#define SLICE_ELEMS (32*1024)

void compute_layout (int pe, int np1, int np2, Env &env)
{
  Alias (int, block, env.sp.fftblock);

  int zs0, ze0, ys0, ye0, xs1, xe1, ys2, ye2, zzs0, zze0, yys2, yye2;
  int me1 = pe / np2, me2 = pe % np2;

  env.dom0.setNo(pe, 1); env.dom1.setNo(pe, 1); env.dom2.setNo(pe, 1);

  splitInt (0, NY-1, me1, np1, ys0, ye0);
  splitInt (0, NZ-1, me2, np2, zs0, ze0);
  env.dom0[pe].set(zs0, ze0, ys0, ye0, 0, NX-1);
  env.mdom0[pe].set(zs0, ze0, ys0, ye0, 0, NX-1);
  
  splitInt (0, NX-1, me1, np1, xs1, xe1);
  env.dom1[pe].set(zs0, ze0, 0, NY-1, xs1, xe1);
  env.mdom1[pe].set(zs0, ze0, 0, NY-1, xs1, xe1);
  
  splitInt (0, NY-1, me2, np2, ys2, ye2);
  env.dom2[pe].set(0, NZ-1, ys2, ye2, xs1, xe1);
  env.mdom2[pe].set(0, NZ-1, ys2, ye2, xs1, xe1);

  int no2aim = MAX(1,NX*NY*NZ/noPE/SLICE_ELEMS);
  int ns0, ns1, ns2, nb0, nb1, nb2, sb0, sb1, sb2, ct0=0, ct1=0, ct2=0;

  sb0 = 1; ns0 = MIN(NZ/np2,no2aim); nb0 = NY/np1/block;
  while (ns0*nb0 > no2aim) { sb0 = sb0<<1; nb0 = int_ceil(NY/np1/block,sb0); }
  sb1 = 1; ns1 = MIN(NZ/np2,no2aim); nb1 = NX/np1/block;
  while (ns1*nb1 > no2aim) { sb1 = sb1<<1; nb1 = int_ceil(NX/np1/block,sb1); }
  sb2 = 1; ns2 = MIN(NY/np2,no2aim); nb2 = NX/np1/block;
  while (ns2*nb2 > no2aim) { sb2 = sb2<<1; nb2 = int_ceil(NX/np1/block,sb2); }
  sb0 = sb0*block; sb1 = sb1 * block; sb2 = sb2 * block;

  if (pe==0) 
    PRINTP0("aim %d\n0: ns %d nb %d sb %d\n1: ns %d nb %d sb %d\n2: ns %d nb %d sb %d\n",
	    no2aim,ns0,nb0,sb0,ns1,nb1,sb1,ns2,nb2,sb2);

  env.mdom0.setNo(pe, nb0*ns0);
  for (int i=0; i<ns0; i++) {
    splitInt (zs0, ze0, i, ns0, zzs0, zze0);
    for (int j=ys0; j<=ye0; j+=sb0)
      env.mdom0.idx(pe,ct0++).set(zzs0, zze0, j, MIN(j+sb0-1,ye0), 0, NX-1);
  }  
  ASSERT(ct0==nb0*ns0,"not enough mparts");

  env.mdom1.setNo(pe, nb1*ns1);
  for (int i=0; i<ns1; i++) {
    splitInt (zs0, ze0, i, ns1, zzs0, zze0);
    for (int j=xs1; j<=xe1; j+=sb1)
      env.mdom1.idx(pe,ct1++).set(zzs0, zze0, 0, NY-1, j, MIN(j+sb1-1,xe1));
  }
  ASSERT(ct1==nb1*ns1,"not enough mparts");
 
  env.mdom2.setNo(pe, nb2*ns2);
  for (int i=0; i<ns2; i++) {
    splitInt (ys2, ye2, i, ns2, yys2, yye2);
    for (int j=xs1; j<=xe1; j+=sb2)
      env.mdom2.idx(pe,ct2++).set(0, NZ-1, yys2, yye2, j, MIN(j+sb2-1,xe1));
  }
  ASSERT(ct2==nb2*ns2,"not enough mparts");
}

void compute_decompositions (Env &env)
{
  Alias (int, pe, env.pre.pe); Alias (int, block, env.sp.fftblock);

  int np1 = 1, np2 = 1;
  if (noPE <= NZ && noPE <= 16) {
    env.lay1d = 1;
    np2 = noPE;
  } else {
    env.lay1d = 0;
    while (np2*np2 <= noPE) np2=np2<<1; np2=np2>>1; 
    np1 = noPE / np2;
  }
  ASSERT (np1*np2==noPE,"no proc %d not power of 2",noPE);
  ASSERT (NX%np1==0 && NY%np1==0,"NX or NY not power of 2");

  block = MIN(block,NX/np1); block = MIN(block,NY/np1);
  ASSERT (NX%block==0 && NY%block==0,"block %d doesn't divide NX or NY",block);
  PRINTP0("Layout %s np1: %d np2: %d; Blocking %d\n",env.lay1d?"1D":"2D",np1,np2,block);

  // partitions
  Alias (RangeNP, ppen, env.pre.PPEn[pe]); Alias (RangeNP, alln, env.pre.ALLn[pe]);
  Alloc (env.dom0, alln); Alloc (env.dom1, alln); Alloc (env.dom2, alln);
  Alloc (env.mdom0, alln); Alloc (env.mdom1, alln); Alloc (env.mdom2, alln);
  Alloc (env.sw01, ppen); Alloc (env.sw12, ppen); 
  Alloc (env.sw21, ppen); Alloc (env.sw10, ppen);
  Alloc (env.sw02, ppen); Alloc (env.sw20, ppen);

  for (int i=0; i<noPE; i++) compute_layout (i, np1, np2, env);
  if (env.lay1d) copy (env.mdom1, env.mdom0);

  // swaps
  computeSwap (env.dom0, env.mdom1, env.sw01);
  computeSwap (env.dom1, env.mdom2, env.sw12);
  computeSwap (env.dom2, env.mdom1, env.sw21);
  computeSwap (env.dom1, env.mdom0, env.sw10);
  computeSwap (env.dom0, env.mdom2, env.sw02);
  computeSwap (env.dom2, env.mdom0, env.sw20);
}

