1 /* ----------------------------------------------------------------------
2    SPARTA - Stochastic PArallel Rarefied-gas Time-accurate Analyzer
3    http://sparta.sandia.gov
4    Steve Plimpton, sjplimp@sandia.gov, Michael Gallis, magalli@sandia.gov
5    Sandia National Laboratories
6 
7    Copyright (2014) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level SPARTA directory.
13 ------------------------------------------------------------------------- */
14 
15 #include "string.h"
16 #include "stdlib.h"
17 #include "fix_balance.h"
18 #include "balance_grid.h"
19 #include "update.h"
20 #include "grid.h"
21 #include "particle.h"
22 #include "domain.h"
23 #include "comm.h"
24 #include "rcb.h"
25 #include "modify.h"
26 #include "compute.h"
27 #include "output.h"
28 #include "dump.h"
29 #include "random_mars.h"
30 #include "random_knuth.h"
31 #include "memory.h"
32 #include "error.h"
33 #include "timer.h"
34 
35 using namespace SPARTA_NS;
36 
37 enum{RANDOM,PROC,BISECTION};
38 enum{CELL,PARTICLE,TIME};
39 
40 #define ZEROPARTICLE 0.1
41 
42 /* ---------------------------------------------------------------------- */
43 
FixBalance(SPARTA * sparta,int narg,char ** arg)44 FixBalance::FixBalance(SPARTA *sparta, int narg, char **arg) :
45   Fix(sparta, narg, arg)
46 {
47   if (narg < 5) error->all(FLERR,"Illegal fix balance command");
48 
49   scalar_flag = 1;
50   vector_flag = 1;
51   size_vector = 2;
52   global_freq = 1;
53 
54   // parse arguments
55 
56   nevery = atoi(arg[2]);
57   thresh = atof(arg[3]);
58 
59   int iarg;
60   if (strcmp(arg[4],"random") == 0) {
61     bstyle = RANDOM;
62     iarg = 5;
63   } else if (strcmp(arg[4],"proc") == 0) {
64     bstyle = PROC;
65     iarg = 5;
66   } else if (strcmp(arg[4],"rcb") == 0) {
67     if (narg < 6) error->all(FLERR,"Illegal fix balance command");
68     bstyle = BISECTION;
69     if (strcmp(arg[5],"cell") == 0) rcbwt = CELL;
70     else if (strcmp(arg[5],"part") == 0) rcbwt = PARTICLE;
71     else if (strcmp(arg[5],"time") == 0) rcbwt = TIME;
72     else error->all(FLERR,"Illegal fix balance command");
73     iarg = 6;
74   } else error->all(FLERR,"Illegal fix balance command");
75 
76   // optional args
77 
78   strcpy(eligible,"xyz");
79   rcbflip = 0;
80 
81   while (iarg < narg) {
82     if (strcmp(arg[iarg],"axes") == 0) {
83       if (iarg+2 > narg) error->all(FLERR,"Illegal fix balance command");
84       if (strlen(arg[iarg+1]) > 3)
85         error->all(FLERR,"Illegal fix balance command");
86       strcpy(eligible,arg[iarg+1]);
87       int xdim = 0;
88       int ydim = 0;
89       int zdim = 0;
90       if (strchr(eligible,'x')) xdim = 1;
91       if (strchr(eligible,'y')) ydim = 1;
92       if (strchr(eligible,'z')) zdim = 1;
93       if (zdim && domain->dimension == 2)
94         error->all(FLERR,"Illegal balance_grid command");
95       if (xdim+ydim+zdim != strlen(eligible))
96         error->all(FLERR,"Illegal fix balance command");
97       iarg += 2;
98     } else if (strcmp(arg[iarg],"flip") == 0) {
99       if (iarg+2 > narg) error->all(FLERR,"Illegal fix balance command");
100       if (strcmp(arg[iarg+1],"yes") == 0) rcbflip = 1;
101       else if (strcmp(arg[iarg+1],"no") == 0) rcbflip = 0;
102       else error->all(FLERR,"Illegal fix balance command");
103       iarg += 2;
104     } else error->all(FLERR,"Illegal fix balance command");
105   }
106 
107   // error check
108 
109   if (nevery < 0 || thresh < 1.0)
110     error->all(FLERR,"Illegal fix balance command");
111 
112   me = comm->me;
113   nprocs = comm->nprocs;
114 
115   // create instance of RNG or RCB
116 
117   random = NULL;
118   rcb = NULL;
119 
120   if (bstyle == RANDOM || bstyle == PROC)
121     random = new RanKnuth(update->ranmaster->uniform());
122   if (bstyle == BISECTION) rcb = new RCB(sparta);
123 
124   // compute initial outputs
125 
126   last = 0.0;
127   imbfinal = imbprev = imbalance_factor(maxperproc);
128 }
129 
130 /* ---------------------------------------------------------------------- */
131 
~FixBalance()132 FixBalance::~FixBalance()
133 {
134   delete random;
135   delete rcb;
136 }
137 
138 /* ---------------------------------------------------------------------- */
139 
setmask()140 int FixBalance::setmask()
141 {
142   int mask = 0;
143   mask |= END_OF_STEP;
144   return mask;
145 }
146 
147 /* ---------------------------------------------------------------------- */
148 
init()149 void FixBalance::init()
150 {
151   // error b/c acquire_ghosts() is a no-op in this case
152 
153   if (bstyle != BISECTION && grid->cutoff >= 0.0)
154     error->all(FLERR,"Cannot use non-rcb fix balance with a grid cutoff");
155 
156   last = 0.0;
157   timer->init();
158 }
159 
160 /* ----------------------------------------------------------------------
161    perform dynamic load balancing
162 ------------------------------------------------------------------------- */
163 
end_of_step()164 void FixBalance::end_of_step()
165 {
166   // return if imbalance < threshhold
167 
168   imbnow = imbalance_factor(maxperproc);
169   if (imbnow <= thresh) return;
170   imbprev = imbnow;
171 
172   Grid::ChildCell *cells = grid->cells;
173   Grid::ChildInfo *cinfo = grid->cinfo;
174   int nglocal = grid->nlocal;
175 
176   // re-assign each of my local child cells to a proc
177   // only assign unsplit and split cells
178   // do not assign sub-cells since they migrate with their split cell
179   // set nmigrate = # of cells that will migrate to a new proc
180   // reset proc field in cells for migrating cells
181 
182   int nmigrate = 0;
183 
184   if (bstyle == RANDOM) {
185     int newproc;
186     for (int icell = 0; icell < nglocal; icell++) {
187       if (cells[icell].nsplit <= 0) continue;
188       newproc = nprocs * random->uniform();
189       if (newproc != cells[icell].proc) nmigrate++;
190       cells[icell].proc = newproc;
191     }
192 
193   } else if (bstyle == PROC) {
194     int newproc = nprocs * random->uniform();
195     for (int icell = 0; icell < nglocal; icell++) {
196       if (cells[icell].nsplit <= 0) continue;
197       if (newproc != cells[icell].proc) nmigrate++;
198       cells[icell].proc = newproc;
199       newproc++;
200       if (newproc == nprocs) newproc = 0;
201     }
202 
203   } else if (bstyle == BISECTION) {
204     double **x;
205     memory->create(x,nglocal,3,"balance:x");
206 
207     double *lo,*hi;
208 
209     int nbalance = 0;
210     for (int icell = 0; icell < nglocal; icell++) {
211       if (cells[icell].nsplit <= 0) continue;
212       lo = cells[icell].lo;
213       hi = cells[icell].hi;
214       x[nbalance][0] = 0.5*(lo[0]+hi[0]);
215       x[nbalance][1] = 0.5*(lo[1]+hi[1]);
216       x[nbalance][2] = 0.5*(lo[2]+hi[2]);
217       nbalance++;
218     }
219 
220     double *wt = NULL;
221     if (rcbwt == PARTICLE) {
222       if (!particle->sorted) particle->sort();
223       memory->create(wt,nglocal,"balance:wt");
224       int n;
225       nbalance = 0;
226       for (int icell = 0; icell < nglocal; icell++) {
227         if (cells[icell].nsplit <= 0) continue;
228         n = cinfo[icell].count;
229         if (n) wt[nbalance++] = n;
230         else wt[nbalance++] = ZEROPARTICLE;
231       }
232     } else if (rcbwt == TIME) {
233       memory->create(wt,nglocal,"balance:wt");
234       timer_cell_weights(wt);
235     }
236 
237     rcb->compute(nbalance,x,wt,eligible,rcbflip);
238     rcb->invert();
239 
240     nbalance = 0;
241     int *sendproc = rcb->sendproc;
242     for (int icell = 0; icell < nglocal; icell++) {
243       if (cells[icell].nsplit <= 0) continue;
244       cells[icell].proc = sendproc[nbalance++];
245     }
246     nmigrate = nbalance - rcb->nkeep;
247 
248     memory->destroy(x);
249     memory->destroy(wt);
250   }
251 
252   if (nprocs == 1 || bstyle == BISECTION) grid->clumped = 1;
253   else grid->clumped = 0;
254 
255   // sort particles
256 
257   if (!particle->sorted) particle->sort();
258 
259   // migrate grid cells and their particles to new owners
260   // invoke grid methods to complete grid setup
261   // some fixes have post migration operations to perform
262 
263   grid->unset_neighbors();
264   grid->remove_ghosts();
265 
266   comm->migrate_cells(nmigrate);
267   grid->hashfilled = 0;
268 
269   grid->setup_owned();
270   grid->acquire_ghosts();
271 
272   grid->reset_neighbors();
273   comm->reset_neighbors();
274 
275   // notify all classes that store per-grid data that grid may have changed
276 
277   grid->notify_changed();
278 
279   // final imbalance factor
280 
281   if (bstyle == BISECTION && rcbwt == TIME)
282     imbfinal = 0.0; // can't compute imbalance from timers since grid cells moved
283   else
284     imbfinal = imbalance_factor(maxperproc);
285 }
286 
287 /* ----------------------------------------------------------------------
288    calculate imbalance based on current particle count
289    return maxcost = max particles per proc or CPU time per proc
290    return imbalance factor = max per proc / ave per proc
291 ------------------------------------------------------------------------- */
292 
imbalance_factor(double & maxcost)293 double FixBalance::imbalance_factor(double &maxcost)
294 {
295   double mycost,totalcost;
296   double mycost_proc_weighted,maxcost_proc_weighted,nprocs_weighted;
297 
298   if (bstyle == BISECTION && rcbwt == TIME) {
299     timer_cost();
300     mycost = my_timer_cost;
301   } else mycost = particle->nlocal;
302 
303   MPI_Allreduce(&mycost,&totalcost,1,MPI_DOUBLE,MPI_SUM,world);
304   MPI_Allreduce(&mycost,&maxcost,1,MPI_DOUBLE,MPI_MAX,world);
305 
306   double imbalance = 1.0;
307   if (maxcost) imbalance = maxcost / (totalcost / nprocs);
308   return imbalance;
309 }
310 
311 /* ----------------------------------------------------------------------
312    return imbalance factor after last rebalance
313 ------------------------------------------------------------------------- */
314 
compute_scalar()315 double FixBalance::compute_scalar()
316 {
317   return imbfinal;
318 }
319 
320 /* ----------------------------------------------------------------------
321    return stats for last rebalance
322 ------------------------------------------------------------------------- */
323 
compute_vector(int i)324 double FixBalance::compute_vector(int i)
325 {
326   if (i == 0) return maxperproc;
327   return imbprev;
328 }
329 
330 /* -------------------------------------------------------------------- */
331 
timer_cost()332 void FixBalance::timer_cost()
333 {
334   // my_timer_cost = CPU time for relevant timers since last invocation
335 
336   my_timer_cost = -last;
337   my_timer_cost += timer->array[TIME_MOVE];
338   my_timer_cost += timer->array[TIME_SORT];
339   my_timer_cost += timer->array[TIME_COLLIDE];
340   my_timer_cost += timer->array[TIME_MODIFY];
341 
342   // last = time up to this point
343 
344   last += my_timer_cost;
345 }
346 
347 /* -------------------------------------------------------------------- */
348 
timer_cell_weights(double * weight)349 void FixBalance::timer_cell_weights(double *weight)
350 {
351   // localwt = weight assigned to each owned grid cell
352   // just return if no time yet tallied
353 
354   double maxcost;
355   MPI_Allreduce(&my_timer_cost,&maxcost,1,MPI_DOUBLE,MPI_MAX,world);
356   if (maxcost <= 0.0) {
357     memory->destroy(weight);
358     weight = NULL;
359     return;
360   }
361 
362   Grid::ChildCell *cells = grid->cells;
363   Grid::ChildInfo *cinfo = grid->cinfo;
364   int nglocal = grid->nlocal;
365 
366   double localwt_total = 0.0;
367   if (nglocal) localwt_total = my_timer_cost/nglocal;
368   if (nglocal && localwt_total <= 0.0) error->one(FLERR,"Balance weight <= 0.0");
369 
370   if (!particle->sorted) particle->sort();
371   double wttotal = 0;
372   int nbalance = 0;
373   double* localwt;
374   memory->create(localwt,nglocal,"imbalance_time:localwt");
375   for (int icell = 0; icell < nglocal; icell++) {
376     localwt[icell] = 0.0;
377     if (cells[icell].nsplit <= 0) continue;
378     int n = cinfo[icell].count;
379     if (n) localwt[nbalance++] = n;
380     else localwt[nbalance++] = ZEROPARTICLE;
381     wttotal += localwt[nbalance-1];
382   }
383 
384   for (int icell = 0; icell < nglocal; icell++)
385     weight[icell] = my_timer_cost*localwt[icell]/wttotal;
386 
387   memory->destroy(localwt);
388 }
389 
390 /* ----------------------------------------------------------------------
391    return # of bytes of allocated memory
392 ------------------------------------------------------------------------- */
393 
memory_usage()394 double FixBalance::memory_usage()
395 {
396   double bytes = 0.0;
397   // tally wt vector?
398   return bytes;
399 }
400