1 /* begin_generated_IBM_copyright_prolog                             */
2 /*                                                                  */
3 /* This is an automatically generated copyright prolog.             */
4 /* After initializing,  DO NOT MODIFY OR MOVE                       */
5 /*  --------------------------------------------------------------- */
6 /* Licensed Materials - Property of IBM                             */
7 /* Blue Gene/Q 5765-PER 5765-PRP                                    */
8 /*                                                                  */
9 /* (C) Copyright IBM Corp. 2011, 2012 All Rights Reserved           */
10 /* US Government Users Restricted Rights -                          */
11 /* Use, duplication, or disclosure restricted                       */
12 /* by GSA ADP Schedule Contract with IBM Corp.                      */
13 /*                                                                  */
14 /*  --------------------------------------------------------------- */
15 /*                                                                  */
16 /* end_generated_IBM_copyright_prolog                               */
17 /*  (C)Copyright IBM Corp.  2007, 2011  */
18 /**
19  * \file src/coll/allreduce/mpido_allreduce.c
20  * \brief ???
21  */
22 
23 /*#define TRACE_ON*/
24 
25 #include <mpidimpl.h>
26 
cb_allreduce(void * ctxt,void * clientdata,pami_result_t err)27 static void cb_allreduce(void *ctxt, void *clientdata, pami_result_t err)
28 {
29    int *active = (int *) clientdata;
30    TRACE_ERR("callback enter, active: %d\n", (*active));
31    MPIDI_Progress_signal();
32    (*active)--;
33 }
34 
MPIDO_Allreduce(const void * sendbuf,void * recvbuf,int count,MPI_Datatype dt,MPI_Op op,MPID_Comm * comm_ptr,int * mpierrno)35 int MPIDO_Allreduce(const void *sendbuf,
36                     void *recvbuf,
37                     int count,
38                     MPI_Datatype dt,
39                     MPI_Op op,
40                     MPID_Comm *comm_ptr,
41                     int *mpierrno)
42 {
43    void *sbuf;
44    TRACE_ERR("Entering mpido_allreduce\n");
45    pami_type_t pdt;
46    pami_data_function pop;
47    int mu;
48    int rc;
49 #ifdef TRACE_ON
50     int len;
51     char op_str[255];
52     char dt_str[255];
53     MPIDI_Op_to_string(op, op_str);
54     PMPI_Type_get_name(dt, dt_str, &len);
55 #endif
56    volatile unsigned active = 1;
57    pami_xfer_t allred;
58    pami_algorithm_t my_allred;
59    pami_metadata_t *my_allred_md = (pami_metadata_t *)NULL;
60    int alg_selected = 0;
61 
62    if(likely(dt == MPI_DOUBLE || dt == MPI_DOUBLE_PRECISION))
63    {
64       rc = MPI_SUCCESS;
65       pdt = PAMI_TYPE_DOUBLE;
66       if(likely(op == MPI_SUM))
67          pop = PAMI_DATA_SUM;
68       else if(likely(op == MPI_MAX))
69          pop = PAMI_DATA_MAX;
70       else if(likely(op == MPI_MIN))
71          pop = PAMI_DATA_MIN;
72       else rc = MPIDI_Datatype_to_pami(dt, &pdt, op, &pop, &mu);
73    }
74    else rc = MPIDI_Datatype_to_pami(dt, &pdt, op, &pop, &mu);
75 
76   if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
77       fprintf(stderr,"allred rc %u, Datatype %p, op %p, mu %u, selectedvar %u != %u\n",
78               rc, pdt, pop, mu,
79               (unsigned)comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE],MPID_COLL_USE_MPICH);
80       /* convert to metadata query */
81   if(unlikely(rc != MPI_SUCCESS ||
82 	      comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE] == MPID_COLL_USE_MPICH))
83    {
84       if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
85          fprintf(stderr,"Using MPICH allreduce type %u.\n",
86                  comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE]);
87       MPIDI_Update_last_algorithm(comm_ptr, "ALLREDUCE_MPICH");
88       return MPIR_Allreduce(sendbuf, recvbuf, count, dt, op, comm_ptr, mpierrno);
89    }
90 
91   if(unlikely(sendbuf == MPI_IN_PLACE))
92    {
93       if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL))
94          fprintf(stderr,"allreduce MPI_IN_PLACE buffering\n");
95       sbuf = recvbuf;
96    }
97    else sbuf = (void *)sendbuf;
98 
99    allred.cb_done = cb_allreduce;
100    allred.cookie = (void *)&active;
101    allred.cmd.xfer_allreduce.sndbuf = sbuf;
102    allred.cmd.xfer_allreduce.stype = pdt;
103    allred.cmd.xfer_allreduce.rcvbuf = recvbuf;
104    allred.cmd.xfer_allreduce.rtype = pdt;
105    allred.cmd.xfer_allreduce.stypecount = count;
106    allred.cmd.xfer_allreduce.rtypecount = count;
107    allred.cmd.xfer_allreduce.op = pop;
108 
109    TRACE_ERR("Allreduce - Basic Collective Selection\n");
110    if(likely(comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE] == MPID_COLL_OPTIMIZED))
111    {
112      if(likely(pop == PAMI_DATA_SUM || pop == PAMI_DATA_MAX || pop == PAMI_DATA_MIN))
113       {
114          /* double protocol works on all message sizes */
115          if(likely(pdt == PAMI_TYPE_DOUBLE && comm_ptr->mpid.query_allred_dsmm == MPID_COLL_QUERY))
116          {
117             my_allred = comm_ptr->mpid.cached_allred_dsmm;
118             my_allred_md = &comm_ptr->mpid.cached_allred_dsmm_md;
119             alg_selected = 1;
120          }
121          else if(pdt == PAMI_TYPE_UNSIGNED_INT && comm_ptr->mpid.query_allred_ismm == MPID_COLL_QUERY)
122          {
123             my_allred = comm_ptr->mpid.cached_allred_ismm;
124             my_allred_md = &comm_ptr->mpid.cached_allred_ismm_md;
125             alg_selected = 1;
126          }
127          /* The integer protocol at >1 ppn requires small messages only */
128          else if(pdt == PAMI_TYPE_UNSIGNED_INT && comm_ptr->mpid.query_allred_ismm == MPID_COLL_CHECK_FN_REQUIRED &&
129                  count <= comm_ptr->mpid.cutoff_size[PAMI_XFER_ALLREDUCE][0])
130          {
131             my_allred = comm_ptr->mpid.cached_allred_ismm;
132             my_allred_md = &comm_ptr->mpid.cached_allred_ismm_md;
133             alg_selected = 1;
134          }
135          else if(comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_NOQUERY &&
136                  count <= comm_ptr->mpid.cutoff_size[PAMI_XFER_ALLREDUCE][0])
137          {
138             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][0];
139             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][0];
140             alg_selected = 1;
141          }
142          else if(comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][1] == MPID_COLL_NOQUERY &&
143                  count > comm_ptr->mpid.cutoff_size[PAMI_XFER_ALLREDUCE][0])
144          {
145             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][1];
146             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][1];
147             alg_selected = 1;
148          }
149          else if((comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_CHECK_FN_REQUIRED) ||
150 		 (comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_QUERY) ||
151 		 (comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] ==  MPID_COLL_ALWAYS_QUERY))
152          {
153             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][0];
154             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][0];
155             alg_selected = 1;
156          }
157       }
158       else
159       {
160          /* so we aren't one of the key ops... */
161          if(comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_NOQUERY &&
162             count <= comm_ptr->mpid.cutoff_size[PAMI_XFER_ALLREDUCE][0])
163          {
164             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][0];
165             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][0];
166             alg_selected = 1;
167          }
168          else if(comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][1] == MPID_COLL_NOQUERY &&
169                  count > comm_ptr->mpid.cutoff_size[PAMI_XFER_ALLREDUCE][0])
170          {
171             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][1];
172             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][1];
173             alg_selected = 1;
174          }
175          else if((comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_CHECK_FN_REQUIRED) ||
176 		 (comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_QUERY) ||
177 		 (comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_ALWAYS_QUERY))
178          {
179             my_allred = comm_ptr->mpid.opt_protocol[PAMI_XFER_ALLREDUCE][0];
180             my_allred_md = &comm_ptr->mpid.opt_protocol_md[PAMI_XFER_ALLREDUCE][0];
181             alg_selected = 1;
182          }
183       }
184       TRACE_ERR("Alg selected: %d\n", alg_selected);
185       if(likely(alg_selected))
186       {
187 	if(unlikely(comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_CHECK_FN_REQUIRED))
188         {
189            if(my_allred_md->check_fn != NULL)/*This should always be the case in FCA.. Otherwise punt to mpich*/
190            {
191               metadata_result_t result = {0};
192               TRACE_ERR("querying allreduce algorithm %s, type was %d\n",
193                  my_allred_md->name,
194                  comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE]);
195               result = my_allred_md->check_fn(&allred);
196               TRACE_ERR("bitmask: %#X\n", result.bitmask);
197               /* \todo Ignore check_correct.values.nonlocal until we implement the
198                  'pre-allreduce allreduce' or the 'safe' environment flag.
199                  We will basically assume 'safe' -- that all ranks are aligned (or not).
200               */
201               result.check.nonlocal = 0; /* #warning REMOVE THIS WHEN IMPLEMENTED */
202               if(!result.bitmask)
203               {
204                  allred.algorithm = my_allred;
205               }
206               else
207               {
208                  alg_selected = 0;
209                  if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
210                     fprintf(stderr,"check_fn failed for %s.\n", my_allred_md->name);
211               }
212            }
213          else alg_selected = 0;
214 	}
215 	else if(unlikely(((comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_QUERY) ||
216 			  (comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE][0] == MPID_COLL_ALWAYS_QUERY))))
217         {
218            if(my_allred_md->check_fn != NULL)/*This should always be the case in FCA.. Otherwise punt to mpich*/
219            {
220               metadata_result_t result = {0};
221               TRACE_ERR("querying allreduce algorithm %s, type was %d\n",
222                  my_allred_md->name,
223                  comm_ptr->mpid.must_query[PAMI_XFER_ALLREDUCE]);
224               result = my_allred_md->check_fn(&allred);
225               TRACE_ERR("bitmask: %#X\n", result.bitmask);
226               /* \todo Ignore check_correct.values.nonlocal until we implement the
227                  'pre-allreduce allreduce' or the 'safe' environment flag.
228                  We will basically assume 'safe' -- that all ranks are aligned (or not).
229               */
230               result.check.nonlocal = 0; /* #warning REMOVE THIS WHEN IMPLEMENTED */
231               if(!result.bitmask)
232               {
233                  allred.algorithm = my_allred;
234               }
235               else
236               {
237                  alg_selected = 0;
238                  if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
239                     fprintf(stderr,"check_fn failed for %s.\n", my_allred_md->name);
240               }
241            }
242 	   else /* no check_fn, manually look at the metadata fields */
243 	   {
244 	     /* Check if the message range if restricted */
245 	     if(my_allred_md->check_correct.values.rangeminmax)
246 	     {
247                MPI_Aint data_true_lb;
248                MPID_Datatype *data_ptr;
249                int data_size, data_contig;
250                MPIDI_Datatype_get_info(count, dt, data_contig, data_size, data_ptr, data_true_lb);
251                if((my_allred_md->range_lo <= data_size) &&
252                   (my_allred_md->range_hi >= data_size))
253                  allred.algorithm = my_allred; /* query algorithm successfully selected */
254                else
255 		 {
256 		   if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
257                      fprintf(stderr,"message size (%u) outside range (%zu<->%zu) for %s.\n",
258                              data_size,
259                              my_allred_md->range_lo,
260                              my_allred_md->range_hi,
261                              my_allred_md->name);
262 		   alg_selected = 0;
263 		 }
264 	     }
265 	     /* \todo check the rest of the metadata */
266 	   }
267         }
268         else
269         {
270            TRACE_ERR("Using %s for allreduce\n", my_allred_md->name);
271            allred.algorithm = my_allred;
272         }
273       }
274    }
275    else
276    {
277       my_allred = comm_ptr->mpid.user_selected[PAMI_XFER_ALLREDUCE];
278       my_allred_md = &comm_ptr->mpid.user_metadata[PAMI_XFER_ALLREDUCE];
279       allred.algorithm = my_allred;
280       if(comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE] == MPID_COLL_QUERY ||
281          comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE] == MPID_COLL_ALWAYS_QUERY ||
282          comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE] == MPID_COLL_CHECK_FN_REQUIRED)
283       {
284          if(my_allred_md->check_fn != NULL)
285          {
286             /* For now, we don't distinguish between MPID_COLL_ALWAYS_QUERY &
287                MPID_COLL_CHECK_FN_REQUIRED, we just call the fn                */
288             metadata_result_t result = {0};
289             TRACE_ERR("querying allreduce algorithm %s, type was %d\n",
290                my_allred_md->name,
291                comm_ptr->mpid.user_selected_type[PAMI_XFER_ALLREDUCE]);
292             result = comm_ptr->mpid.user_metadata[PAMI_XFER_ALLREDUCE].check_fn(&allred);
293             TRACE_ERR("bitmask: %#X\n", result.bitmask);
294             /* \todo Ignore check_correct.values.nonlocal until we implement the
295                'pre-allreduce allreduce' or the 'safe' environment flag.
296                We will basically assume 'safe' -- that all ranks are aligned (or not).
297             */
298             result.check.nonlocal = 0; /* #warning REMOVE THIS WHEN IMPLEMENTED */
299             if(!result.bitmask)
300                alg_selected = 1; /* query algorithm successfully selected */
301             else
302                if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
303                   fprintf(stderr,"check_fn failed for %s.\n", my_allred_md->name);
304          }
305          else /* no check_fn, manually look at the metadata fields */
306          {
307             /* Check if the message range if restricted */
308             if(my_allred_md->check_correct.values.rangeminmax)
309             {
310                MPI_Aint data_true_lb;
311                MPID_Datatype *data_ptr;
312                int data_size, data_contig;
313                MPIDI_Datatype_get_info(count, dt, data_contig, data_size, data_ptr, data_true_lb);
314                if((my_allred_md->range_lo <= data_size) &&
315                   (my_allred_md->range_hi >= data_size))
316                   alg_selected = 1; /* query algorithm successfully selected */
317                else
318                  if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
319                      fprintf(stderr,"message size (%u) outside range (%zu<->%zu) for %s.\n",
320                              data_size,
321                              my_allred_md->range_lo,
322                              my_allred_md->range_hi,
323                              my_allred_md->name);
324             }
325             /* \todo check the rest of the metadata */
326          }
327       }
328       else alg_selected = 1; /* non-query algorithm selected */
329 
330    }
331 
332    if(unlikely(!alg_selected)) /* must be fallback to MPICH */
333    {
334       if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
335          fprintf(stderr,"Using MPICH allreduce\n");
336       MPIDI_Update_last_algorithm(comm_ptr, "ALLREDUCE_MPICH");
337       return MPIR_Allreduce(sendbuf, recvbuf, count, dt, op, comm_ptr, mpierrno);
338    }
339 
340    if(unlikely(MPIDI_Process.verbose >= MPIDI_VERBOSE_DETAILS_ALL && comm_ptr->rank == 0))
341    {
342       unsigned long long int threadID;
343       MPIU_Thread_id_t tid;
344       MPIU_Thread_self(&tid);
345       threadID = (unsigned long long int)tid;
346       fprintf(stderr,"<%llx> Using protocol %s for allreduce on %u\n",
347               threadID,
348               my_allred_md->name,
349               (unsigned) comm_ptr->context_id);
350    }
351 
352    MPIDI_Post_coll_t allred_post;
353    MPIDI_Context_post(MPIDI_Context[0], &allred_post.state,
354                       MPIDI_Pami_post_wrapper, (void *)&allred);
355 
356    MPID_PROGRESS_WAIT_WHILE(active);
357    TRACE_ERR("allreduce done\n");
358    return MPI_SUCCESS;
359 }
360 
361