1 // clang-format off
2 /* ----------------------------------------------------------------------
3    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
4    https://www.lammps.org/, Sandia National Laboratories
5    Steve Plimpton, sjplimp@sandia.gov
6 
7    Copyright (2003) Sandia Corporation.  Under the terms of Contract
8    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
9    certain rights in this software.  This software is distributed under
10    the GNU General Public License.
11 
12    See the README file in the top-level LAMMPS directory.
13 ------------------------------------------------------------------------- */
14 
15 #include "comm.h"
16 
17 #include "accelerator_kokkos.h"
18 #include "atom.h"               // IWYU pragma: keep
19 #include "atom_vec.h"
20 #include "bond.h"
21 #include "compute.h"
22 #include "domain.h"             // IWYU pragma: keep
23 #include "dump.h"
24 #include "error.h"
25 #include "fix.h"
26 #include "force.h"
27 #include "group.h"
28 #include "irregular.h"
29 #include "memory.h"             // IWYU pragma: keep
30 #include "modify.h"
31 #include "neighbor.h"           // IWYU pragma: keep
32 #include "output.h"
33 #include "pair.h"
34 #include "procmap.h"
35 #include "universe.h"
36 #include "update.h"
37 
38 #include <cstring>
39 #ifdef _OPENMP
40 #include <omp.h>
41 #endif
42 
43 using namespace LAMMPS_NS;
44 
45 #define BUFEXTRA 1024
46 
47 enum{ONELEVEL,TWOLEVEL,NUMA,CUSTOM};
48 enum{CART,CARTREORDER,XYZ};
49 
50 /* ---------------------------------------------------------------------- */
51 
Comm(LAMMPS * lmp)52 Comm::Comm(LAMMPS *lmp) : Pointers(lmp)
53 {
54   MPI_Comm_rank(world,&me);
55   MPI_Comm_size(world,&nprocs);
56 
57   mode = 0;
58   bordergroup = 0;
59   cutghostuser = 0.0;
60   cutusermulti = nullptr;
61   cutusermultiold = nullptr;
62   ncollections = 0;
63   ncollections_cutoff = 0;
64   ghost_velocity = 0;
65 
66   user_procgrid[0] = user_procgrid[1] = user_procgrid[2] = 0;
67   coregrid[0] = coregrid[1] = coregrid[2] = 1;
68   gridflag = ONELEVEL;
69   mapflag = CART;
70   customfile = nullptr;
71   outfile = nullptr;
72   recv_from_partition = send_to_partition = -1;
73   otherflag = 0;
74 
75   maxexchange = maxexchange_atom = maxexchange_fix = 0;
76   maxexchange_fix_dynamic = 0;
77   bufextra = BUFEXTRA;
78 
79   grid2proc = nullptr;
80   xsplit = ysplit = zsplit = nullptr;
81   rcbnew = 0;
82   multi_reduce = 0;
83 
84   // use of OpenMP threads
85   // query OpenMP for number of threads/process set by user at run-time
86   // if the OMP_NUM_THREADS environment variable is not set, we default
87   // to using 1 thread. This follows the principle of the least surprise,
88   // while practically all OpenMP implementations violate it by using
89   // as many threads as there are (virtual) CPU cores by default.
90 
91   nthreads = 1;
92 #ifdef _OPENMP
93   if (lmp->kokkos) {
94     nthreads = lmp->kokkos->nthreads * lmp->kokkos->numa;
95   } else if (getenv("OMP_NUM_THREADS") == nullptr) {
96     nthreads = 1;
97     if (me == 0)
98       error->message(FLERR,"OMP_NUM_THREADS environment is not set. "
99                            "Defaulting to 1 thread.");
100   } else {
101     nthreads = omp_get_max_threads();
102   }
103 
104   // enforce consistent number of threads across all MPI tasks
105 
106   MPI_Bcast(&nthreads,1,MPI_INT,0,world);
107   if (!lmp->kokkos) omp_set_num_threads(nthreads);
108 
109   if (me == 0)
110     utils::logmesg(lmp,"  using {} OpenMP thread(s) per MPI task\n",nthreads);
111 #endif
112 
113 }
114 
115 /* ---------------------------------------------------------------------- */
116 
~Comm()117 Comm::~Comm()
118 {
119   memory->destroy(grid2proc);
120   memory->destroy(xsplit);
121   memory->destroy(ysplit);
122   memory->destroy(zsplit);
123   memory->destroy(cutusermulti);
124   memory->destroy(cutusermultiold);
125   delete [] customfile;
126   delete [] outfile;
127 }
128 
129 /* ----------------------------------------------------------------------
130    deep copy of arrays from old Comm class to new one
131    all public/protected vectors/arrays in parent Comm class must be copied
132    called from alternate constructor of child classes
133    when new comm style is created from Input
134 ------------------------------------------------------------------------- */
135 
copy_arrays(Comm * oldcomm)136 void Comm::copy_arrays(Comm *oldcomm)
137 {
138   if (oldcomm->grid2proc) {
139     memory->create(grid2proc,procgrid[0],procgrid[1],procgrid[2],
140                    "comm:grid2proc");
141     memcpy(&grid2proc[0][0][0],&oldcomm->grid2proc[0][0][0],
142            (procgrid[0]*procgrid[1]*procgrid[2])*sizeof(int));
143 
144     memory->create(xsplit,procgrid[0]+1,"comm:xsplit");
145     memory->create(ysplit,procgrid[1]+1,"comm:ysplit");
146     memory->create(zsplit,procgrid[2]+1,"comm:zsplit");
147     memcpy(xsplit,oldcomm->xsplit,(procgrid[0]+1)*sizeof(double));
148     memcpy(ysplit,oldcomm->ysplit,(procgrid[1]+1)*sizeof(double));
149     memcpy(zsplit,oldcomm->zsplit,(procgrid[2]+1)*sizeof(double));
150   }
151 
152   ncollections = oldcomm->ncollections;
153   ncollections_cutoff = oldcomm->ncollections_cutoff;
154   if (oldcomm->cutusermulti) {
155     memory->create(cutusermulti,ncollections_cutoff,"comm:cutusermulti");
156     memcpy(cutusermulti,oldcomm->cutusermulti,ncollections_cutoff);
157   }
158 
159   if (oldcomm->cutusermultiold) {
160     memory->create(cutusermultiold,atom->ntypes+1,"comm:cutusermultiold");
161     memcpy(cutusermultiold,oldcomm->cutusermultiold,atom->ntypes+1);
162   }
163 
164   if (customfile)
165     customfile = utils::strdup(oldcomm->customfile);
166 
167   if (outfile)
168     outfile = utils::strdup(oldcomm->outfile);
169 }
170 
171 /* ----------------------------------------------------------------------
172    common to all Comm styles
173 ------------------------------------------------------------------------- */
174 
init()175 void Comm::init()
176 {
177   triclinic = domain->triclinic;
178   map_style = atom->map_style;
179 
180   // check warn if any proc's subbox is smaller than neigh skin
181   //   since may lead to lost atoms in exchange()
182   // really should check every exchange() in case box size is shrinking
183   //   but seems overkill to do that (fix balance does perform this check)
184 
185   domain->subbox_too_small_check(neighbor->skin);
186 
187   // comm_only = 1 if only x,f are exchanged in forward/reverse comm
188   // comm_x_only = 0 if ghost_velocity since velocities are added
189 
190   comm_x_only = atom->avec->comm_x_only;
191   comm_f_only = atom->avec->comm_f_only;
192   if (ghost_velocity) comm_x_only = 0;
193 
194   // set per-atom sizes for forward/reverse/border comm
195   // augment by velocity and fix quantities if needed
196 
197   size_forward = atom->avec->size_forward;
198   size_reverse = atom->avec->size_reverse;
199   size_border = atom->avec->size_border;
200 
201   if (ghost_velocity) size_forward += atom->avec->size_velocity;
202   if (ghost_velocity) size_border += atom->avec->size_velocity;
203 
204   for (int i = 0; i < modify->nfix; i++)
205     size_border += modify->fix[i]->comm_border;
206 
207   // per-atom limits for communication
208   // maxexchange = max # of datums in exchange comm, set in exchange()
209   // maxforward = # of datums in largest forward comm
210   // maxreverse = # of datums in largest reverse comm
211   // query pair,fix,compute,dump for their requirements
212   // pair style can force reverse comm even if newton off
213 
214   maxforward = MAX(size_forward,size_border);
215   maxreverse = size_reverse;
216 
217   if (force->pair) maxforward = MAX(maxforward,force->pair->comm_forward);
218   if (force->pair) maxreverse = MAX(maxreverse,force->pair->comm_reverse);
219 
220   for (int i = 0; i < modify->nfix; i++) {
221     maxforward = MAX(maxforward,modify->fix[i]->comm_forward);
222     maxreverse = MAX(maxreverse,modify->fix[i]->comm_reverse);
223   }
224 
225   for (int i = 0; i < modify->ncompute; i++) {
226     maxforward = MAX(maxforward,modify->compute[i]->comm_forward);
227     maxreverse = MAX(maxreverse,modify->compute[i]->comm_reverse);
228   }
229 
230   for (int i = 0; i < output->ndump; i++) {
231     maxforward = MAX(maxforward,output->dump[i]->comm_forward);
232     maxreverse = MAX(maxreverse,output->dump[i]->comm_reverse);
233   }
234 
235   if (force->newton == 0) maxreverse = 0;
236   if (force->pair) maxreverse = MAX(maxreverse,force->pair->comm_reverse_off);
237 
238   // maxexchange_atom = size of an exchanged atom, set by AtomVec
239   //   only needs to be set if size > BUFEXTRA
240   // maxexchange_fix_dynamic = 1 if any fix sets its maxexchange dynamically
241 
242   maxexchange_atom = atom->avec->maxexchange;
243 
244   int nfix = modify->nfix;
245   Fix **fix = modify->fix;
246 
247   maxexchange_fix_dynamic = 0;
248   for (int i = 0; i < nfix; i++)
249     if (fix[i]->maxexchange_dynamic) maxexchange_fix_dynamic = 1;
250 
251   if ((mode == Comm::MULTI) && (neighbor->style != Neighbor::MULTI))
252     error->all(FLERR,"Cannot use comm mode multi without multi-style neighbor lists");
253 
254   if (multi_reduce) {
255     if (force->newton == 0)
256       error->all(FLERR,"Cannot use multi/reduce communication with Newton off");
257     if (neighbor->any_full())
258       error->all(FLERR,"Cannot use multi/reduce communication with a full neighbor list");
259     if (mode != Comm::MULTI)
260       error->all(FLERR,"Cannot use multi/reduce communication without mode multi");
261   }
262 }
263 
264 /* ----------------------------------------------------------------------
265    set maxexchange based on AtomVec and fixes
266 ------------------------------------------------------------------------- */
267 
init_exchange()268 void Comm::init_exchange()
269 {
270   int nfix = modify->nfix;
271   Fix **fix = modify->fix;
272 
273   maxexchange_fix = 0;
274   for (int i = 0; i < nfix; i++)
275     maxexchange_fix += fix[i]->maxexchange;
276 
277   maxexchange = maxexchange_atom + maxexchange_fix;
278   bufextra = maxexchange + BUFEXTRA;
279 }
280 
281 /* ----------------------------------------------------------------------
282    modify communication params
283    invoked from input script by comm_modify command
284 ------------------------------------------------------------------------- */
285 
modify_params(int narg,char ** arg)286 void Comm::modify_params(int narg, char **arg)
287 {
288   if (narg < 1) error->all(FLERR,"Illegal comm_modify command");
289 
290   int iarg = 0;
291   while (iarg < narg) {
292     if (strcmp(arg[iarg],"mode") == 0) {
293       if (iarg+2 > narg) error->all(FLERR,"Illegal comm_modify command");
294       if (strcmp(arg[iarg+1],"single") == 0) {
295         // need to reset cutghostuser when switching comm mode
296         if (mode == Comm::MULTI) cutghostuser = 0.0;
297         if (mode == Comm::MULTIOLD) cutghostuser = 0.0;
298         memory->destroy(cutusermulti);
299         memory->destroy(cutusermultiold);
300         mode = Comm::SINGLE;
301       } else if (strcmp(arg[iarg+1],"multi") == 0) {
302         if (neighbor->style != Neighbor::MULTI)
303           error->all(FLERR,"Cannot use comm mode 'multi' without 'multi' style neighbor lists");
304         // need to reset cutghostuser when switching comm mode
305         if (mode == Comm::SINGLE) cutghostuser = 0.0;
306         if (mode == Comm::MULTIOLD) cutghostuser = 0.0;
307         memory->destroy(cutusermultiold);
308         mode = Comm::MULTI;
309       } else if (strcmp(arg[iarg+1],"multi/old") == 0) {
310         if (neighbor->style == Neighbor::MULTI)
311           error->all(FLERR,"Cannot use comm mode 'multi/old' with 'multi' style neighbor lists");
312         // need to reset cutghostuser when switching comm mode
313         if (mode == Comm::SINGLE) cutghostuser = 0.0;
314         if (mode == Comm::MULTI) cutghostuser = 0.0;
315         memory->destroy(cutusermulti);
316         mode = Comm::MULTIOLD;
317       } else error->all(FLERR,"Illegal comm_modify command");
318       iarg += 2;
319     } else if (strcmp(arg[iarg],"group") == 0) {
320       if (iarg+2 > narg) error->all(FLERR,"Illegal comm_modify command");
321       bordergroup = group->find(arg[iarg+1]);
322       if (bordergroup < 0)
323         error->all(FLERR,"Invalid group in comm_modify command");
324       if (bordergroup && (atom->firstgroupname == nullptr ||
325                           strcmp(arg[iarg+1],atom->firstgroupname) != 0))
326         error->all(FLERR,"Comm_modify group != atom_modify first group");
327       iarg += 2;
328     } else if (strcmp(arg[iarg],"cutoff") == 0) {
329       if (iarg+2 > narg) error->all(FLERR,"Illegal comm_modify command");
330       if (mode == Comm::MULTI)
331         error->all(FLERR, "Use cutoff/multi keyword to set cutoff in multi mode");
332       if (mode == Comm::MULTIOLD)
333         error->all(FLERR, "Use cutoff/multi/old keyword to set cutoff in multi mode");
334       cutghostuser = utils::numeric(FLERR,arg[iarg+1],false,lmp);
335       if (cutghostuser < 0.0)
336         error->all(FLERR,"Invalid cutoff in comm_modify command");
337       iarg += 2;
338     } else if (strcmp(arg[iarg],"cutoff/multi") == 0) {
339       int i,nlo,nhi;
340       double cut;
341       if (mode == Comm::SINGLE)
342         error->all(FLERR,"Use cutoff keyword to set cutoff in single mode");
343       if (mode == Comm::MULTIOLD)
344         error->all(FLERR,"Use cutoff/multi/old keyword to set cutoff in multi/old mode");
345       if (domain->box_exist == 0)
346         error->all(FLERR, "Cannot set cutoff/multi before simulation box is defined");
347 
348       // Check if # of collections has changed, if so erase any previously defined cutoffs
349       // Neighbor will reset ncollections if collections are redefined
350       if (! cutusermulti || ncollections_cutoff != neighbor->ncollections) {
351         ncollections_cutoff = neighbor->ncollections;
352         memory->destroy(cutusermulti);
353         memory->create(cutusermulti,ncollections_cutoff,"comm:cutusermulti");
354         for (i=0; i < ncollections_cutoff; ++i)
355           cutusermulti[i] = -1.0;
356       }
357       utils::bounds(FLERR,arg[iarg+1],1,ncollections_cutoff,nlo,nhi,error);
358       cut = utils::numeric(FLERR,arg[iarg+2],false,lmp);
359       cutghostuser = MAX(cutghostuser,cut);
360       if (cut < 0.0)
361         error->all(FLERR,"Invalid cutoff in comm_modify command");
362       // collections use 1-based indexing externally and 0-based indexing internally
363       for (i=nlo; i<=nhi; ++i)
364         cutusermulti[i-1] = cut;
365       iarg += 3;
366     }  else if (strcmp(arg[iarg],"cutoff/multi/old") == 0) {
367       int i,nlo,nhi;
368       double cut;
369       if (mode == Comm::SINGLE)
370         error->all(FLERR,"Use cutoff keyword to set cutoff in single mode");
371       if (mode == Comm::MULTI)
372         error->all(FLERR,"Use cutoff/multi keyword to set cutoff in multi mode");
373       if (domain->box_exist == 0)
374         error->all(FLERR, "Cannot set cutoff/multi before simulation box is defined");
375       const int ntypes = atom->ntypes;
376       if (iarg+3 > narg)
377         error->all(FLERR,"Illegal comm_modify command");
378       if (cutusermultiold == nullptr) {
379         memory->create(cutusermultiold,ntypes+1,"comm:cutusermultiold");
380         for (i=0; i < ntypes+1; ++i)
381           cutusermultiold[i] = -1.0;
382       }
383       utils::bounds(FLERR,arg[iarg+1],1,ntypes,nlo,nhi,error);
384       cut = utils::numeric(FLERR,arg[iarg+2],false,lmp);
385       cutghostuser = MAX(cutghostuser,cut);
386       if (cut < 0.0)
387         error->all(FLERR,"Invalid cutoff in comm_modify command");
388       for (i=nlo; i<=nhi; ++i)
389         cutusermultiold[i] = cut;
390       iarg += 3;
391     } else if (strcmp(arg[iarg],"reduce/multi") == 0) {
392       if (mode == Comm::SINGLE)
393         error->all(FLERR,"Use reduce/multi in mode multi only");
394       multi_reduce = 1;
395       iarg += 1;
396     } else if (strcmp(arg[iarg],"vel") == 0) {
397       if (iarg+2 > narg) error->all(FLERR,"Illegal comm_modify command");
398       if (strcmp(arg[iarg+1],"yes") == 0) ghost_velocity = 1;
399       else if (strcmp(arg[iarg+1],"no") == 0) ghost_velocity = 0;
400       else error->all(FLERR,"Illegal comm_modify command");
401       iarg += 2;
402     } else error->all(FLERR,"Illegal comm_modify command");
403   }
404 }
405 
406 /* ----------------------------------------------------------------------
407    set dimensions for 3d grid of processors, and associated flags
408    invoked from input script by processors command
409 ------------------------------------------------------------------------- */
410 
set_processors(int narg,char ** arg)411 void Comm::set_processors(int narg, char **arg)
412 {
413   if (narg < 3) error->all(FLERR,"Illegal processors command");
414 
415   if (strcmp(arg[0],"*") == 0) user_procgrid[0] = 0;
416   else user_procgrid[0] = utils::inumeric(FLERR,arg[0],false,lmp);
417   if (strcmp(arg[1],"*") == 0) user_procgrid[1] = 0;
418   else user_procgrid[1] = utils::inumeric(FLERR,arg[1],false,lmp);
419   if (strcmp(arg[2],"*") == 0) user_procgrid[2] = 0;
420   else user_procgrid[2] = utils::inumeric(FLERR,arg[2],false,lmp);
421 
422   if (user_procgrid[0] < 0 || user_procgrid[1] < 0 || user_procgrid[2] < 0)
423     error->all(FLERR,"Illegal processors command");
424 
425   int p = user_procgrid[0]*user_procgrid[1]*user_procgrid[2];
426   if (p && p != nprocs)
427     error->all(FLERR,"Specified processors != physical processors");
428 
429   int iarg = 3;
430   while (iarg < narg) {
431     if (strcmp(arg[iarg],"grid") == 0) {
432       if (iarg+2 > narg) error->all(FLERR,"Illegal processors command");
433 
434       if (strcmp(arg[iarg+1],"onelevel") == 0) {
435         gridflag = ONELEVEL;
436 
437       } else if (strcmp(arg[iarg+1],"twolevel") == 0) {
438         if (iarg+6 > narg) error->all(FLERR,"Illegal processors command");
439         gridflag = TWOLEVEL;
440 
441         ncores = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
442         if (strcmp(arg[iarg+3],"*") == 0) user_coregrid[0] = 0;
443         else user_coregrid[0] = utils::inumeric(FLERR,arg[iarg+3],false,lmp);
444         if (strcmp(arg[iarg+4],"*") == 0) user_coregrid[1] = 0;
445         else user_coregrid[1] = utils::inumeric(FLERR,arg[iarg+4],false,lmp);
446         if (strcmp(arg[iarg+5],"*") == 0) user_coregrid[2] = 0;
447         else user_coregrid[2] = utils::inumeric(FLERR,arg[iarg+5],false,lmp);
448 
449         if (ncores <= 0 || user_coregrid[0] < 0 ||
450             user_coregrid[1] < 0 || user_coregrid[2] < 0)
451           error->all(FLERR,"Illegal processors command");
452         iarg += 4;
453 
454       } else if (strcmp(arg[iarg+1],"numa") == 0) {
455         gridflag = NUMA;
456 
457       } else if (strcmp(arg[iarg+1],"custom") == 0) {
458         if (iarg+3 > narg) error->all(FLERR,"Illegal processors command");
459         gridflag = CUSTOM;
460         delete [] customfile;
461         customfile = utils::strdup(arg[iarg+2]);
462         iarg += 1;
463 
464       } else error->all(FLERR,"Illegal processors command");
465       iarg += 2;
466 
467     } else if (strcmp(arg[iarg],"map") == 0) {
468       if (iarg+2 > narg) error->all(FLERR,"Illegal processors command");
469       if (strcmp(arg[iarg+1],"cart") == 0) mapflag = CART;
470       else if (strcmp(arg[iarg+1],"cart/reorder") == 0) mapflag = CARTREORDER;
471       else if (strcmp(arg[iarg+1],"xyz") == 0 ||
472                strcmp(arg[iarg+1],"xzy") == 0 ||
473                strcmp(arg[iarg+1],"yxz") == 0 ||
474                strcmp(arg[iarg+1],"yzx") == 0 ||
475                strcmp(arg[iarg+1],"zxy") == 0 ||
476                strcmp(arg[iarg+1],"zyx") == 0) {
477         mapflag = XYZ;
478         strncpy(xyz,arg[iarg+1],3);
479       } else error->all(FLERR,"Illegal processors command");
480       iarg += 2;
481 
482     } else if (strcmp(arg[iarg],"part") == 0) {
483       if (iarg+4 > narg) error->all(FLERR,"Illegal processors command");
484       if (universe->nworlds == 1)
485         error->all(FLERR,
486                    "Cannot use processors part command "
487                    "without using partitions");
488       int isend = utils::inumeric(FLERR,arg[iarg+1],false,lmp);
489       int irecv = utils::inumeric(FLERR,arg[iarg+2],false,lmp);
490       if (isend < 1 || isend > universe->nworlds ||
491           irecv < 1 || irecv > universe->nworlds || isend == irecv)
492         error->all(FLERR,"Invalid partitions in processors part command");
493       if (isend-1 == universe->iworld) {
494         if (send_to_partition >= 0)
495           error->all(FLERR,
496                      "Sending partition in processors part command "
497                      "is already a sender");
498         send_to_partition = irecv-1;
499       }
500       if (irecv-1 == universe->iworld) {
501         if (recv_from_partition >= 0)
502           error->all(FLERR,
503                      "Receiving partition in processors part command "
504                      "is already a receiver");
505         recv_from_partition = isend-1;
506       }
507 
508       // only receiver has otherflag dependency
509 
510       if (strcmp(arg[iarg+3],"multiple") == 0) {
511         if (universe->iworld == irecv-1) {
512           otherflag = 1;
513           other_style = Comm::MULTIPLE;
514         }
515       } else error->all(FLERR,"Illegal processors command");
516       iarg += 4;
517 
518     } else if (strcmp(arg[iarg],"file") == 0) {
519       if (iarg+2 > narg) error->all(FLERR,"Illegal processors command");
520       delete [] outfile;
521       outfile = utils::strdup(arg[iarg+1]);
522       iarg += 2;
523 
524     } else error->all(FLERR,"Illegal processors command");
525   }
526 
527   // error checks
528 
529   if (gridflag == NUMA && mapflag != CART)
530     error->all(FLERR,"Processors grid numa and map style are incompatible");
531   if (otherflag && (gridflag == NUMA || gridflag == CUSTOM))
532     error->all(FLERR,
533                "Processors part option and grid style are incompatible");
534 }
535 
536 /* ----------------------------------------------------------------------
537    create a 3d grid of procs based on Nprocs and box size & shape
538    map processors to grid, setup xyz split for a uniform grid
539 ------------------------------------------------------------------------- */
540 
set_proc_grid(int outflag)541 void Comm::set_proc_grid(int outflag)
542 {
543   // recv 3d proc grid of another partition if my 3d grid depends on it
544 
545   if (recv_from_partition >= 0) {
546     if (me == 0) {
547       MPI_Recv(other_procgrid,3,MPI_INT,
548                universe->root_proc[recv_from_partition],0,
549                universe->uworld,MPI_STATUS_IGNORE);
550       MPI_Recv(other_coregrid,3,MPI_INT,
551                universe->root_proc[recv_from_partition],0,
552                universe->uworld,MPI_STATUS_IGNORE);
553     }
554     MPI_Bcast(other_procgrid,3,MPI_INT,0,world);
555     MPI_Bcast(other_coregrid,3,MPI_INT,0,world);
556   }
557 
558   // create ProcMap class to create 3d grid and map procs to it
559 
560   ProcMap *pmap = new ProcMap(lmp);
561 
562   // create 3d grid of processors
563   // produces procgrid and coregrid (if relevant)
564 
565   if (gridflag == ONELEVEL) {
566     pmap->onelevel_grid(nprocs,user_procgrid,procgrid,
567                         otherflag,other_style,other_procgrid,other_coregrid);
568 
569   } else if (gridflag == TWOLEVEL) {
570     pmap->twolevel_grid(nprocs,user_procgrid,procgrid,
571                         ncores,user_coregrid,coregrid,
572                         otherflag,other_style,other_procgrid,other_coregrid);
573 
574   } else if (gridflag == NUMA) {
575     pmap->numa_grid(nprocs,user_procgrid,procgrid,coregrid);
576 
577   } else if (gridflag == CUSTOM) {
578     pmap->custom_grid(customfile,nprocs,user_procgrid,procgrid);
579   }
580 
581   // error check on procgrid
582   // should not be necessary due to ProcMap
583 
584   if (procgrid[0]*procgrid[1]*procgrid[2] != nprocs)
585     error->all(FLERR,"Bad grid of processors");
586   if (domain->dimension == 2 && procgrid[2] != 1)
587     error->all(FLERR,"Processor count in z must be 1 for 2d simulation");
588 
589   // grid2proc[i][j][k] = proc that owns i,j,k location in 3d grid
590 
591   if (grid2proc) memory->destroy(grid2proc);
592   memory->create(grid2proc,procgrid[0],procgrid[1],procgrid[2],
593                  "comm:grid2proc");
594 
595   // map processor IDs to 3d processor grid
596   // produces myloc, procneigh, grid2proc
597 
598   if (gridflag == ONELEVEL) {
599     if (mapflag == CART)
600       pmap->cart_map(0,procgrid,myloc,procneigh,grid2proc);
601     else if (mapflag == CARTREORDER)
602       pmap->cart_map(1,procgrid,myloc,procneigh,grid2proc);
603     else if (mapflag == XYZ)
604       pmap->xyz_map(xyz,procgrid,myloc,procneigh,grid2proc);
605 
606   } else if (gridflag == TWOLEVEL) {
607     if (mapflag == CART)
608       pmap->cart_map(0,procgrid,ncores,coregrid,myloc,procneigh,grid2proc);
609     else if (mapflag == CARTREORDER)
610       pmap->cart_map(1,procgrid,ncores,coregrid,myloc,procneigh,grid2proc);
611     else if (mapflag == XYZ)
612       pmap->xyz_map(xyz,procgrid,ncores,coregrid,myloc,procneigh,grid2proc);
613 
614   } else if (gridflag == NUMA) {
615     pmap->numa_map(0,coregrid,myloc,procneigh,grid2proc);
616 
617   } else if (gridflag == CUSTOM) {
618     pmap->custom_map(procgrid,myloc,procneigh,grid2proc);
619   }
620 
621   // print 3d grid info to screen and logfile
622 
623   if (outflag && me == 0) {
624     auto mesg = fmt::format("  {} by {} by {} MPI processor grid\n",
625                             procgrid[0],procgrid[1],procgrid[2]);
626     if (gridflag == NUMA || gridflag == TWOLEVEL)
627       mesg += fmt::format("  {} by {} by {} core grid within node\n",
628                           coregrid[0],coregrid[1],coregrid[2]);
629     utils::logmesg(lmp,mesg);
630   }
631 
632   // print 3d grid details to outfile
633 
634   if (outfile) pmap->output(outfile,procgrid,grid2proc);
635 
636   // free ProcMap class
637 
638   delete pmap;
639 
640   // set xsplit,ysplit,zsplit for uniform spacings
641 
642   memory->destroy(xsplit);
643   memory->destroy(ysplit);
644   memory->destroy(zsplit);
645 
646   memory->create(xsplit,procgrid[0]+1,"comm:xsplit");
647   memory->create(ysplit,procgrid[1]+1,"comm:ysplit");
648   memory->create(zsplit,procgrid[2]+1,"comm:zsplit");
649 
650   for (int i = 0; i < procgrid[0]; i++) xsplit[i] = i * 1.0/procgrid[0];
651   for (int i = 0; i < procgrid[1]; i++) ysplit[i] = i * 1.0/procgrid[1];
652   for (int i = 0; i < procgrid[2]; i++) zsplit[i] = i * 1.0/procgrid[2];
653 
654   xsplit[procgrid[0]] = ysplit[procgrid[1]] = zsplit[procgrid[2]] = 1.0;
655 
656   // set lamda box params after procs are assigned
657   // only set once unless load-balancing occurs
658 
659   if (domain->triclinic) domain->set_lamda_box();
660 
661   // send my 3d proc grid to another partition if requested
662 
663   if (send_to_partition >= 0) {
664     if (me == 0) {
665       MPI_Send(procgrid,3,MPI_INT,
666                universe->root_proc[send_to_partition],0,
667                universe->uworld);
668       MPI_Send(coregrid,3,MPI_INT,
669                universe->root_proc[send_to_partition],0,
670                universe->uworld);
671     }
672   }
673 }
674 
675 /* ----------------------------------------------------------------------
676    determine suitable communication cutoff.
677    this uses three inputs: 1) maximum neighborlist cutoff, 2) an estimate
678    based on bond lengths and bonded interaction styles present, and 3) a
679    user supplied communication cutoff.
680    the neighbor list cutoff (1) is *always* used, since it is a requirement
681    for neighborlists working correctly. the bond length based cutoff is
682    *only* used, if no pair style is defined and no user cutoff is provided.
683    otherwise, a warning is printed. if the bond length based estimate is
684    larger than what is used.
685    print a warning, if a user specified communication cutoff is overridden.
686 ------------------------------------------------------------------------- */
687 
get_comm_cutoff()688 double Comm::get_comm_cutoff()
689 {
690   double maxcommcutoff, maxbondcutoff = 0.0;
691 
692   if (force->bond) {
693     int n = atom->nbondtypes;
694     for (int i = 1; i <= n; ++i)
695       maxbondcutoff = MAX(maxbondcutoff,force->bond->equilibrium_distance(i));
696 
697     // apply bond length based heuristics.
698 
699     if (force->newton_bond) {
700       if (force->dihedral || force->improper) {
701         maxbondcutoff *= 2.25;
702       } else {
703         maxbondcutoff *=1.5;
704       }
705     } else {
706       if (force->dihedral || force->improper) {
707         maxbondcutoff *= 3.125;
708       } else if (force->angle) {
709         maxbondcutoff *= 2.25;
710       } else {
711         maxbondcutoff *=1.5;
712       }
713     }
714     maxbondcutoff += neighbor->skin;
715   }
716 
717   // always take the larger of max neighbor list and user specified cutoff
718 
719   maxcommcutoff = MAX(cutghostuser,neighbor->cutneighmax);
720 
721   // use cutoff estimate from bond length only if no user specified
722   // cutoff was given and no pair style present. Otherwise print a
723   // warning, if the estimated bond based cutoff is larger than what
724   // is currently used.
725 
726   if (!force->pair && (cutghostuser == 0.0)) {
727     maxcommcutoff = MAX(maxcommcutoff,maxbondcutoff);
728   } else {
729     if ((me == 0) && (maxbondcutoff > maxcommcutoff))
730       error->warning(FLERR,"Communication cutoff {} is shorter than a bond "
731                      "length based estimate of {}. This may lead to errors.",
732                      maxcommcutoff,maxbondcutoff);
733   }
734 
735   // print warning if neighborlist cutoff overrides user cutoff
736 
737   if ((me == 0) && (update->setupflag == 1)) {
738     if ((cutghostuser > 0.0) && (maxcommcutoff > cutghostuser))
739       error->warning(FLERR,"Communication cutoff adjusted to {}",maxcommcutoff);
740   }
741 
742   // Check maximum interval size for neighbor multi
743   if (neighbor->interval_collection_flag) {
744     for (int i = 0; i < neighbor->ncollections; i++){
745       maxcommcutoff = MAX(maxcommcutoff, neighbor->collection2cut[i]);
746     }
747   }
748 
749   return maxcommcutoff;
750 }
751 
752 /* ----------------------------------------------------------------------
753    determine which proc owns atom with coord x[3] based on current decomp
754    x will be in box (orthogonal) or lamda coords (triclinic)
755    if layout = UNIFORM, calculate owning proc directly
756    if layout = NONUNIFORM, iteratively find owning proc via binary search
757    if layout = TILED, CommTiled has its own method
758    return owning proc ID via grid2proc
759    return igx,igy,igz = logical grid loc of owing proc within 3d grid of procs
760 ------------------------------------------------------------------------- */
761 
coord2proc(double * x,int & igx,int & igy,int & igz)762 int Comm::coord2proc(double *x, int &igx, int &igy, int &igz)
763 {
764   double *prd = domain->prd;
765   double *boxlo = domain->boxlo;
766 
767   // initialize triclinic b/c coord2proc can be called before Comm::init()
768   // via Irregular::migrate_atoms()
769 
770   triclinic = domain->triclinic;
771 
772   if (layout == Comm::LAYOUT_UNIFORM) {
773     if (triclinic == 0) {
774       igx = static_cast<int> (procgrid[0] * (x[0]-boxlo[0]) / prd[0]);
775       igy = static_cast<int> (procgrid[1] * (x[1]-boxlo[1]) / prd[1]);
776       igz = static_cast<int> (procgrid[2] * (x[2]-boxlo[2]) / prd[2]);
777     } else {
778       igx = static_cast<int> (procgrid[0] * x[0]);
779       igy = static_cast<int> (procgrid[1] * x[1]);
780       igz = static_cast<int> (procgrid[2] * x[2]);
781     }
782 
783   } else if (layout == Comm::LAYOUT_NONUNIFORM) {
784     if (triclinic == 0) {
785       igx = utils::binary_search((x[0]-boxlo[0])/prd[0],procgrid[0],xsplit);
786       igy = utils::binary_search((x[1]-boxlo[1])/prd[1],procgrid[1],ysplit);
787       igz = utils::binary_search((x[2]-boxlo[2])/prd[2],procgrid[2],zsplit);
788     } else {
789       igx = utils::binary_search(x[0],procgrid[0],xsplit);
790       igy = utils::binary_search(x[1],procgrid[1],ysplit);
791       igz = utils::binary_search(x[2],procgrid[2],zsplit);
792     }
793   }
794 
795   if (igx < 0) igx = 0;
796   if (igx >= procgrid[0]) igx = procgrid[0] - 1;
797   if (igy < 0) igy = 0;
798   if (igy >= procgrid[1]) igy = procgrid[1] - 1;
799   if (igz < 0) igz = 0;
800   if (igz >= procgrid[2]) igz = procgrid[2] - 1;
801 
802   return grid2proc[igx][igy][igz];
803 }
804 
805 /* ----------------------------------------------------------------------
806    partition a global regular grid into one brick-shaped sub-grid per proc
807    if grid point is inside my sub-domain I own it,
808      this includes sub-domain lo boundary but excludes hi boundary
809    nx,ny,nz = extent of global grid
810      indices into the global grid range from 0 to N-1 in each dim
811    zfactor = 0.0 if the grid exactly covers the simulation box
812    zfactor > 1.0 if the grid extends beyond the +z boundary by this factor
813      used by 2d slab-mode PPPM
814      this effectively maps proc sub-grids to a smaller subset of the grid
815    nxyz lo/hi = inclusive lo/hi bounds of global grid sub-brick I own
816    if proc owns no grid cells in a dim, then nlo > nhi
817    special case: 2 procs share boundary which a grid point is exactly on
818      2 equality if tests insure a consistent decision as to which proc owns it
819 ------------------------------------------------------------------------- */
820 
partition_grid(int nx,int ny,int nz,double zfactor,int & nxlo,int & nxhi,int & nylo,int & nyhi,int & nzlo,int & nzhi)821 void Comm::partition_grid(int nx, int ny, int nz, double zfactor,
822                           int &nxlo, int &nxhi, int &nylo, int &nyhi,
823                           int &nzlo, int &nzhi)
824 {
825   double xfraclo,xfrachi,yfraclo,yfrachi,zfraclo,zfrachi;
826 
827   if (layout != LAYOUT_TILED) {
828     xfraclo = xsplit[myloc[0]];
829     xfrachi = xsplit[myloc[0]+1];
830     yfraclo = ysplit[myloc[1]];
831     yfrachi = ysplit[myloc[1]+1];
832     zfraclo = zsplit[myloc[2]];
833     zfrachi = zsplit[myloc[2]+1];
834   } else {
835     xfraclo = mysplit[0][0];
836     xfrachi = mysplit[0][1];
837     yfraclo = mysplit[1][0];
838     yfrachi = mysplit[1][1];
839     zfraclo = mysplit[2][0];
840     zfrachi = mysplit[2][1];
841   }
842 
843   nxlo = static_cast<int> (xfraclo * nx);
844   if (1.0*nxlo != xfraclo*nx) nxlo++;
845   nxhi = static_cast<int> (xfrachi * nx);
846   if (1.0*nxhi == xfrachi*nx) nxhi--;
847 
848   nylo = static_cast<int> (yfraclo * ny);
849   if (1.0*nylo != yfraclo*ny) nylo++;
850   nyhi = static_cast<int> (yfrachi * ny);
851   if (1.0*nyhi == yfrachi*ny) nyhi--;
852 
853   if (zfactor == 0.0) {
854     nzlo = static_cast<int> (zfraclo * nz);
855     if (1.0*nzlo != zfraclo*nz) nzlo++;
856     nzhi = static_cast<int> (zfrachi * nz);
857     if (1.0*nzhi == zfrachi*nz) nzhi--;
858   } else {
859     nzlo = static_cast<int> (zfraclo * nz/zfactor);
860     if (1.0*nzlo != zfraclo*nz) nzlo++;
861     nzhi = static_cast<int> (zfrachi * nz/zfactor);
862     if (1.0*nzhi == zfrachi*nz) nzhi--;
863   }
864 
865   // OLD code
866   // could sometimes map grid points slightly outside a proc to the proc
867 
868   /*
869   if (layout != LAYOUT_TILED) {
870     nxlo = static_cast<int> (xsplit[myloc[0]] * nx);
871     nxhi = static_cast<int> (xsplit[myloc[0]+1] * nx) - 1;
872 
873     nylo = static_cast<int> (ysplit[myloc[1]] * ny);
874     nyhi = static_cast<int> (ysplit[myloc[1]+1] * ny) - 1;
875 
876     if (zfactor == 0.0) {
877       nzlo = static_cast<int> (zsplit[myloc[2]] * nz);
878       nzhi = static_cast<int> (zsplit[myloc[2]+1] * nz) - 1;
879     } else {
880       nzlo = static_cast<int> (zsplit[myloc[2]] * nz/zfactor);
881       nzhi = static_cast<int> (zsplit[myloc[2]+1] * nz/zfactor) - 1;
882     }
883 
884   } else {
885     nxlo = static_cast<int> (mysplit[0][0] * nx);
886     nxhi = static_cast<int> (mysplit[0][1] * nx) - 1;
887 
888     nylo = static_cast<int> (mysplit[1][0] * ny);
889     nyhi = static_cast<int> (mysplit[1][1] * ny) - 1;
890 
891     if (zfactor == 0.0) {
892       nzlo = static_cast<int> (mysplit[2][0] * nz);
893       nzhi = static_cast<int> (mysplit[2][1] * nz) - 1;
894     } else {
895       nzlo = static_cast<int> (mysplit[2][0] * nz/zfactor);
896       nzhi = static_cast<int> (mysplit[2][1] * nz/zfactor) - 1;
897     }
898   }
899   */
900 }
901 
902 /* ----------------------------------------------------------------------
903    communicate inbuf around full ring of processors with messtag
904    nbytes = size of inbuf = n datums * nper bytes
905    callback() is invoked to allow caller to process/update each proc's inbuf
906    if self=1 (default), then callback() is invoked on final iteration
907      using original inbuf, which may have been updated
908    for non-nullptr outbuf, final updated inbuf is copied to it
909      ok to specify outbuf = inbuf
910    the ptr argument is a pointer to the instance of calling class
911 ------------------------------------------------------------------------- */
912 
ring(int n,int nper,void * inbuf,int messtag,void (* callback)(int,char *,void *),void * outbuf,void * ptr,int self)913 void Comm::ring(int n, int nper, void *inbuf, int messtag,
914                 void (*callback)(int, char *, void *),
915                 void *outbuf, void *ptr, int self)
916 {
917   MPI_Request request;
918   MPI_Status status;
919 
920   int nbytes = n*nper;
921   int maxbytes;
922   MPI_Allreduce(&nbytes,&maxbytes,1,MPI_INT,MPI_MAX,world);
923 
924   // no need to communicate without data
925 
926   if (maxbytes == 0) return;
927 
928   // sanity check
929 
930   if ((nbytes > 0) && inbuf == nullptr)
931     error->one(FLERR,"Cannot put data on ring from NULL pointer");
932 
933   char *buf,*bufcopy;
934   memory->create(buf,maxbytes,"comm:buf");
935   memory->create(bufcopy,maxbytes,"comm:bufcopy");
936   if (nbytes && inbuf) memcpy(buf,inbuf,nbytes);
937 
938   int next = me + 1;
939   int prev = me - 1;
940   if (next == nprocs) next = 0;
941   if (prev < 0) prev = nprocs - 1;
942 
943   for (int loop = 0; loop < nprocs; loop++) {
944     if (me != next) {
945       MPI_Irecv(bufcopy,maxbytes,MPI_CHAR,prev,messtag,world,&request);
946       MPI_Send(buf,nbytes,MPI_CHAR,next,messtag,world);
947       MPI_Wait(&request,&status);
948       MPI_Get_count(&status,MPI_CHAR,&nbytes);
949       if (nbytes) memcpy(buf,bufcopy,nbytes);
950     }
951     if (self || loop < nprocs-1) callback(nbytes/nper,buf,ptr);
952   }
953 
954   if (nbytes && outbuf) memcpy(outbuf,buf,nbytes);
955 
956   memory->destroy(buf);
957   memory->destroy(bufcopy);
958 }
959 
960 /* ----------------------------------------------------------------------
961    rendezvous communication operation
962    three stages:
963      first comm sends inbuf from caller decomp to rvous decomp
964      callback operates on data in rendezvous decomp
965      second comm sends outbuf from rvous decomp back to caller decomp
966    inputs:
967      which = perform (0) irregular or (1) MPI_All2allv communication
968      n = # of datums in inbuf
969      inbuf = vector of input datums
970      insize = byte size of each input datum
971      inorder = 0 for inbuf in random proc order, 1 for datums ordered by proc
972      procs: inorder 0 = proc to send each datum to, 1 = # of datums/proc,
973      callback = caller function to invoke in rendezvous decomposition
974                 takes input datums, returns output datums
975      outorder = same as inorder, but for datums returned by callback()
976      ptr = pointer to caller class, passed to callback()
977    outputs:
978      nout = # of output datums (function return)
979      outbuf = vector of output datums
980      outsize = byte size of each output datum
981    callback inputs:
982      nrvous = # of rvous decomp datums in inbuf_rvous
983      inbuf_rvous = vector of rvous decomp input datums
984      ptr = pointer to caller class
985    callback outputs:
986      nrvous_out = # of rvous decomp output datums (function return)
987      flag = 0 for no second comm, 1 for outbuf_rvous = inbuf_rvous,
988             2 for second comm with new outbuf_rvous
989      procs_rvous = outorder 0 = proc to send each datum to, 1 = # of datums/proc
990                    allocated
991      outbuf_rvous = vector of rvous decomp output datums
992    NOTE: could use MPI_INT or MPI_DOUBLE insead of MPI_CHAR
993          to avoid checked-for overflow in MPI_Alltoallv?
994 ------------------------------------------------------------------------- */
995 
996 int Comm::
rendezvous(int which,int n,char * inbuf,int insize,int inorder,int * procs,int (* callback)(int,char *,int &,int * &,char * &,void *),int outorder,char * & outbuf,int outsize,void * ptr,int statflag)997 rendezvous(int which, int n, char *inbuf, int insize,
998            int inorder, int *procs,
999            int (*callback)(int, char *, int &, int *&, char *&, void *),
1000            int outorder, char *&outbuf, int outsize, void *ptr, int statflag)
1001 {
1002   if (which == 0)
1003     return rendezvous_irregular(n,inbuf,insize,inorder,procs,callback,
1004                                 outorder,outbuf,outsize,ptr,statflag);
1005   else
1006     return rendezvous_all2all(n,inbuf,insize,inorder,procs,callback,
1007                               outorder,outbuf,outsize,ptr,statflag);
1008 }
1009 
1010 /* ---------------------------------------------------------------------- */
1011 
1012 int Comm::
rendezvous_irregular(int n,char * inbuf,int insize,int inorder,int * procs,int (* callback)(int,char *,int &,int * &,char * &,void *),int outorder,char * & outbuf,int outsize,void * ptr,int statflag)1013 rendezvous_irregular(int n, char *inbuf, int insize, int inorder, int *procs,
1014                      int (*callback)(int, char *, int &, int *&, char *&, void *),
1015                      int outorder, char *&outbuf,
1016                      int outsize, void *ptr, int statflag)
1017 {
1018   // irregular comm of inbuf from caller decomp to rendezvous decomp
1019 
1020   Irregular *irregular = new Irregular(lmp);
1021 
1022   int nrvous;
1023   if (inorder) nrvous = irregular->create_data_grouped(n,procs);
1024   else nrvous = irregular->create_data(n,procs);
1025 
1026   // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1027 
1028   char *inbuf_rvous = (char *) memory->smalloc((bigint) nrvous*insize+1,
1029                                                "rendezvous:inbuf");
1030   irregular->exchange_data(inbuf,insize,inbuf_rvous);
1031 
1032   bigint irregular1_bytes = irregular->memory_usage();
1033   irregular->destroy_data();
1034   delete irregular;
1035 
1036   // peform rendezvous computation via callback()
1037   // callback() allocates/populates proclist_rvous and outbuf_rvous
1038 
1039   int flag;
1040   int *procs_rvous;
1041   char *outbuf_rvous;
1042   int nrvous_out = callback(nrvous,inbuf_rvous,flag,
1043                             procs_rvous,outbuf_rvous,ptr);
1044 
1045   if (flag != 1) memory->sfree(inbuf_rvous);  // outbuf_rvous = inbuf_vous
1046   if (flag == 0) {
1047     if (statflag) rendezvous_stats(n,0,nrvous,nrvous_out,insize,outsize,
1048                                    (bigint) nrvous_out*sizeof(int) +
1049                                    irregular1_bytes);
1050     return 0;    // all nout_rvous are 0, no 2nd comm stage
1051   }
1052 
1053   // irregular comm of outbuf from rendezvous decomp back to caller decomp
1054   // caller will free outbuf
1055 
1056   irregular = new Irregular(lmp);
1057 
1058   int nout;
1059   if (outorder)
1060     nout = irregular->create_data_grouped(nrvous_out,procs_rvous);
1061   else nout = irregular->create_data(nrvous_out,procs_rvous);
1062 
1063   // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1064 
1065   outbuf = (char *) memory->smalloc((bigint) nout*outsize+1,
1066                                     "rendezvous:outbuf");
1067   irregular->exchange_data(outbuf_rvous,outsize,outbuf);
1068 
1069   bigint irregular2_bytes = irregular->memory_usage();
1070   irregular->destroy_data();
1071   delete irregular;
1072 
1073   memory->destroy(procs_rvous);
1074   memory->sfree(outbuf_rvous);
1075 
1076   // return number of output datums
1077   // last arg to stats() = memory for procs_rvous + irregular comm
1078 
1079   if (statflag) rendezvous_stats(n,nout,nrvous,nrvous_out,insize,outsize,
1080                                  (bigint) nrvous_out*sizeof(int) +
1081                                  MAX(irregular1_bytes,irregular2_bytes));
1082   return nout;
1083 }
1084 
1085 /* ---------------------------------------------------------------------- */
1086 
1087 int Comm::
rendezvous_all2all(int n,char * inbuf,int insize,int inorder,int * procs,int (* callback)(int,char *,int &,int * &,char * &,void *),int outorder,char * & outbuf,int outsize,void * ptr,int statflag)1088 rendezvous_all2all(int n, char *inbuf, int insize, int inorder, int *procs,
1089                    int (*callback)(int, char *, int &, int *&, char *&, void *),
1090                    int outorder, char *&outbuf, int outsize, void *ptr,
1091                    int statflag)
1092 {
1093   int iproc;
1094   bigint all2all1_bytes,all2all2_bytes;
1095   int *sendcount,*sdispls,*recvcount,*rdispls;
1096   int *procs_a2a;
1097   bigint *offsets;
1098   char *inbuf_a2a,*outbuf_a2a;
1099 
1100   // create procs and inbuf for All2all if necessary
1101 
1102   if (!inorder) {
1103     memory->create(procs_a2a,nprocs,"rendezvous:procs");
1104 
1105     // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1106 
1107     inbuf_a2a = (char *) memory->smalloc((bigint) n*insize+1,
1108                                          "rendezvous:inbuf");
1109     memset(inbuf_a2a,0,(bigint)n*insize*sizeof(char));
1110     memory->create(offsets,nprocs,"rendezvous:offsets");
1111 
1112     for (int i = 0; i < nprocs; i++) procs_a2a[i] = 0;
1113     for (int i = 0; i < n; i++) procs_a2a[procs[i]]++;
1114 
1115     offsets[0] = 0;
1116     for (int i = 1; i < nprocs; i++)
1117       offsets[i] = offsets[i-1] + (bigint)insize*procs_a2a[i-1];
1118 
1119     bigint offset = 0;
1120     for (int i = 0; i < n; i++) {
1121       iproc = procs[i];
1122       memcpy(&inbuf_a2a[offsets[iproc]],&inbuf[offset],insize);
1123       offsets[iproc] += insize;
1124       offset += insize;
1125     }
1126 
1127     all2all1_bytes = nprocs*sizeof(int) + nprocs*sizeof(bigint)
1128                      + (bigint)n*insize;
1129 
1130   } else {
1131     procs_a2a = procs;
1132     inbuf_a2a = inbuf;
1133     all2all1_bytes = 0;
1134   }
1135 
1136   // create args for MPI_Alltoallv() on input data
1137 
1138   memory->create(sendcount,nprocs,"rendezvous:sendcount");
1139   memcpy(sendcount,procs_a2a,nprocs*sizeof(int));
1140 
1141   memory->create(recvcount,nprocs,"rendezvous:recvcount");
1142   MPI_Alltoall(sendcount,1,MPI_INT,recvcount,1,MPI_INT,world);
1143 
1144   memory->create(sdispls,nprocs,"rendezvous:sdispls");
1145   memory->create(rdispls,nprocs,"rendezvous:rdispls");
1146   sdispls[0] = rdispls[0] = 0;
1147   for (int i = 1; i < nprocs; i++) {
1148     sdispls[i] = sdispls[i-1] + sendcount[i-1];
1149     rdispls[i] = rdispls[i-1] + recvcount[i-1];
1150   }
1151   int nrvous = rdispls[nprocs-1] + recvcount[nprocs-1];
1152 
1153   // test for overflow of input data due to imbalance or insize
1154   // means that individual sdispls or rdispls values overflow
1155 
1156   int overflow = 0;
1157   if ((bigint) n*insize > MAXSMALLINT) overflow = 1;
1158   if ((bigint) nrvous*insize > MAXSMALLINT) overflow = 1;
1159   int overflowall;
1160   MPI_Allreduce(&overflow,&overflowall,1,MPI_INT,MPI_MAX,world);
1161   if (overflowall) error->all(FLERR,"Overflow input size in rendezvous_a2a");
1162 
1163   for (int i = 0; i < nprocs; i++) {
1164     sendcount[i] *= insize;
1165     sdispls[i] *= insize;
1166     recvcount[i] *= insize;
1167     rdispls[i] *= insize;
1168   }
1169 
1170   // all2all comm of inbuf from caller decomp to rendezvous decomp
1171   // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1172 
1173   char *inbuf_rvous = (char *) memory->smalloc((bigint) nrvous*insize+1,
1174                                                "rendezvous:inbuf");
1175   memset(inbuf_rvous,0,(bigint) nrvous*insize*sizeof(char));
1176 
1177   MPI_Alltoallv(inbuf_a2a,sendcount,sdispls,MPI_CHAR,
1178                 inbuf_rvous,recvcount,rdispls,MPI_CHAR,world);
1179 
1180   if (!inorder) {
1181     memory->destroy(procs_a2a);
1182     memory->sfree(inbuf_a2a);
1183     memory->destroy(offsets);
1184   }
1185 
1186   // peform rendezvous computation via callback()
1187   // callback() allocates/populates proclist_rvous and outbuf_rvous
1188 
1189   int flag;
1190   int *procs_rvous;
1191   char *outbuf_rvous;
1192 
1193   int nrvous_out = callback(nrvous,inbuf_rvous,flag,
1194                             procs_rvous,outbuf_rvous,ptr);
1195 
1196   if (flag != 1) memory->sfree(inbuf_rvous);  // outbuf_rvous = inbuf_vous
1197   if (flag == 0) {
1198     memory->destroy(sendcount);
1199     memory->destroy(recvcount);
1200     memory->destroy(sdispls);
1201     memory->destroy(rdispls);
1202     if (statflag) rendezvous_stats(n,0,nrvous,nrvous_out,insize,outsize,
1203                                    (bigint) nrvous_out*sizeof(int) +
1204                                    4*nprocs*sizeof(int) + all2all1_bytes);
1205     return 0;    // all nout_rvous are 0, no 2nd irregular
1206   }
1207 
1208   // create procs and outbuf for All2all if necessary
1209 
1210   if (!outorder) {
1211     memory->create(procs_a2a,nprocs,"rendezvous_a2a:procs");
1212 
1213     // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1214 
1215     outbuf_a2a = (char *) memory->smalloc((bigint) nrvous_out*outsize+1,
1216                                           "rendezvous:outbuf");
1217     memory->create(offsets,nprocs,"rendezvous:offsets");
1218 
1219     for (int i = 0; i < nprocs; i++) procs_a2a[i] = 0;
1220     for (int i = 0; i < nrvous_out; i++) procs_a2a[procs_rvous[i]]++;
1221 
1222     offsets[0] = 0;
1223     for (int i = 1; i < nprocs; i++)
1224       offsets[i] = offsets[i-1] + (bigint)outsize*procs_a2a[i-1];
1225 
1226     bigint offset = 0;
1227     for (int i = 0; i < nrvous_out; i++) {
1228       iproc = procs_rvous[i];
1229       memcpy(&outbuf_a2a[offsets[iproc]],&outbuf_rvous[offset],outsize);
1230       offsets[iproc] += outsize;
1231       offset += outsize;
1232     }
1233 
1234     all2all2_bytes = nprocs*sizeof(int) + nprocs*sizeof(bigint) +
1235       (bigint)nrvous_out*outsize;
1236 
1237   } else {
1238     procs_a2a = procs_rvous;
1239     outbuf_a2a = outbuf_rvous;
1240     all2all2_bytes = 0;
1241   }
1242 
1243   // comm outbuf from rendezvous decomposition back to caller
1244 
1245   memcpy(sendcount,procs_a2a,nprocs*sizeof(int));
1246 
1247   MPI_Alltoall(sendcount,1,MPI_INT,recvcount,1,MPI_INT,world);
1248 
1249   sdispls[0] = rdispls[0] = 0;
1250   for (int i = 1; i < nprocs; i++) {
1251     sdispls[i] = sdispls[i-1] + sendcount[i-1];
1252     rdispls[i] = rdispls[i-1] + recvcount[i-1];
1253   }
1254   int nout = rdispls[nprocs-1] + recvcount[nprocs-1];
1255 
1256   // test for overflow of outbuf due to imbalance or outsize
1257   // means that individual sdispls or rdispls values overflow
1258 
1259   overflow = 0;
1260   if ((bigint) nrvous*outsize > MAXSMALLINT) overflow = 1;
1261   if ((bigint) nout*outsize > MAXSMALLINT) overflow = 1;
1262   MPI_Allreduce(&overflow,&overflowall,1,MPI_INT,MPI_MAX,world);
1263   if (overflowall) error->all(FLERR,"Overflow output in rendezvous_a2a");
1264 
1265   for (int i = 0; i < nprocs; i++) {
1266     sendcount[i] *= outsize;
1267     sdispls[i] *= outsize;
1268     recvcount[i] *= outsize;
1269     rdispls[i] *= outsize;
1270   }
1271 
1272   // all2all comm of outbuf from rendezvous decomp back to caller decomp
1273   // caller will free outbuf
1274   // add 1 item to the allocated buffer size, so the returned pointer is not a null pointer
1275 
1276   outbuf = (char *) memory->smalloc((bigint) nout*outsize+1,"rendezvous:outbuf");
1277 
1278   MPI_Alltoallv(outbuf_a2a,sendcount,sdispls,MPI_CHAR,
1279                 outbuf,recvcount,rdispls,MPI_CHAR,world);
1280 
1281   memory->destroy(procs_rvous);
1282   memory->sfree(outbuf_rvous);
1283 
1284   if (!outorder) {
1285     memory->destroy(procs_a2a);
1286     memory->sfree(outbuf_a2a);
1287     memory->destroy(offsets);
1288   }
1289 
1290   // clean up
1291 
1292   memory->destroy(sendcount);
1293   memory->destroy(recvcount);
1294   memory->destroy(sdispls);
1295   memory->destroy(rdispls);
1296 
1297   // return number of output datums
1298   // last arg to stats() = mem for procs_rvous + per-proc vecs + reordering ops
1299 
1300   if (statflag) rendezvous_stats(n,nout,nrvous,nrvous_out,insize,outsize,
1301                                  (bigint) nrvous_out*sizeof(int) +
1302                                  4*nprocs*sizeof(int) +
1303                                  MAX(all2all1_bytes,all2all2_bytes));
1304   return nout;
1305 }
1306 
1307 /* ----------------------------------------------------------------------
1308    print balance and memory info for rendezvous operation
1309    useful for debugging
1310 ------------------------------------------------------------------------- */
1311 
rendezvous_stats(int n,int nout,int nrvous,int nrvous_out,int insize,int outsize,bigint commsize)1312 void Comm::rendezvous_stats(int n, int nout, int nrvous, int nrvous_out,
1313                             int insize, int outsize, bigint commsize)
1314 {
1315   bigint size_in_all,size_in_max,size_in_min;
1316   bigint size_out_all,size_out_max,size_out_min;
1317   bigint size_inrvous_all,size_inrvous_max,size_inrvous_min;
1318   bigint size_outrvous_all,size_outrvous_max,size_outrvous_min;
1319   bigint size_comm_all,size_comm_max,size_comm_min;
1320 
1321   bigint size = (bigint) n*insize;
1322   MPI_Allreduce(&size,&size_in_all,1,MPI_LMP_BIGINT,MPI_SUM,world);
1323   MPI_Allreduce(&size,&size_in_max,1,MPI_LMP_BIGINT,MPI_MAX,world);
1324   MPI_Allreduce(&size,&size_in_min,1,MPI_LMP_BIGINT,MPI_MIN,world);
1325 
1326   size = (bigint) nout*outsize;
1327   MPI_Allreduce(&size,&size_out_all,1,MPI_LMP_BIGINT,MPI_SUM,world);
1328   MPI_Allreduce(&size,&size_out_max,1,MPI_LMP_BIGINT,MPI_MAX,world);
1329   MPI_Allreduce(&size,&size_out_min,1,MPI_LMP_BIGINT,MPI_MIN,world);
1330 
1331   size = (bigint) nrvous*insize;
1332   MPI_Allreduce(&size,&size_inrvous_all,1,MPI_LMP_BIGINT,MPI_SUM,world);
1333   MPI_Allreduce(&size,&size_inrvous_max,1,MPI_LMP_BIGINT,MPI_MAX,world);
1334   MPI_Allreduce(&size,&size_inrvous_min,1,MPI_LMP_BIGINT,MPI_MIN,world);
1335 
1336   size = (bigint) nrvous_out*insize;
1337   MPI_Allreduce(&size,&size_outrvous_all,1,MPI_LMP_BIGINT,MPI_SUM,world);
1338   MPI_Allreduce(&size,&size_outrvous_max,1,MPI_LMP_BIGINT,MPI_MAX,world);
1339   MPI_Allreduce(&size,&size_outrvous_min,1,MPI_LMP_BIGINT,MPI_MIN,world);
1340 
1341   size = commsize;
1342   MPI_Allreduce(&size,&size_comm_all,1,MPI_LMP_BIGINT,MPI_SUM,world);
1343   MPI_Allreduce(&size,&size_comm_max,1,MPI_LMP_BIGINT,MPI_MAX,world);
1344   MPI_Allreduce(&size,&size_comm_min,1,MPI_LMP_BIGINT,MPI_MIN,world);
1345 
1346   int mbytes = 1024*1024;
1347 
1348   if (me == 0) {
1349     std::string mesg = "Rendezvous balance and memory info: (tot,ave,max,min) \n";
1350     mesg += fmt::format("  input datum count: {} {} {} {}\n",
1351                         size_in_all/insize,1.0*size_in_all/nprocs/insize,
1352                         size_in_max/insize,size_in_min/insize);
1353     mesg += fmt::format("  input data (MB): {:.6} {:.6} {:.6} {:.6}\n",
1354                         1.0*size_in_all/mbytes,1.0*size_in_all/nprocs/mbytes,
1355                         1.0*size_in_max/mbytes,1.0*size_in_min/mbytes);
1356     if (outsize)
1357       mesg += fmt::format("  output datum count: {} {} {} {}\n",
1358                           size_out_all/outsize,1.0*size_out_all/nprocs/outsize,
1359                           size_out_max/outsize,size_out_min/outsize);
1360     else
1361       mesg += fmt::format("  output datum count: {} {:.6} {} {}\n",0,0.0,0,0);
1362 
1363     mesg += fmt::format("  output data (MB): {:.6} {:.6} {:.6} {:.6}\n",
1364                         1.0*size_out_all/mbytes,1.0*size_out_all/nprocs/mbytes,
1365                         1.0*size_out_max/mbytes,1.0*size_out_min/mbytes);
1366     mesg += fmt::format("  input rvous datum count: {} {} {} {}\n",
1367                         size_inrvous_all/insize,1.0*size_inrvous_all/nprocs/insize,
1368                         size_inrvous_max/insize,size_inrvous_min/insize);
1369     mesg += fmt::format("  input rvous data (MB): {:.6} {:.6} {:.6} {:.6}\n",
1370                         1.0*size_inrvous_all/mbytes,1.0*size_inrvous_all/nprocs/mbytes,
1371                         1.0*size_inrvous_max/mbytes,1.0*size_inrvous_min/mbytes);
1372     if (outsize)
1373       mesg += fmt::format("  output rvous datum count: {} {} {} {}\n",
1374                           size_outrvous_all/outsize,1.0*size_outrvous_all/nprocs/outsize,
1375                           size_outrvous_max/outsize,size_outrvous_min/outsize);
1376     else
1377       mesg += fmt::format("  output rvous datum count: {} {:.6} {} {}\n",0,0.0,0,0);
1378     mesg += fmt::format("  output rvous data (MB): {:.6} {:.6} {:.6} {:.6}\n",
1379                         1.0*size_outrvous_all/mbytes,1.0*size_outrvous_all/nprocs/mbytes,
1380                         1.0*size_outrvous_max/mbytes,1.0*size_outrvous_min/mbytes);
1381     mesg += fmt::format("  rvous comm (MB): {:.6} {:.6} {:.6} {:.6}\n",
1382                         1.0*size_comm_all/mbytes,1.0*size_comm_all/nprocs/mbytes,
1383                         1.0*size_comm_max/mbytes,1.0*size_comm_min/mbytes);
1384     utils::logmesg(lmp,mesg);
1385   }
1386 }
1387