1 // Copyright (c) Lawrence Livermore National Security, LLC and other Conduit
2 // Project developers. See top-level LICENSE AND COPYRIGHT files for dates and
3 // other details. No copyright assignment is required to contribute to Conduit.
4 
5 //-----------------------------------------------------------------------------
6 ///
7 /// file: conduit_relay_mpi.cpp
8 ///
9 //-----------------------------------------------------------------------------
10 
11 #include "conduit_relay_mpi.hpp"
12 #include <iostream>
13 #include <limits>
14 
15 //-----------------------------------------------------------------------------
16 /// The CONDUIT_CHECK_MPI_ERROR macro is used to check return values for
17 /// mpi calls.
18 //-----------------------------------------------------------------------------
19 #define CONDUIT_CHECK_MPI_ERROR( check_mpi_err_code )               \
20 {                                                                   \
21     if( static_cast<int>(check_mpi_err_code) != MPI_SUCCESS)        \
22     {                                                               \
23         char check_mpi_err_str_buff[MPI_MAX_ERROR_STRING];          \
24         int  check_mpi_err_str_len=0;                               \
25         MPI_Error_string( check_mpi_err_code ,                      \
26                          check_mpi_err_str_buff,                    \
27                          &check_mpi_err_str_len);                   \
28                                                                     \
29         CONDUIT_ERROR("MPI call failed: \n"                         \
30                       << " error code = "                           \
31                       <<  check_mpi_err_code  << "\n"               \
32                       << " error message = "                        \
33                       <<  check_mpi_err_str_buff << "\n");          \
34         return  check_mpi_err_code;                                 \
35     }                                                               \
36 }
37 
38 
39 //-----------------------------------------------------------------------------
40 // -- begin conduit:: --
41 //-----------------------------------------------------------------------------
42 namespace conduit
43 {
44 
45 //-----------------------------------------------------------------------------
46 // -- begin conduit::relay --
47 //-----------------------------------------------------------------------------
48 namespace relay
49 {
50 
51 
52 //-----------------------------------------------------------------------------
53 // -- begin conduit::relay::mpi --
54 //-----------------------------------------------------------------------------
55 namespace mpi
56 {
57 
58 
59 //-----------------------------------------------------------------------------
60 int
size(MPI_Comm mpi_comm)61 size(MPI_Comm mpi_comm)
62 {
63     int res;
64     MPI_Comm_size(mpi_comm,&res);
65     return res;
66 };
67 
68 //-----------------------------------------------------------------------------
69 int
rank(MPI_Comm mpi_comm)70 rank(MPI_Comm mpi_comm)
71 {
72     int res;
73     MPI_Comm_rank(mpi_comm,&res);
74     return res;
75 }
76 
77 //-----------------------------------------------------------------------------
78 MPI_Datatype
conduit_dtype_to_mpi_dtype(const DataType & dt)79 conduit_dtype_to_mpi_dtype(const DataType &dt)
80 {
81     MPI_Datatype res = MPI_DATATYPE_NULL;
82 
83     // can't use switch w/ case statements here b/c NATIVE_IDS may actually
84     // be overloaded on some platforms (this happens on windows)
85 
86     index_t dt_id = dt.id();
87 
88     // signed integer types
89     if(dt_id == CONDUIT_INT8_ID)
90     {
91         res = MPI_INT8_T;
92     }
93     else if( dt_id == CONDUIT_INT16_ID)
94     {
95         res = MPI_INT16_T;
96     }
97     else if( dt_id == CONDUIT_INT32_ID)
98     {
99         res = MPI_INT32_T;
100     }
101     else if( dt_id == CONDUIT_INT64_ID)
102     {
103         res = MPI_INT64_T;
104     }
105     // unsigned integer types
106     else if( dt_id == CONDUIT_UINT8_ID)
107     {
108         res = MPI_UINT8_T;
109     }
110     else if( dt_id == CONDUIT_UINT16_ID)
111     {
112         res = MPI_UINT16_T;
113     }
114     else if( dt_id == CONDUIT_UINT32_ID)
115     {
116         res = MPI_UINT32_T;
117     }
118     else if( dt_id == CONDUIT_UINT64_ID)
119     {
120         res = MPI_UINT64_T;
121     }
122     // floating point types
123     else if( dt_id == CONDUIT_NATIVE_FLOAT_ID)
124     {
125         res = MPI_FLOAT;
126     }
127     else if( dt_id == CONDUIT_NATIVE_DOUBLE_ID)
128     {
129         res = MPI_DOUBLE;
130     }
131     #if defined(CONDUIT_USE_LONG_DOUBLE)
132     else if( dt_id == CONDUIT_NATIVE_LONG_DOUBLE_ID)
133     {
134         res = MPI_LONG_DOUBLE;
135     }
136     #endif
137     // string type
138     else if( dt_id == DataType::CHAR8_STR_ID)
139     {
140         res = MPI_CHAR;
141     }
142 
143     return res;
144 }
145 
146 
147 //-----------------------------------------------------------------------------
148 index_t
mpi_dtype_to_conduit_dtype_id(MPI_Datatype dt)149 mpi_dtype_to_conduit_dtype_id(MPI_Datatype dt)
150 {
151     index_t res = DataType::EMPTY_ID;
152 
153     // can't use switch w/ case statements here b/c in some
154     // MPI implementations MPI_Datatype is a struct (or something more complex)
155     // that won't compile when used in a switch statement.
156 
157     // string type
158     if(dt == MPI_CHAR)
159     {
160         res = DataType::CHAR8_STR_ID;
161     }
162     // mpi c bw-style signed integer types
163     if(dt == MPI_INT8_T)
164     {
165         res = CONDUIT_INT8_ID;
166     }
167     else if( dt == MPI_INT16_T)
168     {
169         res = CONDUIT_INT16_ID;
170     }
171     else if( dt == MPI_INT32_T)
172     {
173         res = CONDUIT_INT32_ID;
174     }
175     else if( dt == MPI_INT64_T)
176     {
177         res = CONDUIT_INT64_ID;
178     }
179     // mpi c bw-style unsigned integer types
180     else if( dt == MPI_UINT8_T)
181     {
182         res = CONDUIT_UINT8_ID;
183     }
184     else if( dt == MPI_UINT16_T)
185     {
186         res = CONDUIT_UINT16_ID;
187     }
188     else if( dt == MPI_UINT32_T)
189     {
190         res = CONDUIT_UINT32_ID;
191     }
192     else if( dt == MPI_UINT64_T)
193     {
194         res = CONDUIT_UINT64_ID;
195     }
196     // native c signed integer types
197     else if(dt == MPI_SHORT)
198     {
199         res = CONDUIT_NATIVE_SHORT_ID;
200     }
201     else if(dt == MPI_INT)
202     {
203         res = CONDUIT_NATIVE_INT_ID;
204     }
205     else if(dt == MPI_LONG)
206     {
207         res = CONDUIT_NATIVE_LONG_ID;
208     }
209     #if defined(CONDUIT_HAS_LONG_LONG)
210     else if(dt == MPI_LONG_LONG)
211     {
212         res = CONDUIT_NATIVE_LONG_LONG_ID;
213     }
214     #endif
215     // native c unsigned integer types
216     else if(dt == MPI_BYTE)
217     {
218         res = CONDUIT_NATIVE_UNSIGNED_CHAR_ID;
219     }
220     else if(dt == MPI_UNSIGNED_CHAR)
221     {
222         res = CONDUIT_NATIVE_UNSIGNED_CHAR_ID;
223     }
224     else if(dt == MPI_UNSIGNED_SHORT)
225     {
226         res = CONDUIT_NATIVE_UNSIGNED_SHORT_ID;
227     }
228     else if(dt == MPI_UNSIGNED)
229     {
230         res = CONDUIT_NATIVE_UNSIGNED_INT_ID;
231     }
232     else if(dt == MPI_UNSIGNED_LONG)
233     {
234         res = CONDUIT_NATIVE_UNSIGNED_LONG_ID;
235     }
236     #if defined(CONDUIT_HAS_LONG_LONG)
237     else if(dt == MPI_UNSIGNED_LONG_LONG)
238     {
239         res = CONDUIT_NATIVE_UNSIGNED_LONG_LONG_ID;
240     }
241     #endif
242     // floating point types
243     else if(dt == MPI_FLOAT)
244     {
245         res = CONDUIT_NATIVE_FLOAT_ID;
246     }
247     else if(dt == MPI_DOUBLE)
248     {
249         res = CONDUIT_NATIVE_DOUBLE_ID;
250     }
251     #if defined(CONDUIT_USE_LONG_DOUBLE)
252     else if(dt == MPI_LONG_DOUBLE)
253     {
254         res = CONDUIT_NATIVE_LONG_DOUBLE_ID;
255     }
256     #endif
257     return res;
258 }
259 
260 //---------------------------------------------------------------------------//
261 int
send_using_schema(const Node & node,int dest,int tag,MPI_Comm comm)262 send_using_schema(const Node &node, int dest, int tag, MPI_Comm comm)
263 {
264     Schema s_data_compact;
265 
266     // schema will only be valid if compact and contig
267     if( node.is_compact() && node.is_contiguous())
268     {
269         s_data_compact = node.schema();
270     }
271     else
272     {
273         node.schema().compact_to(s_data_compact);
274     }
275 
276     std::string snd_schema_json = s_data_compact.to_json();
277 
278     Schema s_msg;
279     s_msg["schema_len"].set(DataType::int64());
280     s_msg["schema"].set(DataType::char8_str(snd_schema_json.size()+1));
281     s_msg["data"].set(s_data_compact);
282 
283     // create a compact schema to use
284     Schema s_msg_compact;
285     s_msg.compact_to(s_msg_compact);
286 
287     Node n_msg(s_msg_compact);
288     // these sets won't realloc since schemas are compatible
289     n_msg["schema_len"].set((int64)snd_schema_json.length());
290     n_msg["schema"].set(snd_schema_json);
291     n_msg["data"].update(node);
292 
293 
294     index_t msg_data_size = n_msg.total_bytes_compact();
295 
296     if(!conduit::utils::value_fits<index_t,int>(msg_data_size))
297     {
298         CONDUIT_INFO("Warning size value (" << msg_data_size << ")"
299                      " exceeds the size of MPI_Send max value "
300                      "(" << std::numeric_limits<int>::max() << ")")
301     }
302 
303     int mpi_error = MPI_Send(const_cast<void*>(n_msg.data_ptr()),
304                              static_cast<int>(msg_data_size),
305                              MPI_BYTE,
306                              dest,
307                              tag,
308                              comm);
309 
310     CONDUIT_CHECK_MPI_ERROR(mpi_error);
311 
312     return mpi_error;
313 }
314 
315 
316 //---------------------------------------------------------------------------//
317 int
recv_using_schema(Node & node,int src,int tag,MPI_Comm comm)318 recv_using_schema(Node &node, int src, int tag, MPI_Comm comm)
319 {
320     MPI_Status status;
321 
322     int mpi_error = MPI_Probe(src, tag, comm, &status);
323 
324     CONDUIT_CHECK_MPI_ERROR(mpi_error);
325 
326     int buffer_size = 0;
327     MPI_Get_count(&status, MPI_BYTE, &buffer_size);
328 
329     Node n_buffer(DataType::uint8(buffer_size));
330 
331     mpi_error = MPI_Recv(n_buffer.data_ptr(),
332                          buffer_size,
333                          MPI_BYTE,
334                          src,
335                          tag,
336                          comm,
337                          &status);
338 
339     uint8 *n_buff_ptr = (uint8*)n_buffer.data_ptr();
340 
341     Node n_msg;
342     // length of the schema is sent as a 64-bit signed int
343     // NOTE: we aren't using this value  ...
344     n_msg["schema_len"].set_external((int64*)n_buff_ptr);
345     n_buff_ptr +=8;
346     // wrap the schema string
347     n_msg["schema"].set_external_char8_str((char*)(n_buff_ptr));
348     // create the schema
349     Schema rcv_schema;
350     Generator gen(n_msg["schema"].as_char8_str());
351     gen.walk(rcv_schema);
352 
353     // advance by the schema length
354     n_buff_ptr += n_msg["schema"].total_bytes_compact();
355 
356     // apply the schema to the data
357     n_msg["data"].set_external(rcv_schema,n_buff_ptr);
358 
359     // copy out to our result node
360     node.update(n_msg["data"]);
361 
362     return mpi_error;
363 }
364 
365 //---------------------------------------------------------------------------//
366 int
send(const Node & node,int dest,int tag,MPI_Comm comm)367 send(const Node &node, int dest, int tag, MPI_Comm comm)
368 {
369     // assumes size and type are known on the other end
370 
371     Node snd_compact;
372 
373     const void *snd_ptr = node.contiguous_data_ptr();;
374     index_t    snd_size = node.total_bytes_compact();;
375 
376     if( snd_ptr == NULL ||
377         ! node.is_compact())
378     {
379          node.compact_to(snd_compact);
380          snd_ptr = snd_compact.data_ptr();
381     }
382 
383     if(!conduit::utils::value_fits<index_t,int>(snd_size))
384     {
385         CONDUIT_INFO("Warning size value (" << snd_size << ")"
386                      " exceeds the size of MPI_Send max value "
387                      "(" << std::numeric_limits<int>::max() << ")")
388     }
389 
390     int mpi_error = MPI_Send(const_cast<void*>(snd_ptr),
391                              static_cast<int>(snd_size),
392                              MPI_BYTE,
393                              dest,
394                              tag,
395                              comm);
396 
397     CONDUIT_CHECK_MPI_ERROR(mpi_error);
398 
399     return mpi_error;
400 }
401 
402 //---------------------------------------------------------------------------//
403 int
recv(Node & node,int src,int tag,MPI_Comm comm)404 recv(Node &node, int src, int tag, MPI_Comm comm)
405 {
406 
407     MPI_Status status;
408     Node rcv_compact;
409 
410     bool cpy_out = false;
411 
412     const void *rcv_ptr  = node.contiguous_data_ptr();
413     index_t     rcv_size = node.total_bytes_compact();
414 
415     if( rcv_ptr == NULL  ||
416         ! node.is_compact() )
417     {
418         // we will need to update into rcv node
419         cpy_out = true;
420         Schema s_rcv_compact;
421         node.schema().compact_to(s_rcv_compact);
422         rcv_compact.set_schema(s_rcv_compact);
423         rcv_ptr  = rcv_compact.data_ptr();
424     }
425 
426     if(!conduit::utils::value_fits<index_t,int>(rcv_size))
427     {
428         CONDUIT_INFO("Warning size value (" << rcv_size << ")"
429                      " exceeds the size of MPI_Recv max value "
430                      "(" << std::numeric_limits<int>::max() << ")")
431     }
432 
433 
434     int mpi_error = MPI_Recv(const_cast<void*>(rcv_ptr),
435                              static_cast<int>(rcv_size),
436                              MPI_BYTE,
437                              src,
438                              tag,
439                              comm,
440                              &status);
441 
442     CONDUIT_CHECK_MPI_ERROR(mpi_error);
443 
444     if(cpy_out)
445     {
446         node.update(rcv_compact);
447     }
448 
449     return mpi_error;
450 }
451 
452 
453 //---------------------------------------------------------------------------//
454 int
reduce(const Node & snd_node,Node & rcv_node,MPI_Op mpi_op,int root,MPI_Comm mpi_comm)455 reduce(const Node &snd_node,
456        Node &rcv_node,
457        MPI_Op mpi_op,
458        int root,
459        MPI_Comm mpi_comm)
460 {
461     MPI_Datatype mpi_dtype = conduit_dtype_to_mpi_dtype(snd_node.dtype());
462 
463     if(mpi_dtype == MPI_DATATYPE_NULL)
464     {
465         CONDUIT_ERROR("Unsupported send DataType for mpi::reduce"
466                       << snd_node.dtype().name());
467     }
468 
469     void *snd_ptr = NULL;
470     void *rcv_ptr = NULL;
471 
472     Node snd_compact;
473     Node rcv_compact;
474 
475     //note: we don't have to ask for contig in this case, since
476     // we can only reduce leaf types
477     if(snd_node.is_compact())
478     {
479         snd_ptr = const_cast<void*>(snd_node.data_ptr());
480     }
481     else
482     {
483         snd_node.compact_to(snd_compact);
484         snd_ptr = snd_compact.data_ptr();
485     }
486 
487     bool cpy_out = false;
488 
489     int rank = mpi::rank(mpi_comm);
490 
491     if( rank == root )
492     {
493 
494         rcv_ptr = rcv_node.contiguous_data_ptr();
495 
496 
497         if( !snd_node.compatible(rcv_node) ||
498             rcv_ptr == NULL ||
499             !rcv_node.is_compact() )
500         {
501             // we will need to update into rcv node
502             cpy_out = true;
503 
504             Schema s_snd_compact;
505             snd_node.schema().compact_to(s_snd_compact);
506 
507             rcv_compact.set_schema(s_snd_compact);
508             rcv_ptr = rcv_compact.data_ptr();
509         }
510     }
511 
512     int num_eles = (int) snd_node.dtype().number_of_elements();
513 
514     int mpi_error = MPI_Reduce(snd_ptr,
515                                rcv_ptr,
516                                num_eles,
517                                mpi_dtype,
518                                mpi_op,
519                                root,
520                                mpi_comm);
521 
522     CONDUIT_CHECK_MPI_ERROR(mpi_error);
523 
524     if( rank == root  && cpy_out )
525     {
526         rcv_node.update(rcv_compact);
527     }
528 
529     return mpi_error;
530 }
531 
532 //--------------------------------------------------------------------------//
533 int
all_reduce(const Node & snd_node,Node & rcv_node,MPI_Op mpi_op,MPI_Comm mpi_comm)534 all_reduce(const Node &snd_node,
535            Node &rcv_node,
536            MPI_Op mpi_op,
537            MPI_Comm mpi_comm)
538 {
539     MPI_Datatype mpi_dtype = conduit_dtype_to_mpi_dtype(snd_node.dtype());
540 
541     if(mpi_dtype == MPI_DATATYPE_NULL)
542     {
543         CONDUIT_ERROR("Unsupported send DataType for mpi::all_reduce"
544                       << snd_node.dtype().name());
545     }
546 
547 
548     void *snd_ptr = NULL;
549     void *rcv_ptr = NULL;
550 
551     Node snd_compact;
552     Node rcv_compact;
553 
554     //note: we don't have to ask for contig in this case, since
555     // we can only reduce leaf types
556     if(snd_node.is_compact())
557     {
558         snd_ptr = const_cast<void*>(snd_node.data_ptr());
559     }
560     else
561     {
562         snd_node.compact_to(snd_compact);
563         snd_ptr = snd_compact.data_ptr();
564     }
565 
566     bool cpy_out = false;
567 
568 
569     rcv_ptr = rcv_node.contiguous_data_ptr();
570 
571 
572     if( !snd_node.compatible(rcv_node) ||
573         rcv_ptr == NULL ||
574         !rcv_node.is_compact() )
575     {
576         // we will need to update into rcv node
577         cpy_out = true;
578 
579         Schema s_snd_compact;
580         snd_node.schema().compact_to(s_snd_compact);
581 
582         rcv_compact.set_schema(s_snd_compact);
583         rcv_ptr = rcv_compact.data_ptr();
584     }
585 
586     int num_eles = (int) snd_node.dtype().number_of_elements();
587 
588     int mpi_error = MPI_Allreduce(snd_ptr,
589                                   rcv_ptr,
590                                   num_eles,
591                                   mpi_dtype,
592                                   mpi_op,
593                                   mpi_comm);
594 
595     CONDUIT_CHECK_MPI_ERROR(mpi_error);
596 
597     if(cpy_out)
598     {
599         rcv_node.update(rcv_compact);
600     }
601 
602 
603     return mpi_error;
604 }
605 
606 //-- reduce helpers -- //
607 
608 //---------------------------------------------------------------------------//
609 int
sum_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)610 sum_reduce(const Node &snd_node,
611            Node &rcv_node,
612            int root,
613            MPI_Comm mpi_comm)
614 {
615     return reduce(snd_node,
616                   rcv_node,
617                   MPI_SUM,
618                   root,
619                   mpi_comm);
620 }
621 
622 
623 //---------------------------------------------------------------------------//
624 int
min_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)625 min_reduce(const Node &snd_node,
626            Node &rcv_node,
627            int root,
628            MPI_Comm mpi_comm)
629 {
630     return reduce(snd_node,
631                   rcv_node,
632                   MPI_MIN,
633                   root,
634                   mpi_comm);
635 }
636 
637 
638 
639 //---------------------------------------------------------------------------//
640 int
max_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)641 max_reduce(const Node &snd_node,
642            Node &rcv_node,
643            int root,
644            MPI_Comm mpi_comm)
645 {
646     return reduce(snd_node,
647                   rcv_node,
648                   MPI_MAX,
649                   root,
650                   mpi_comm);
651 }
652 
653 
654 
655 //---------------------------------------------------------------------------//
656 int
prod_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)657 prod_reduce(const Node &snd_node,
658             Node &rcv_node,
659             int root,
660             MPI_Comm mpi_comm)
661 {
662     return reduce(snd_node,
663                   rcv_node,
664                   MPI_PROD,
665                   root,
666                   mpi_comm);
667 }
668 
669 
670 
671 //--- all reduce helpers -- /
672 //---------------------------------------------------------------------------//
673 int
sum_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)674 sum_all_reduce(const Node &snd_node,
675                Node &rcv_node,
676                MPI_Comm mpi_comm)
677 {
678     return all_reduce(snd_node,
679                       rcv_node,
680                       MPI_SUM,
681                       mpi_comm);
682 }
683 
684 
685 //---------------------------------------------------------------------------//
686 int
min_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)687 min_all_reduce(const Node &snd_node,
688                Node &rcv_node,
689                MPI_Comm mpi_comm)
690 {
691     return all_reduce(snd_node,
692                       rcv_node,
693                       MPI_MIN,
694                       mpi_comm);
695 
696 }
697 
698 
699 
700 //---------------------------------------------------------------------------//
701 int
max_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)702 max_all_reduce(const Node &snd_node,
703                Node &rcv_node,
704                MPI_Comm mpi_comm)
705 {
706     return all_reduce(snd_node,
707                       rcv_node,
708                       MPI_MAX,
709                       mpi_comm);
710 
711 }
712 
713 
714 //---------------------------------------------------------------------------//
715 int
prod_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)716 prod_all_reduce(const Node &snd_node,
717                 Node &rcv_node,
718                 MPI_Comm mpi_comm)
719 {
720     return all_reduce(snd_node,
721                       rcv_node,
722                       MPI_PROD,
723                       mpi_comm);
724 
725 }
726 
727 
728 
729 //---------------------------------------------------------------------------//
730 int
isend(const Node & node,int dest,int tag,MPI_Comm mpi_comm,Request * request)731 isend(const Node &node,
732       int dest,
733       int tag,
734       MPI_Comm mpi_comm,
735       Request *request)
736 {
737 
738     const void *data_ptr  = node.contiguous_data_ptr();
739     index_t     data_size = node.total_bytes_compact();
740 
741     // note: this checks for both compact and contig
742     if( data_ptr == NULL ||
743        !node.is_compact() )
744     {
745         node.compact_to(request->m_buffer);
746         data_ptr  = request->m_buffer.data_ptr();
747     }
748 
749     // for wait_all,  this must always be NULL except for
750     // the irecv cases where copy out is necessary
751     // isend case must always be NULL
752     request->m_rcv_ptr = NULL;
753 
754 
755     if(!conduit::utils::value_fits<index_t,int>(data_size))
756     {
757         CONDUIT_INFO("Warning size value (" << data_size << ")"
758                      " exceeds the size of MPI_Isend max value "
759                      "(" << std::numeric_limits<int>::max() << ")")
760     }
761 
762     int mpi_error =  MPI_Isend(const_cast<void*>(data_ptr),
763                                static_cast<int>(data_size),
764                                MPI_BYTE,
765                                dest,
766                                tag,
767                                mpi_comm,
768                                &(request->m_request));
769 
770     CONDUIT_CHECK_MPI_ERROR(mpi_error);
771     return mpi_error;
772 }
773 
774 //---------------------------------------------------------------------------//
775 int
irecv(Node & node,int src,int tag,MPI_Comm mpi_comm,Request * request)776 irecv(Node &node,
777       int src,
778       int tag,
779       MPI_Comm mpi_comm,
780       Request *request)
781 {
782     // if rcv is compact, we can write directly into recv
783     // if it's not compact, we need a recv_buffer
784 
785     void    *data_ptr  = node.contiguous_data_ptr();
786     index_t  data_size = node.total_bytes_compact();
787 
788     // note: this checks for both compact and contig
789     if(data_ptr != NULL &&
790        node.is_compact() )
791     {
792         // for wait_all,  this must always be NULL except for
793         // the irecv cases where copy out is necessary
794         request->m_rcv_ptr = NULL;
795     }
796     else
797     {
798         node.compact_to(request->m_buffer);
799         data_ptr  = request->m_buffer.data_ptr();
800         request->m_rcv_ptr = &node;
801     }
802 
803     if(!conduit::utils::value_fits<index_t,int>(data_size))
804     {
805         CONDUIT_INFO("Warning size value (" << data_size << ")"
806                      " exceeds the size of MPI_Irecv max value "
807                      "(" << std::numeric_limits<int>::max() << ")")
808     }
809 
810     int mpi_error =  MPI_Irecv(data_ptr,
811                                static_cast<int>(data_size),
812                                MPI_BYTE,
813                                src,
814                                tag,
815                                mpi_comm,
816                                &(request->m_request));
817 
818     CONDUIT_CHECK_MPI_ERROR(mpi_error);
819     return mpi_error;
820 }
821 
822 
823 //---------------------------------------------------------------------------//
824 // wait handles both send and recv requests
825 int
wait(Request * request,MPI_Status * status)826 wait(Request *request,
827      MPI_Status *status)
828 {
829     int mpi_error = MPI_Wait(&(request->m_request), status);
830     CONDUIT_CHECK_MPI_ERROR(mpi_error);
831 
832     // we need to update if m_rcv_ptr was used
833     // this will only be non NULL in the recv copy out case,
834     // sends will always be NULL
835     if(request->m_rcv_ptr)
836     {
837         request->m_rcv_ptr->update(request->m_buffer);
838     }
839 
840     request->m_buffer.reset();
841     request->m_rcv_ptr = NULL;
842 
843     return mpi_error;
844 }
845 
846 //---------------------------------------------------------------------------//
847 int
wait_send(Request * request,MPI_Status * status)848 wait_send(Request *request,
849           MPI_Status *status)
850 {
851     return wait(request,status);
852 }
853 
854 //---------------------------------------------------------------------------//
855 int
wait_recv(Request * request,MPI_Status * status)856 wait_recv(Request *request,
857           MPI_Status *status)
858 {
859     return wait(request,status);
860 }
861 
862 
863 //---------------------------------------------------------------------------//
864 // wait all handles mixed batches of sends and receives
865 int
wait_all(int count,Request requests[],MPI_Status statuses[])866 wait_all(int count,
867          Request requests[],
868          MPI_Status statuses[])
869 {
870      MPI_Request *justrequests = new MPI_Request[count];
871 
872      for (int i = 0; i < count; ++i)
873      {
874          // mpi requests can be simply copied
875          justrequests[i] = requests[i].m_request;
876      }
877 
878      int mpi_error = MPI_Waitall(count, justrequests, statuses);
879      CONDUIT_CHECK_MPI_ERROR(mpi_error);
880 
881      for (int i = 0; i < count; ++i)
882      {
883          // if this request is a recv, we need to check for copy out
884          // m_rcv_ptr will always be NULL, unless we have done a
885          // irecv where we need to use the pointer.
886          if(requests[i].m_rcv_ptr != NULL)
887          {
888              requests[i].m_rcv_ptr->update(requests[i].m_buffer);
889              requests[i].m_rcv_ptr = NULL;
890          }
891 
892          requests[i].m_request = justrequests[i];
893          requests[i].m_buffer.reset();
894      }
895 
896      delete [] justrequests;
897 
898      return mpi_error;
899 }
900 
901 //---------------------------------------------------------------------------//
902 int
wait_all_send(int count,Request requests[],MPI_Status statuses[])903 wait_all_send(int count,
904               Request requests[],
905               MPI_Status statuses[])
906 {
907     return wait_all(count,requests,statuses);
908 }
909 
910 //---------------------------------------------------------------------------//
911 int
wait_all_recv(int count,Request requests[],MPI_Status statuses[])912 wait_all_recv(int count,
913               Request requests[],
914               MPI_Status statuses[])
915 {
916    return  wait_all(count,requests,statuses);
917 }
918 
919 
920 //---------------------------------------------------------------------------//
921 int
gather(Node & send_node,Node & recv_node,int root,MPI_Comm mpi_comm)922 gather(Node &send_node,
923        Node &recv_node,
924        int root,
925        MPI_Comm mpi_comm)
926 {
927     Node   n_snd_compact;
928     Schema s_snd_compact;
929 
930     send_node.schema().compact_to(s_snd_compact);
931 
932     const void *snd_ptr = send_node.contiguous_data_ptr();
933     index_t    snd_size = 0;
934 
935 
936     if(snd_ptr != NULL &&
937        send_node.is_compact() )
938     {
939         snd_ptr  = send_node.data_ptr();
940         snd_size = send_node.total_bytes_compact();
941     }
942     else
943     {
944         send_node.compact_to(n_snd_compact);
945         snd_ptr  = n_snd_compact.data_ptr();
946         snd_size = n_snd_compact.total_bytes_compact();
947     }
948 
949     int mpi_rank = mpi::rank(mpi_comm);
950     int mpi_size = mpi::size(mpi_comm);
951 
952     if(mpi_rank == root)
953     {
954         // TODO: copy out support w/o always reallocing?
955         recv_node.list_of(s_snd_compact,
956                           mpi_size);
957     }
958 
959     if(!conduit::utils::value_fits<index_t,int>(snd_size))
960     {
961         CONDUIT_INFO("Warning size value (" << snd_size << ")"
962                      " exceeds the size of MPI_Gather max value "
963                      "(" << std::numeric_limits<int>::max() << ")")
964     }
965 
966     int mpi_error = MPI_Gather( const_cast<void*>(snd_ptr), // local data
967                                 static_cast<int>(snd_size), // local data len
968                                 MPI_BYTE, // send chars
969                                 recv_node.data_ptr(),  // rcv buffer
970                                 static_cast<int>(snd_size), // data len
971                                 MPI_BYTE,  // rcv chars
972                                 root,
973                                 mpi_comm); // mpi com
974 
975     CONDUIT_CHECK_MPI_ERROR(mpi_error);
976 
977     return mpi_error;
978 }
979 
980 //---------------------------------------------------------------------------//
981 int
all_gather(Node & send_node,Node & recv_node,MPI_Comm mpi_comm)982 all_gather(Node &send_node,
983            Node &recv_node,
984            MPI_Comm mpi_comm)
985 {
986     Node   n_snd_compact;
987     Schema s_snd_compact;
988 
989     send_node.schema().compact_to(s_snd_compact);
990 
991     const void *snd_ptr  = send_node.contiguous_data_ptr();
992     index_t     snd_size = send_node.total_bytes_compact();
993 
994     if( snd_ptr == NULL ||
995        !send_node.is_compact() )
996     {
997         send_node.compact_to(n_snd_compact);
998         snd_ptr  = n_snd_compact.data_ptr();
999     }
1000 
1001 
1002     int mpi_size = mpi::size(mpi_comm);
1003 
1004 
1005     // TODO: copy out support w/o always reallocing?
1006     // TODO: what about common case of scatter w/ leaf types?
1007     //       instead of list_of, we would have a leaf of
1008     //       of a given type w/ # of elements == # of ranks.
1009     recv_node.list_of(n_snd_compact.schema(),
1010                       mpi_size);
1011 
1012     if(!conduit::utils::value_fits<index_t,int>(snd_size))
1013     {
1014         CONDUIT_INFO("Warning size value (" << snd_size << ")"
1015                      " exceeds the size of MPI_Gather max value "
1016                      "(" << std::numeric_limits<int>::max() << ")")
1017     }
1018 
1019     int mpi_error = MPI_Allgather( const_cast<void*>(snd_ptr), // local data
1020                                    static_cast<int>(snd_size), // local data len
1021                                    MPI_BYTE, // send chars
1022                                    recv_node.data_ptr(),  // rcv buffer
1023                                    static_cast<int>(snd_size), // data len
1024                                    MPI_BYTE,  // rcv chars
1025                                    mpi_comm); // mpi com
1026 
1027     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1028 
1029     return mpi_error;
1030 }
1031 
1032 
1033 
1034 //---------------------------------------------------------------------------//
1035 int
gather_using_schema(Node & send_node,Node & recv_node,int root,MPI_Comm mpi_comm)1036 gather_using_schema(Node &send_node,
1037                     Node &recv_node,
1038                     int root,
1039                     MPI_Comm mpi_comm)
1040 {
1041     Node n_snd_compact;
1042     send_node.compact_to(n_snd_compact);
1043 
1044     int m_size = mpi::size(mpi_comm);
1045     int m_rank = mpi::rank(mpi_comm);
1046 
1047     std::string schema_str = n_snd_compact.schema().to_json();
1048 
1049     int schema_len = static_cast<int>(schema_str.length() + 1);
1050     int data_len   = static_cast<int>(n_snd_compact.total_bytes_compact());
1051 
1052     // to do the conduit gatherv, first need a gather to get the
1053     // schema and data buffer sizes
1054 
1055     int snd_sizes[] = {schema_len, data_len};
1056 
1057     Node n_rcv_sizes;
1058 
1059     if( m_rank == root )
1060     {
1061         Schema s;
1062         s["schema_len"].set(DataType::c_int());
1063         s["data_len"].set(DataType::c_int());
1064         n_rcv_sizes.list_of(s,m_size);
1065     }
1066 
1067     int mpi_error = MPI_Gather( snd_sizes, // local data
1068                                 2, // two ints per rank
1069                                 MPI_INT, // send ints
1070                                 n_rcv_sizes.data_ptr(),  // rcv buffer
1071                                 2,  // two ints per rank
1072                                 MPI_INT,  // rcv ints
1073                                 root,  // id of root for gather op
1074                                 mpi_comm); // mpi com
1075 
1076     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1077 
1078     Node n_rcv_tmp;
1079 
1080     int  *schema_rcv_counts = NULL;
1081     int  *schema_rcv_displs = NULL;
1082     char *schema_rcv_buff   = NULL;
1083 
1084     int  *data_rcv_counts = NULL;
1085     int  *data_rcv_displs = NULL;
1086     char *data_rcv_buff   = NULL;
1087 
1088     // we only need rcv params on the gather root
1089     if( m_rank == root )
1090     {
1091         // alloc data for the mpi gather counts and displ arrays
1092         n_rcv_tmp["schemas/counts"].set(DataType::c_int(m_size));
1093         n_rcv_tmp["schemas/displs"].set(DataType::c_int(m_size));
1094 
1095         n_rcv_tmp["data/counts"].set(DataType::c_int(m_size));
1096         n_rcv_tmp["data/displs"].set(DataType::c_int(m_size));
1097 
1098         // get pointers to counts and displs
1099         schema_rcv_counts = n_rcv_tmp["schemas/counts"].value();
1100         schema_rcv_displs = n_rcv_tmp["schemas/displs"].value();
1101 
1102         data_rcv_counts = n_rcv_tmp["data/counts"].value();
1103         data_rcv_displs = n_rcv_tmp["data/displs"].value();
1104 
1105         int schema_curr_displ = 0;
1106         int data_curr_displ   = 0;
1107         int i=0;
1108 
1109         NodeIterator itr = n_rcv_sizes.children();
1110         while(itr.has_next())
1111         {
1112             Node &curr = itr.next();
1113 
1114             int schema_curr_count = curr["schema_len"].value();
1115             int data_curr_count   = curr["data_len"].value();
1116 
1117             schema_rcv_counts[i] = schema_curr_count;
1118             schema_rcv_displs[i] = schema_curr_displ;
1119             schema_curr_displ   += schema_curr_count;
1120 
1121             data_rcv_counts[i] = data_curr_count;
1122             data_rcv_displs[i] = data_curr_displ;
1123             data_curr_displ   += data_curr_count;
1124 
1125             i++;
1126         }
1127 
1128         n_rcv_tmp["schemas/data"].set(DataType::c_char(schema_curr_displ));
1129         schema_rcv_buff = n_rcv_tmp["schemas/data"].value();
1130     }
1131 
1132     mpi_error = MPI_Gatherv( const_cast <char*>(schema_str.c_str()),
1133                              schema_len,
1134                              MPI_BYTE,
1135                              schema_rcv_buff,
1136                              schema_rcv_counts,
1137                              schema_rcv_displs,
1138                              MPI_BYTE,
1139                              root,
1140                              mpi_comm);
1141 
1142     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1143 
1144     // build all schemas from JSON, compact them.
1145     Schema rcv_schema;
1146     if( m_rank == root )
1147     {
1148         //TODO: should we make it easer to create a compact schema?
1149         // TODO: Revisit, I think we can do this better
1150 
1151         Schema s_tmp;
1152         for(int i=0;i < m_size; i++)
1153         {
1154             Schema &s = s_tmp.append();
1155             s.set(&schema_rcv_buff[schema_rcv_displs[i]]);
1156         }
1157 
1158         s_tmp.compact_to(rcv_schema);
1159     }
1160 
1161 
1162     if( m_rank == root )
1163     {
1164         // allocate data to hold the gather result
1165         // TODO can we support copy out w/out realloc
1166         recv_node.set(rcv_schema);
1167         data_rcv_buff = (char*)recv_node.data_ptr();
1168     }
1169 
1170     mpi_error = MPI_Gatherv( n_snd_compact.data_ptr(),
1171                              data_len,
1172                              MPI_BYTE,
1173                              data_rcv_buff,
1174                              data_rcv_counts,
1175                              data_rcv_displs,
1176                              MPI_BYTE,
1177                              root,
1178                              mpi_comm);
1179 
1180     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1181 
1182     return mpi_error;
1183 }
1184 
1185 //---------------------------------------------------------------------------//
1186 int
all_gather_using_schema(Node & send_node,Node & recv_node,MPI_Comm mpi_comm)1187 all_gather_using_schema(Node &send_node,
1188                         Node &recv_node,
1189                         MPI_Comm mpi_comm)
1190 {
1191     Node n_snd_compact;
1192     send_node.compact_to(n_snd_compact);
1193 
1194     int m_size = mpi::size(mpi_comm);
1195 
1196     std::string schema_str = n_snd_compact.schema().to_json();
1197 
1198     int schema_len = static_cast<int>(schema_str.length() + 1);
1199     int data_len   = static_cast<int>(n_snd_compact.total_bytes_compact());
1200 
1201     // to do the conduit gatherv, first need a gather to get the
1202     // schema and data buffer sizes
1203 
1204     int snd_sizes[] = {schema_len, data_len};
1205 
1206     Node n_rcv_sizes;
1207 
1208     Schema s;
1209     s["schema_len"].set(DataType::c_int());
1210     s["data_len"].set(DataType::c_int());
1211     n_rcv_sizes.list_of(s,m_size);
1212 
1213     int mpi_error = MPI_Allgather( snd_sizes, // local data
1214                                    2, // two ints per rank
1215                                    MPI_INT, // send ints
1216                                    n_rcv_sizes.data_ptr(),  // rcv buffer
1217                                    2,  // two ints per rank
1218                                    MPI_INT,  // rcv ints
1219                                    mpi_comm); // mpi com
1220 
1221     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1222 
1223     Node n_rcv_tmp;
1224 
1225     int  *schema_rcv_counts = NULL;
1226     int  *schema_rcv_displs = NULL;
1227     char *schema_rcv_buff   = NULL;
1228 
1229     int  *data_rcv_counts = NULL;
1230     int  *data_rcv_displs = NULL;
1231     char *data_rcv_buff   = NULL;
1232 
1233 
1234     // alloc data for the mpi gather counts and displ arrays
1235     n_rcv_tmp["schemas/counts"].set(DataType::c_int(m_size));
1236     n_rcv_tmp["schemas/displs"].set(DataType::c_int(m_size));
1237 
1238     n_rcv_tmp["data/counts"].set(DataType::c_int(m_size));
1239     n_rcv_tmp["data/displs"].set(DataType::c_int(m_size));
1240 
1241     // get pointers to counts and displs
1242     schema_rcv_counts = n_rcv_tmp["schemas/counts"].value();
1243     schema_rcv_displs = n_rcv_tmp["schemas/displs"].value();
1244 
1245     data_rcv_counts = n_rcv_tmp["data/counts"].value();
1246     data_rcv_displs = n_rcv_tmp["data/displs"].value();
1247 
1248     int schema_curr_displ = 0;
1249     int data_curr_displ   = 0;
1250 
1251     NodeIterator itr = n_rcv_sizes.children();
1252 
1253     index_t child_idx = 0;
1254 
1255     while(itr.has_next())
1256     {
1257         Node &curr = itr.next();
1258 
1259         int schema_curr_count = curr["schema_len"].value();
1260         int data_curr_count   = curr["data_len"].value();
1261 
1262         schema_rcv_counts[child_idx] = schema_curr_count;
1263         schema_rcv_displs[child_idx] = schema_curr_displ;
1264         schema_curr_displ   += schema_curr_count;
1265 
1266         data_rcv_counts[child_idx] = data_curr_count;
1267         data_rcv_displs[child_idx] = data_curr_displ;
1268         data_curr_displ   += data_curr_count;
1269 
1270         child_idx+=1;
1271     }
1272 
1273     n_rcv_tmp["schemas/data"].set(DataType::c_char(schema_curr_displ));
1274     schema_rcv_buff = n_rcv_tmp["schemas/data"].value();
1275 
1276     mpi_error = MPI_Allgatherv( const_cast <char*>(schema_str.c_str()),
1277                                 schema_len,
1278                                 MPI_BYTE,
1279                                 schema_rcv_buff,
1280                                 schema_rcv_counts,
1281                                 schema_rcv_displs,
1282                                 MPI_BYTE,
1283                                 mpi_comm);
1284 
1285     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1286 
1287     // build all schemas from JSON, compact them.
1288     Schema rcv_schema;
1289     //TODO: should we make it easer to create a compact schema?
1290     // TODO: Revisit, I think we can do this better
1291     Schema s_tmp;
1292     for(int s_idx=0; s_idx < m_size; s_idx++)
1293     {
1294         Schema &s_new = s_tmp.append();
1295         s_new.set(&schema_rcv_buff[schema_rcv_displs[s_idx]]);
1296     }
1297 
1298     // TODO can we support copy out w/out realloc
1299     s_tmp.compact_to(rcv_schema);
1300 
1301     // allocate data to hold the gather result
1302     recv_node.set(rcv_schema);
1303     data_rcv_buff = (char*)recv_node.data_ptr();
1304 
1305     mpi_error = MPI_Allgatherv( n_snd_compact.data_ptr(),
1306                                 data_len,
1307                                 MPI_BYTE,
1308                                 data_rcv_buff,
1309                                 data_rcv_counts,
1310                                 data_rcv_displs,
1311                                 MPI_BYTE,
1312                                 mpi_comm);
1313 
1314     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1315 
1316     return mpi_error;
1317 }
1318 
1319 
1320 //---------------------------------------------------------------------------//
1321 int
broadcast(Node & node,int root,MPI_Comm comm)1322 broadcast(Node &node,
1323           int root,
1324           MPI_Comm comm)
1325 {
1326     int rank = mpi::rank(comm);
1327 
1328     Node bcast_buffer;
1329 
1330     bool cpy_out = false;
1331 
1332     void    *bcast_data_ptr  = node.contiguous_data_ptr();
1333     index_t  bcast_data_size = node.total_bytes_compact();
1334 
1335     // setup buffers on root for send
1336     if(rank == root)
1337     {
1338         if( bcast_data_ptr == NULL ||
1339             ! node.is_compact() )
1340         {
1341             node.compact_to(bcast_buffer);
1342             bcast_data_ptr  = bcast_buffer.data_ptr();
1343         }
1344 
1345     }
1346     else // rank != root,  setup buffers on non root for rcv
1347     {
1348         if( bcast_data_ptr == NULL ||
1349             ! node.is_compact() )
1350         {
1351             Schema s_compact;
1352             node.schema().compact_to(s_compact);
1353             bcast_buffer.set_schema(s_compact);
1354 
1355             bcast_data_ptr  = bcast_buffer.data_ptr();
1356             cpy_out = true;
1357         }
1358     }
1359 
1360 
1361     if(!conduit::utils::value_fits<index_t,int>(bcast_data_size))
1362     {
1363         CONDUIT_INFO("Warning size value (" << bcast_data_size << ")"
1364                      " exceeds the size of MPI_Bcast max value "
1365                      "(" << std::numeric_limits<int>::max() << ")")
1366     }
1367 
1368 
1369     int mpi_error = MPI_Bcast(bcast_data_ptr,
1370                               static_cast<int>(bcast_data_size),
1371                               MPI_BYTE,
1372                               root,
1373                               comm);
1374 
1375     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1376 
1377     // note: cpy_out will always be false when rank == root
1378     if( cpy_out )
1379     {
1380         node.update(bcast_buffer);
1381     }
1382 
1383     return mpi_error;
1384 }
1385 
1386 //---------------------------------------------------------------------------//
1387 int
broadcast_using_schema(Node & node,int root,MPI_Comm comm)1388 broadcast_using_schema(Node &node,
1389                        int root,
1390                        MPI_Comm comm)
1391 {
1392     int rank = mpi::rank(comm);
1393 
1394     Node bcast_buffers;
1395 
1396     void *bcast_data_ptr = NULL;
1397     int   bcast_data_size = 0;
1398 
1399     int bcast_schema_size = 0;
1400     int rcv_bcast_schema_size = 0;
1401 
1402     // setup buffers for send
1403     if(rank == root)
1404     {
1405 
1406         bcast_data_ptr  = node.contiguous_data_ptr();
1407         bcast_data_size = static_cast<int>(node.total_bytes_compact());
1408 
1409         if(bcast_data_ptr != NULL &&
1410            node.is_compact() &&
1411            node.is_contiguous())
1412         {
1413             bcast_buffers["schema"] = node.schema().to_json();
1414         }
1415         else
1416         {
1417             Node &bcast_data_compact = bcast_buffers["data"];
1418             node.compact_to(bcast_data_compact);
1419 
1420             bcast_data_ptr  = bcast_data_compact.data_ptr();
1421             bcast_buffers["schema"] =  bcast_data_compact.schema().to_json();
1422         }
1423 
1424 
1425 
1426         bcast_schema_size = static_cast<int>(bcast_buffers["schema"].dtype().number_of_elements());
1427     }
1428 
1429     int mpi_error = MPI_Allreduce(&bcast_schema_size,
1430                                   &rcv_bcast_schema_size,
1431                                   1,
1432                                   MPI_INT,
1433                                   MPI_MAX,
1434                                   comm);
1435 
1436     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1437 
1438     bcast_schema_size = rcv_bcast_schema_size;
1439 
1440     // alloc for rcv for schema
1441     if(rank != root)
1442     {
1443         bcast_buffers["schema"].set(DataType::char8_str(bcast_schema_size));
1444     }
1445 
1446     // broadcast the schema
1447     mpi_error = MPI_Bcast(bcast_buffers["schema"].data_ptr(),
1448                           bcast_schema_size,
1449                           MPI_CHAR,
1450                           root,
1451                           comm);
1452 
1453     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1454 
1455     bool cpy_out = false;
1456 
1457     // setup buffers for receive
1458     if(rank != root)
1459     {
1460         Schema bcast_schema;
1461         Generator gen(bcast_buffers["schema"].as_char8_str());
1462         gen.walk(bcast_schema);
1463 
1464         // only check compat for leaves
1465         // there are more zero copy cases possible here, but
1466         // we need a better way to identify them
1467         // compatible won't work for object cases that
1468         // have different named leaves
1469         if( !(node.dtype().is_empty() ||
1470               node.dtype().is_object() ||
1471               node.dtype().is_list() ) &&
1472             !(bcast_schema.dtype().is_empty() ||
1473               bcast_schema.dtype().is_object() ||
1474               bcast_schema.dtype().is_list() )
1475             && bcast_schema.compatible(node.schema()))
1476         {
1477 
1478             bcast_data_ptr  = node.contiguous_data_ptr();
1479             bcast_data_size = static_cast<int>(node.total_bytes_compact());
1480 
1481             if( bcast_data_ptr == NULL ||
1482                 ! node.is_compact() )
1483             {
1484                 Node &bcast_data_buffer = bcast_buffers["data"];
1485                 bcast_data_buffer.set_schema(bcast_schema);
1486 
1487                 bcast_data_ptr  = bcast_data_buffer.data_ptr();
1488                 cpy_out = true;
1489             }
1490         }
1491         else
1492         {
1493             node.set_schema(bcast_schema);
1494 
1495             bcast_data_ptr  = node.data_ptr();
1496             bcast_data_size = static_cast<int>(node.total_bytes_compact());
1497         }
1498     }
1499 
1500     mpi_error = MPI_Bcast(bcast_data_ptr,
1501                           bcast_data_size,
1502                           MPI_BYTE,
1503                           root,
1504                           comm);
1505 
1506     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1507 
1508     // note: cpy_out will always be false when rank == root
1509     if( cpy_out )
1510     {
1511         node.update(bcast_buffers["data"]);
1512     }
1513 
1514     return mpi_error;
1515 }
1516 
1517 //-----------------------------------------------------------------------------
1518 //-----------------------------------------------------------------------------
1519 const int communicate_using_schema::OP_SEND = 1;
1520 const int communicate_using_schema::OP_RECV = 2;
1521 
1522 //-----------------------------------------------------------------------------
communicate_using_schema(MPI_Comm c)1523 communicate_using_schema::communicate_using_schema(MPI_Comm c) :
1524     comm(c), operations(), logging(false)
1525 {
1526 }
1527 
1528 //-----------------------------------------------------------------------------
~communicate_using_schema()1529 communicate_using_schema::~communicate_using_schema()
1530 {
1531     clear();
1532 }
1533 
1534 //-----------------------------------------------------------------------------
1535 void
clear()1536 communicate_using_schema::clear()
1537 {
1538     for(size_t i = 0; i < operations.size(); i++)
1539     {
1540         if(operations[i].free[0])
1541             delete operations[i].node[0];
1542         if(operations[i].free[1])
1543             delete operations[i].node[1];
1544     }
1545     operations.clear();
1546 }
1547 
1548 //-----------------------------------------------------------------------------
1549 void
set_logging(bool val)1550 communicate_using_schema::set_logging(bool val)
1551 {
1552     logging = val;
1553 }
1554 
1555 //-----------------------------------------------------------------------------
1556 void
add_isend(const Node & node,int dest,int tag)1557 communicate_using_schema::add_isend(const Node &node, int dest, int tag)
1558 {
1559     // Append the work to the operations.
1560     operation work;
1561     work.op = OP_SEND;
1562     work.rank = dest;
1563     work.tag = tag;
1564     work.node[0] = const_cast<Node *>(&node); // The node we're sending.
1565     work.free[0] = false;
1566     work.node[1] = nullptr;
1567     work.free[1] = false;
1568     operations.push_back(work);
1569 }
1570 
1571 //-----------------------------------------------------------------------------
1572 void
add_irecv(Node & node,int src,int tag)1573 communicate_using_schema::add_irecv(Node &node, int src, int tag)
1574 {
1575     // Append the work to the operations.
1576     operation work;
1577     work.op = OP_RECV;
1578     work.rank = src;
1579     work.tag = tag;
1580     work.node[0] = &node; // Node that will contain final data.
1581     work.free[0] = false; // Don't need to free it.
1582     work.node[1] = nullptr;
1583     work.free[1] = false;
1584     operations.push_back(work);
1585 }
1586 
1587 //-----------------------------------------------------------------------------
1588 int
execute()1589 communicate_using_schema::execute()
1590 {
1591     int mpi_error = 0;
1592     std::vector<MPI_Request> requests(operations.size());
1593     std::vector<MPI_Status>  statuses(operations.size());
1594 
1595     int rank, size;
1596     MPI_Comm_rank(comm, &rank);
1597     MPI_Comm_size(comm, &size);
1598     std::ofstream log;
1599     double t0 = MPI_Wtime();
1600     if(logging)
1601     {
1602         char fn[128];
1603         sprintf(fn, "communicate_using_schema.%04d.log", rank);
1604         log.open(fn, std::ofstream::out);
1605         log << "* Log started on rank " << rank << " at " << t0 << std::endl;
1606     }
1607 
1608     // Issue all the sends (so they are in flight by the time we probe them)
1609     for(size_t i = 0; i < operations.size(); i++)
1610     {
1611         if(operations[i].op == OP_SEND)
1612         {
1613             Schema s_data_compact;
1614             const Node &node = *operations[i].node[0];
1615             // schema will only be valid if compact and contig
1616             if( node.is_compact() && node.is_contiguous())
1617             {
1618                 s_data_compact = node.schema();
1619             }
1620             else
1621             {
1622                 node.schema().compact_to(s_data_compact);
1623             }
1624 
1625             std::string snd_schema_json = s_data_compact.to_json();
1626 
1627             Schema s_msg;
1628             s_msg["schema_len"].set(DataType::int64());
1629             s_msg["schema"].set(DataType::char8_str(snd_schema_json.size()+1));
1630             s_msg["data"].set(s_data_compact);
1631 
1632             // create a compact schema to use
1633             Schema s_msg_compact;
1634             s_msg.compact_to(s_msg_compact);
1635 
1636             operations[i].node[1] = new Node(s_msg_compact);
1637             operations[i].free[1] = true;
1638             Node &n_msg = *operations[i].node[1];
1639             // these sets won't realloc since schemas are compatible
1640             n_msg["schema_len"].set((int64)snd_schema_json.length());
1641             n_msg["schema"].set(snd_schema_json);
1642             n_msg["data"].update(node);
1643 
1644             // Send the serialized node data.
1645             index_t msg_data_size = operations[i].node[1]->total_bytes_compact();
1646             if(logging)
1647             {
1648                 log << "    MPI_Isend("
1649                     << const_cast<void*>(operations[i].node[1]->data_ptr()) << ", "
1650                     << msg_data_size << ", "
1651                     << "MPI_BYTE, "
1652                     << operations[i].rank << ", "
1653                     << operations[i].tag << ", "
1654                     << "comm, &requests[" << i << "]);" << std::endl;
1655             }
1656 
1657             if(!conduit::utils::value_fits<index_t,int>(msg_data_size))
1658             {
1659                 CONDUIT_INFO("Warning size value (" << msg_data_size << ")"
1660                              " exceeds the size of MPI_Isend max value "
1661                              "(" << std::numeric_limits<int>::max() << ")")
1662             }
1663 
1664             mpi_error = MPI_Isend(const_cast<void*>(operations[i].node[1]->data_ptr()),
1665                                   static_cast<int>(msg_data_size),
1666                                   MPI_BYTE,
1667                                   operations[i].rank,
1668                                   operations[i].tag,
1669                                   comm,
1670                                   &requests[i]);
1671             CONDUIT_CHECK_MPI_ERROR(mpi_error);
1672         }
1673     }
1674     double t1 = MPI_Wtime();
1675     if(logging)
1676     {
1677         log << "* Time issuing MPI_Isend calls: " << (t1-t0) << std::endl;
1678     }
1679 
1680     // Issue all the recvs.
1681     for(size_t i = 0; i < operations.size(); i++)
1682     {
1683         if(operations[i].op == OP_RECV)
1684         {
1685             // Probe the message for its buffer size.
1686             if(logging)
1687             {
1688                 log << "    MPI_Probe("
1689                     << operations[i].rank << ", "
1690                     << operations[i].tag << ", "
1691                     << "comm, &statuses[" << i << "]);" << std::endl;
1692             }
1693             mpi_error = MPI_Probe(operations[i].rank, operations[i].tag, comm, &statuses[i]);
1694             CONDUIT_CHECK_MPI_ERROR(mpi_error);
1695 
1696             int buffer_size = 0;
1697             MPI_Get_count(&statuses[i], MPI_BYTE, &buffer_size);
1698             if(logging)
1699             {
1700                 log << "    MPI_Get_count(&statuses[" << i << "], MPI_BYTE, &buffer_size); -> "
1701                     << buffer_size << std::endl;
1702             }
1703 
1704             // Allocate a node into which we'll receive the raw data.
1705             operations[i].node[1] = new Node(DataType::uint8(buffer_size));
1706             operations[i].free[1] = true;
1707 
1708             if(logging)
1709             {
1710                 log << "    MPI_Irecv("
1711                     << operations[i].node[1]->data_ptr() << ", "
1712                     << buffer_size << ", "
1713                     << "MPI_BYTE, "
1714                     << operations[i].rank << ", "
1715                     << operations[i].tag << ", "
1716                     << "comm, &requests[" << i << "]);" << std::endl;
1717             }
1718 
1719             // Post the actual receive.
1720             mpi_error = MPI_Irecv(operations[i].node[1]->data_ptr(),
1721                                   buffer_size,
1722                                   MPI_BYTE,
1723                                   operations[i].rank,
1724                                   operations[i].tag,
1725                                   comm,
1726                                   &requests[i]);
1727             CONDUIT_CHECK_MPI_ERROR(mpi_error);
1728         }
1729     }
1730     double t2 = MPI_Wtime();
1731     if(logging)
1732     {
1733         log << "* Time issuing MPI_Irecv calls: " << (t2-t1) << std::endl;
1734     }
1735 
1736     // Wait for the requests to complete.
1737     int n = static_cast<int>(operations.size());
1738     if(logging)
1739     {
1740         log << "    MPI_Waitall(" << n << ", &requests[0], &statuses[0]);" << std::endl;
1741     }
1742     mpi_error = MPI_Waitall(n, &requests[0], &statuses[0]);
1743     CONDUIT_CHECK_MPI_ERROR(mpi_error);
1744     double t3 = MPI_Wtime();
1745     if(logging)
1746     {
1747         log << "* Time in MPI_Waitall: " << (t3-t2) << std::endl;
1748     }
1749 
1750     // Finish building the nodes for which we received data.
1751     for(size_t i = 0; i < operations.size(); i++)
1752     {
1753         if(operations[i].op == OP_RECV)
1754         {
1755             // Get the buffer of the data we received.
1756             uint8 *n_buff_ptr = (uint8*)operations[i].node[1]->data_ptr();
1757 
1758             Node n_msg;
1759             // length of the schema is sent as a 64-bit signed int
1760             // NOTE: we aren't using this value  ...
1761             n_msg["schema_len"].set_external((int64*)n_buff_ptr);
1762             n_buff_ptr +=8;
1763             // wrap the schema string
1764             n_msg["schema"].set_external_char8_str((char*)(n_buff_ptr));
1765             // create the schema
1766             Schema rcv_schema;
1767             Generator gen(n_msg["schema"].as_char8_str());
1768             gen.walk(rcv_schema);
1769 
1770             // advance by the schema length
1771             n_buff_ptr += n_msg["schema"].total_bytes_compact();
1772 
1773             // apply the schema to the data
1774             n_msg["data"].set_external(rcv_schema,n_buff_ptr);
1775 
1776             // copy out to our result node
1777             operations[i].node[0]->update(n_msg["data"]);
1778 
1779             if(logging)
1780             {
1781                 log << "* Built output node " << i << std::endl;
1782             }
1783         }
1784     }
1785     double t4 = MPI_Wtime();
1786     if(logging)
1787     {
1788         log << "* Time building output nodes " << (t4-t3) << std::endl;
1789         log.close();
1790     }
1791 
1792     // Cleanup
1793     clear();
1794 
1795     return 0;
1796 }
1797 
1798 //---------------------------------------------------------------------------//
1799 std::string
about()1800 about()
1801 {
1802     Node n;
1803     mpi::about(n);
1804     return n.to_yaml();
1805 }
1806 
1807 //---------------------------------------------------------------------------//
1808 void
about(Node & n)1809 about(Node &n)
1810 {
1811     n.reset();
1812     n["mpi"] = "enabled";
1813 }
1814 
1815 }
1816 //-----------------------------------------------------------------------------
1817 // -- end conduit::relay::mpi --
1818 //-----------------------------------------------------------------------------
1819 
1820 }
1821 //-----------------------------------------------------------------------------
1822 // -- end conduit::relay --
1823 //-----------------------------------------------------------------------------
1824 
1825 
1826 }
1827 //-----------------------------------------------------------------------------
1828 // -- end conduit:: --
1829 //-----------------------------------------------------------------------------
1830 
1831 
1832