1 // Copyright (c) Lawrence Livermore National Security, LLC and other Conduit
2 // Project developers. See top-level LICENSE AND COPYRIGHT files for dates and
3 // other details. No copyright assignment is required to contribute to Conduit.
4
5 //-----------------------------------------------------------------------------
6 ///
7 /// file: conduit_relay_mpi.cpp
8 ///
9 //-----------------------------------------------------------------------------
10
11 #include "conduit_relay_mpi.hpp"
12 #include <iostream>
13 #include <limits>
14
15 //-----------------------------------------------------------------------------
16 /// The CONDUIT_CHECK_MPI_ERROR macro is used to check return values for
17 /// mpi calls.
18 //-----------------------------------------------------------------------------
19 #define CONDUIT_CHECK_MPI_ERROR( check_mpi_err_code ) \
20 { \
21 if( static_cast<int>(check_mpi_err_code) != MPI_SUCCESS) \
22 { \
23 char check_mpi_err_str_buff[MPI_MAX_ERROR_STRING]; \
24 int check_mpi_err_str_len=0; \
25 MPI_Error_string( check_mpi_err_code , \
26 check_mpi_err_str_buff, \
27 &check_mpi_err_str_len); \
28 \
29 CONDUIT_ERROR("MPI call failed: \n" \
30 << " error code = " \
31 << check_mpi_err_code << "\n" \
32 << " error message = " \
33 << check_mpi_err_str_buff << "\n"); \
34 return check_mpi_err_code; \
35 } \
36 }
37
38
39 //-----------------------------------------------------------------------------
40 // -- begin conduit:: --
41 //-----------------------------------------------------------------------------
42 namespace conduit
43 {
44
45 //-----------------------------------------------------------------------------
46 // -- begin conduit::relay --
47 //-----------------------------------------------------------------------------
48 namespace relay
49 {
50
51
52 //-----------------------------------------------------------------------------
53 // -- begin conduit::relay::mpi --
54 //-----------------------------------------------------------------------------
55 namespace mpi
56 {
57
58
59 //-----------------------------------------------------------------------------
60 int
size(MPI_Comm mpi_comm)61 size(MPI_Comm mpi_comm)
62 {
63 int res;
64 MPI_Comm_size(mpi_comm,&res);
65 return res;
66 };
67
68 //-----------------------------------------------------------------------------
69 int
rank(MPI_Comm mpi_comm)70 rank(MPI_Comm mpi_comm)
71 {
72 int res;
73 MPI_Comm_rank(mpi_comm,&res);
74 return res;
75 }
76
77 //-----------------------------------------------------------------------------
78 MPI_Datatype
conduit_dtype_to_mpi_dtype(const DataType & dt)79 conduit_dtype_to_mpi_dtype(const DataType &dt)
80 {
81 MPI_Datatype res = MPI_DATATYPE_NULL;
82
83 // can't use switch w/ case statements here b/c NATIVE_IDS may actually
84 // be overloaded on some platforms (this happens on windows)
85
86 index_t dt_id = dt.id();
87
88 // signed integer types
89 if(dt_id == CONDUIT_INT8_ID)
90 {
91 res = MPI_INT8_T;
92 }
93 else if( dt_id == CONDUIT_INT16_ID)
94 {
95 res = MPI_INT16_T;
96 }
97 else if( dt_id == CONDUIT_INT32_ID)
98 {
99 res = MPI_INT32_T;
100 }
101 else if( dt_id == CONDUIT_INT64_ID)
102 {
103 res = MPI_INT64_T;
104 }
105 // unsigned integer types
106 else if( dt_id == CONDUIT_UINT8_ID)
107 {
108 res = MPI_UINT8_T;
109 }
110 else if( dt_id == CONDUIT_UINT16_ID)
111 {
112 res = MPI_UINT16_T;
113 }
114 else if( dt_id == CONDUIT_UINT32_ID)
115 {
116 res = MPI_UINT32_T;
117 }
118 else if( dt_id == CONDUIT_UINT64_ID)
119 {
120 res = MPI_UINT64_T;
121 }
122 // floating point types
123 else if( dt_id == CONDUIT_NATIVE_FLOAT_ID)
124 {
125 res = MPI_FLOAT;
126 }
127 else if( dt_id == CONDUIT_NATIVE_DOUBLE_ID)
128 {
129 res = MPI_DOUBLE;
130 }
131 #if defined(CONDUIT_USE_LONG_DOUBLE)
132 else if( dt_id == CONDUIT_NATIVE_LONG_DOUBLE_ID)
133 {
134 res = MPI_LONG_DOUBLE;
135 }
136 #endif
137 // string type
138 else if( dt_id == DataType::CHAR8_STR_ID)
139 {
140 res = MPI_CHAR;
141 }
142
143 return res;
144 }
145
146
147 //-----------------------------------------------------------------------------
148 index_t
mpi_dtype_to_conduit_dtype_id(MPI_Datatype dt)149 mpi_dtype_to_conduit_dtype_id(MPI_Datatype dt)
150 {
151 index_t res = DataType::EMPTY_ID;
152
153 // can't use switch w/ case statements here b/c in some
154 // MPI implementations MPI_Datatype is a struct (or something more complex)
155 // that won't compile when used in a switch statement.
156
157 // string type
158 if(dt == MPI_CHAR)
159 {
160 res = DataType::CHAR8_STR_ID;
161 }
162 // mpi c bw-style signed integer types
163 if(dt == MPI_INT8_T)
164 {
165 res = CONDUIT_INT8_ID;
166 }
167 else if( dt == MPI_INT16_T)
168 {
169 res = CONDUIT_INT16_ID;
170 }
171 else if( dt == MPI_INT32_T)
172 {
173 res = CONDUIT_INT32_ID;
174 }
175 else if( dt == MPI_INT64_T)
176 {
177 res = CONDUIT_INT64_ID;
178 }
179 // mpi c bw-style unsigned integer types
180 else if( dt == MPI_UINT8_T)
181 {
182 res = CONDUIT_UINT8_ID;
183 }
184 else if( dt == MPI_UINT16_T)
185 {
186 res = CONDUIT_UINT16_ID;
187 }
188 else if( dt == MPI_UINT32_T)
189 {
190 res = CONDUIT_UINT32_ID;
191 }
192 else if( dt == MPI_UINT64_T)
193 {
194 res = CONDUIT_UINT64_ID;
195 }
196 // native c signed integer types
197 else if(dt == MPI_SHORT)
198 {
199 res = CONDUIT_NATIVE_SHORT_ID;
200 }
201 else if(dt == MPI_INT)
202 {
203 res = CONDUIT_NATIVE_INT_ID;
204 }
205 else if(dt == MPI_LONG)
206 {
207 res = CONDUIT_NATIVE_LONG_ID;
208 }
209 #if defined(CONDUIT_HAS_LONG_LONG)
210 else if(dt == MPI_LONG_LONG)
211 {
212 res = CONDUIT_NATIVE_LONG_LONG_ID;
213 }
214 #endif
215 // native c unsigned integer types
216 else if(dt == MPI_BYTE)
217 {
218 res = CONDUIT_NATIVE_UNSIGNED_CHAR_ID;
219 }
220 else if(dt == MPI_UNSIGNED_CHAR)
221 {
222 res = CONDUIT_NATIVE_UNSIGNED_CHAR_ID;
223 }
224 else if(dt == MPI_UNSIGNED_SHORT)
225 {
226 res = CONDUIT_NATIVE_UNSIGNED_SHORT_ID;
227 }
228 else if(dt == MPI_UNSIGNED)
229 {
230 res = CONDUIT_NATIVE_UNSIGNED_INT_ID;
231 }
232 else if(dt == MPI_UNSIGNED_LONG)
233 {
234 res = CONDUIT_NATIVE_UNSIGNED_LONG_ID;
235 }
236 #if defined(CONDUIT_HAS_LONG_LONG)
237 else if(dt == MPI_UNSIGNED_LONG_LONG)
238 {
239 res = CONDUIT_NATIVE_UNSIGNED_LONG_LONG_ID;
240 }
241 #endif
242 // floating point types
243 else if(dt == MPI_FLOAT)
244 {
245 res = CONDUIT_NATIVE_FLOAT_ID;
246 }
247 else if(dt == MPI_DOUBLE)
248 {
249 res = CONDUIT_NATIVE_DOUBLE_ID;
250 }
251 #if defined(CONDUIT_USE_LONG_DOUBLE)
252 else if(dt == MPI_LONG_DOUBLE)
253 {
254 res = CONDUIT_NATIVE_LONG_DOUBLE_ID;
255 }
256 #endif
257 return res;
258 }
259
260 //---------------------------------------------------------------------------//
261 int
send_using_schema(const Node & node,int dest,int tag,MPI_Comm comm)262 send_using_schema(const Node &node, int dest, int tag, MPI_Comm comm)
263 {
264 Schema s_data_compact;
265
266 // schema will only be valid if compact and contig
267 if( node.is_compact() && node.is_contiguous())
268 {
269 s_data_compact = node.schema();
270 }
271 else
272 {
273 node.schema().compact_to(s_data_compact);
274 }
275
276 std::string snd_schema_json = s_data_compact.to_json();
277
278 Schema s_msg;
279 s_msg["schema_len"].set(DataType::int64());
280 s_msg["schema"].set(DataType::char8_str(snd_schema_json.size()+1));
281 s_msg["data"].set(s_data_compact);
282
283 // create a compact schema to use
284 Schema s_msg_compact;
285 s_msg.compact_to(s_msg_compact);
286
287 Node n_msg(s_msg_compact);
288 // these sets won't realloc since schemas are compatible
289 n_msg["schema_len"].set((int64)snd_schema_json.length());
290 n_msg["schema"].set(snd_schema_json);
291 n_msg["data"].update(node);
292
293
294 index_t msg_data_size = n_msg.total_bytes_compact();
295
296 if(!conduit::utils::value_fits<index_t,int>(msg_data_size))
297 {
298 CONDUIT_INFO("Warning size value (" << msg_data_size << ")"
299 " exceeds the size of MPI_Send max value "
300 "(" << std::numeric_limits<int>::max() << ")")
301 }
302
303 int mpi_error = MPI_Send(const_cast<void*>(n_msg.data_ptr()),
304 static_cast<int>(msg_data_size),
305 MPI_BYTE,
306 dest,
307 tag,
308 comm);
309
310 CONDUIT_CHECK_MPI_ERROR(mpi_error);
311
312 return mpi_error;
313 }
314
315
316 //---------------------------------------------------------------------------//
317 int
recv_using_schema(Node & node,int src,int tag,MPI_Comm comm)318 recv_using_schema(Node &node, int src, int tag, MPI_Comm comm)
319 {
320 MPI_Status status;
321
322 int mpi_error = MPI_Probe(src, tag, comm, &status);
323
324 CONDUIT_CHECK_MPI_ERROR(mpi_error);
325
326 int buffer_size = 0;
327 MPI_Get_count(&status, MPI_BYTE, &buffer_size);
328
329 Node n_buffer(DataType::uint8(buffer_size));
330
331 mpi_error = MPI_Recv(n_buffer.data_ptr(),
332 buffer_size,
333 MPI_BYTE,
334 src,
335 tag,
336 comm,
337 &status);
338
339 uint8 *n_buff_ptr = (uint8*)n_buffer.data_ptr();
340
341 Node n_msg;
342 // length of the schema is sent as a 64-bit signed int
343 // NOTE: we aren't using this value ...
344 n_msg["schema_len"].set_external((int64*)n_buff_ptr);
345 n_buff_ptr +=8;
346 // wrap the schema string
347 n_msg["schema"].set_external_char8_str((char*)(n_buff_ptr));
348 // create the schema
349 Schema rcv_schema;
350 Generator gen(n_msg["schema"].as_char8_str());
351 gen.walk(rcv_schema);
352
353 // advance by the schema length
354 n_buff_ptr += n_msg["schema"].total_bytes_compact();
355
356 // apply the schema to the data
357 n_msg["data"].set_external(rcv_schema,n_buff_ptr);
358
359 // copy out to our result node
360 node.update(n_msg["data"]);
361
362 return mpi_error;
363 }
364
365 //---------------------------------------------------------------------------//
366 int
send(const Node & node,int dest,int tag,MPI_Comm comm)367 send(const Node &node, int dest, int tag, MPI_Comm comm)
368 {
369 // assumes size and type are known on the other end
370
371 Node snd_compact;
372
373 const void *snd_ptr = node.contiguous_data_ptr();;
374 index_t snd_size = node.total_bytes_compact();;
375
376 if( snd_ptr == NULL ||
377 ! node.is_compact())
378 {
379 node.compact_to(snd_compact);
380 snd_ptr = snd_compact.data_ptr();
381 }
382
383 if(!conduit::utils::value_fits<index_t,int>(snd_size))
384 {
385 CONDUIT_INFO("Warning size value (" << snd_size << ")"
386 " exceeds the size of MPI_Send max value "
387 "(" << std::numeric_limits<int>::max() << ")")
388 }
389
390 int mpi_error = MPI_Send(const_cast<void*>(snd_ptr),
391 static_cast<int>(snd_size),
392 MPI_BYTE,
393 dest,
394 tag,
395 comm);
396
397 CONDUIT_CHECK_MPI_ERROR(mpi_error);
398
399 return mpi_error;
400 }
401
402 //---------------------------------------------------------------------------//
403 int
recv(Node & node,int src,int tag,MPI_Comm comm)404 recv(Node &node, int src, int tag, MPI_Comm comm)
405 {
406
407 MPI_Status status;
408 Node rcv_compact;
409
410 bool cpy_out = false;
411
412 const void *rcv_ptr = node.contiguous_data_ptr();
413 index_t rcv_size = node.total_bytes_compact();
414
415 if( rcv_ptr == NULL ||
416 ! node.is_compact() )
417 {
418 // we will need to update into rcv node
419 cpy_out = true;
420 Schema s_rcv_compact;
421 node.schema().compact_to(s_rcv_compact);
422 rcv_compact.set_schema(s_rcv_compact);
423 rcv_ptr = rcv_compact.data_ptr();
424 }
425
426 if(!conduit::utils::value_fits<index_t,int>(rcv_size))
427 {
428 CONDUIT_INFO("Warning size value (" << rcv_size << ")"
429 " exceeds the size of MPI_Recv max value "
430 "(" << std::numeric_limits<int>::max() << ")")
431 }
432
433
434 int mpi_error = MPI_Recv(const_cast<void*>(rcv_ptr),
435 static_cast<int>(rcv_size),
436 MPI_BYTE,
437 src,
438 tag,
439 comm,
440 &status);
441
442 CONDUIT_CHECK_MPI_ERROR(mpi_error);
443
444 if(cpy_out)
445 {
446 node.update(rcv_compact);
447 }
448
449 return mpi_error;
450 }
451
452
453 //---------------------------------------------------------------------------//
454 int
reduce(const Node & snd_node,Node & rcv_node,MPI_Op mpi_op,int root,MPI_Comm mpi_comm)455 reduce(const Node &snd_node,
456 Node &rcv_node,
457 MPI_Op mpi_op,
458 int root,
459 MPI_Comm mpi_comm)
460 {
461 MPI_Datatype mpi_dtype = conduit_dtype_to_mpi_dtype(snd_node.dtype());
462
463 if(mpi_dtype == MPI_DATATYPE_NULL)
464 {
465 CONDUIT_ERROR("Unsupported send DataType for mpi::reduce"
466 << snd_node.dtype().name());
467 }
468
469 void *snd_ptr = NULL;
470 void *rcv_ptr = NULL;
471
472 Node snd_compact;
473 Node rcv_compact;
474
475 //note: we don't have to ask for contig in this case, since
476 // we can only reduce leaf types
477 if(snd_node.is_compact())
478 {
479 snd_ptr = const_cast<void*>(snd_node.data_ptr());
480 }
481 else
482 {
483 snd_node.compact_to(snd_compact);
484 snd_ptr = snd_compact.data_ptr();
485 }
486
487 bool cpy_out = false;
488
489 int rank = mpi::rank(mpi_comm);
490
491 if( rank == root )
492 {
493
494 rcv_ptr = rcv_node.contiguous_data_ptr();
495
496
497 if( !snd_node.compatible(rcv_node) ||
498 rcv_ptr == NULL ||
499 !rcv_node.is_compact() )
500 {
501 // we will need to update into rcv node
502 cpy_out = true;
503
504 Schema s_snd_compact;
505 snd_node.schema().compact_to(s_snd_compact);
506
507 rcv_compact.set_schema(s_snd_compact);
508 rcv_ptr = rcv_compact.data_ptr();
509 }
510 }
511
512 int num_eles = (int) snd_node.dtype().number_of_elements();
513
514 int mpi_error = MPI_Reduce(snd_ptr,
515 rcv_ptr,
516 num_eles,
517 mpi_dtype,
518 mpi_op,
519 root,
520 mpi_comm);
521
522 CONDUIT_CHECK_MPI_ERROR(mpi_error);
523
524 if( rank == root && cpy_out )
525 {
526 rcv_node.update(rcv_compact);
527 }
528
529 return mpi_error;
530 }
531
532 //--------------------------------------------------------------------------//
533 int
all_reduce(const Node & snd_node,Node & rcv_node,MPI_Op mpi_op,MPI_Comm mpi_comm)534 all_reduce(const Node &snd_node,
535 Node &rcv_node,
536 MPI_Op mpi_op,
537 MPI_Comm mpi_comm)
538 {
539 MPI_Datatype mpi_dtype = conduit_dtype_to_mpi_dtype(snd_node.dtype());
540
541 if(mpi_dtype == MPI_DATATYPE_NULL)
542 {
543 CONDUIT_ERROR("Unsupported send DataType for mpi::all_reduce"
544 << snd_node.dtype().name());
545 }
546
547
548 void *snd_ptr = NULL;
549 void *rcv_ptr = NULL;
550
551 Node snd_compact;
552 Node rcv_compact;
553
554 //note: we don't have to ask for contig in this case, since
555 // we can only reduce leaf types
556 if(snd_node.is_compact())
557 {
558 snd_ptr = const_cast<void*>(snd_node.data_ptr());
559 }
560 else
561 {
562 snd_node.compact_to(snd_compact);
563 snd_ptr = snd_compact.data_ptr();
564 }
565
566 bool cpy_out = false;
567
568
569 rcv_ptr = rcv_node.contiguous_data_ptr();
570
571
572 if( !snd_node.compatible(rcv_node) ||
573 rcv_ptr == NULL ||
574 !rcv_node.is_compact() )
575 {
576 // we will need to update into rcv node
577 cpy_out = true;
578
579 Schema s_snd_compact;
580 snd_node.schema().compact_to(s_snd_compact);
581
582 rcv_compact.set_schema(s_snd_compact);
583 rcv_ptr = rcv_compact.data_ptr();
584 }
585
586 int num_eles = (int) snd_node.dtype().number_of_elements();
587
588 int mpi_error = MPI_Allreduce(snd_ptr,
589 rcv_ptr,
590 num_eles,
591 mpi_dtype,
592 mpi_op,
593 mpi_comm);
594
595 CONDUIT_CHECK_MPI_ERROR(mpi_error);
596
597 if(cpy_out)
598 {
599 rcv_node.update(rcv_compact);
600 }
601
602
603 return mpi_error;
604 }
605
606 //-- reduce helpers -- //
607
608 //---------------------------------------------------------------------------//
609 int
sum_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)610 sum_reduce(const Node &snd_node,
611 Node &rcv_node,
612 int root,
613 MPI_Comm mpi_comm)
614 {
615 return reduce(snd_node,
616 rcv_node,
617 MPI_SUM,
618 root,
619 mpi_comm);
620 }
621
622
623 //---------------------------------------------------------------------------//
624 int
min_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)625 min_reduce(const Node &snd_node,
626 Node &rcv_node,
627 int root,
628 MPI_Comm mpi_comm)
629 {
630 return reduce(snd_node,
631 rcv_node,
632 MPI_MIN,
633 root,
634 mpi_comm);
635 }
636
637
638
639 //---------------------------------------------------------------------------//
640 int
max_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)641 max_reduce(const Node &snd_node,
642 Node &rcv_node,
643 int root,
644 MPI_Comm mpi_comm)
645 {
646 return reduce(snd_node,
647 rcv_node,
648 MPI_MAX,
649 root,
650 mpi_comm);
651 }
652
653
654
655 //---------------------------------------------------------------------------//
656 int
prod_reduce(const Node & snd_node,Node & rcv_node,int root,MPI_Comm mpi_comm)657 prod_reduce(const Node &snd_node,
658 Node &rcv_node,
659 int root,
660 MPI_Comm mpi_comm)
661 {
662 return reduce(snd_node,
663 rcv_node,
664 MPI_PROD,
665 root,
666 mpi_comm);
667 }
668
669
670
671 //--- all reduce helpers -- /
672 //---------------------------------------------------------------------------//
673 int
sum_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)674 sum_all_reduce(const Node &snd_node,
675 Node &rcv_node,
676 MPI_Comm mpi_comm)
677 {
678 return all_reduce(snd_node,
679 rcv_node,
680 MPI_SUM,
681 mpi_comm);
682 }
683
684
685 //---------------------------------------------------------------------------//
686 int
min_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)687 min_all_reduce(const Node &snd_node,
688 Node &rcv_node,
689 MPI_Comm mpi_comm)
690 {
691 return all_reduce(snd_node,
692 rcv_node,
693 MPI_MIN,
694 mpi_comm);
695
696 }
697
698
699
700 //---------------------------------------------------------------------------//
701 int
max_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)702 max_all_reduce(const Node &snd_node,
703 Node &rcv_node,
704 MPI_Comm mpi_comm)
705 {
706 return all_reduce(snd_node,
707 rcv_node,
708 MPI_MAX,
709 mpi_comm);
710
711 }
712
713
714 //---------------------------------------------------------------------------//
715 int
prod_all_reduce(const Node & snd_node,Node & rcv_node,MPI_Comm mpi_comm)716 prod_all_reduce(const Node &snd_node,
717 Node &rcv_node,
718 MPI_Comm mpi_comm)
719 {
720 return all_reduce(snd_node,
721 rcv_node,
722 MPI_PROD,
723 mpi_comm);
724
725 }
726
727
728
729 //---------------------------------------------------------------------------//
730 int
isend(const Node & node,int dest,int tag,MPI_Comm mpi_comm,Request * request)731 isend(const Node &node,
732 int dest,
733 int tag,
734 MPI_Comm mpi_comm,
735 Request *request)
736 {
737
738 const void *data_ptr = node.contiguous_data_ptr();
739 index_t data_size = node.total_bytes_compact();
740
741 // note: this checks for both compact and contig
742 if( data_ptr == NULL ||
743 !node.is_compact() )
744 {
745 node.compact_to(request->m_buffer);
746 data_ptr = request->m_buffer.data_ptr();
747 }
748
749 // for wait_all, this must always be NULL except for
750 // the irecv cases where copy out is necessary
751 // isend case must always be NULL
752 request->m_rcv_ptr = NULL;
753
754
755 if(!conduit::utils::value_fits<index_t,int>(data_size))
756 {
757 CONDUIT_INFO("Warning size value (" << data_size << ")"
758 " exceeds the size of MPI_Isend max value "
759 "(" << std::numeric_limits<int>::max() << ")")
760 }
761
762 int mpi_error = MPI_Isend(const_cast<void*>(data_ptr),
763 static_cast<int>(data_size),
764 MPI_BYTE,
765 dest,
766 tag,
767 mpi_comm,
768 &(request->m_request));
769
770 CONDUIT_CHECK_MPI_ERROR(mpi_error);
771 return mpi_error;
772 }
773
774 //---------------------------------------------------------------------------//
775 int
irecv(Node & node,int src,int tag,MPI_Comm mpi_comm,Request * request)776 irecv(Node &node,
777 int src,
778 int tag,
779 MPI_Comm mpi_comm,
780 Request *request)
781 {
782 // if rcv is compact, we can write directly into recv
783 // if it's not compact, we need a recv_buffer
784
785 void *data_ptr = node.contiguous_data_ptr();
786 index_t data_size = node.total_bytes_compact();
787
788 // note: this checks for both compact and contig
789 if(data_ptr != NULL &&
790 node.is_compact() )
791 {
792 // for wait_all, this must always be NULL except for
793 // the irecv cases where copy out is necessary
794 request->m_rcv_ptr = NULL;
795 }
796 else
797 {
798 node.compact_to(request->m_buffer);
799 data_ptr = request->m_buffer.data_ptr();
800 request->m_rcv_ptr = &node;
801 }
802
803 if(!conduit::utils::value_fits<index_t,int>(data_size))
804 {
805 CONDUIT_INFO("Warning size value (" << data_size << ")"
806 " exceeds the size of MPI_Irecv max value "
807 "(" << std::numeric_limits<int>::max() << ")")
808 }
809
810 int mpi_error = MPI_Irecv(data_ptr,
811 static_cast<int>(data_size),
812 MPI_BYTE,
813 src,
814 tag,
815 mpi_comm,
816 &(request->m_request));
817
818 CONDUIT_CHECK_MPI_ERROR(mpi_error);
819 return mpi_error;
820 }
821
822
823 //---------------------------------------------------------------------------//
824 // wait handles both send and recv requests
825 int
wait(Request * request,MPI_Status * status)826 wait(Request *request,
827 MPI_Status *status)
828 {
829 int mpi_error = MPI_Wait(&(request->m_request), status);
830 CONDUIT_CHECK_MPI_ERROR(mpi_error);
831
832 // we need to update if m_rcv_ptr was used
833 // this will only be non NULL in the recv copy out case,
834 // sends will always be NULL
835 if(request->m_rcv_ptr)
836 {
837 request->m_rcv_ptr->update(request->m_buffer);
838 }
839
840 request->m_buffer.reset();
841 request->m_rcv_ptr = NULL;
842
843 return mpi_error;
844 }
845
846 //---------------------------------------------------------------------------//
847 int
wait_send(Request * request,MPI_Status * status)848 wait_send(Request *request,
849 MPI_Status *status)
850 {
851 return wait(request,status);
852 }
853
854 //---------------------------------------------------------------------------//
855 int
wait_recv(Request * request,MPI_Status * status)856 wait_recv(Request *request,
857 MPI_Status *status)
858 {
859 return wait(request,status);
860 }
861
862
863 //---------------------------------------------------------------------------//
864 // wait all handles mixed batches of sends and receives
865 int
wait_all(int count,Request requests[],MPI_Status statuses[])866 wait_all(int count,
867 Request requests[],
868 MPI_Status statuses[])
869 {
870 MPI_Request *justrequests = new MPI_Request[count];
871
872 for (int i = 0; i < count; ++i)
873 {
874 // mpi requests can be simply copied
875 justrequests[i] = requests[i].m_request;
876 }
877
878 int mpi_error = MPI_Waitall(count, justrequests, statuses);
879 CONDUIT_CHECK_MPI_ERROR(mpi_error);
880
881 for (int i = 0; i < count; ++i)
882 {
883 // if this request is a recv, we need to check for copy out
884 // m_rcv_ptr will always be NULL, unless we have done a
885 // irecv where we need to use the pointer.
886 if(requests[i].m_rcv_ptr != NULL)
887 {
888 requests[i].m_rcv_ptr->update(requests[i].m_buffer);
889 requests[i].m_rcv_ptr = NULL;
890 }
891
892 requests[i].m_request = justrequests[i];
893 requests[i].m_buffer.reset();
894 }
895
896 delete [] justrequests;
897
898 return mpi_error;
899 }
900
901 //---------------------------------------------------------------------------//
902 int
wait_all_send(int count,Request requests[],MPI_Status statuses[])903 wait_all_send(int count,
904 Request requests[],
905 MPI_Status statuses[])
906 {
907 return wait_all(count,requests,statuses);
908 }
909
910 //---------------------------------------------------------------------------//
911 int
wait_all_recv(int count,Request requests[],MPI_Status statuses[])912 wait_all_recv(int count,
913 Request requests[],
914 MPI_Status statuses[])
915 {
916 return wait_all(count,requests,statuses);
917 }
918
919
920 //---------------------------------------------------------------------------//
921 int
gather(Node & send_node,Node & recv_node,int root,MPI_Comm mpi_comm)922 gather(Node &send_node,
923 Node &recv_node,
924 int root,
925 MPI_Comm mpi_comm)
926 {
927 Node n_snd_compact;
928 Schema s_snd_compact;
929
930 send_node.schema().compact_to(s_snd_compact);
931
932 const void *snd_ptr = send_node.contiguous_data_ptr();
933 index_t snd_size = 0;
934
935
936 if(snd_ptr != NULL &&
937 send_node.is_compact() )
938 {
939 snd_ptr = send_node.data_ptr();
940 snd_size = send_node.total_bytes_compact();
941 }
942 else
943 {
944 send_node.compact_to(n_snd_compact);
945 snd_ptr = n_snd_compact.data_ptr();
946 snd_size = n_snd_compact.total_bytes_compact();
947 }
948
949 int mpi_rank = mpi::rank(mpi_comm);
950 int mpi_size = mpi::size(mpi_comm);
951
952 if(mpi_rank == root)
953 {
954 // TODO: copy out support w/o always reallocing?
955 recv_node.list_of(s_snd_compact,
956 mpi_size);
957 }
958
959 if(!conduit::utils::value_fits<index_t,int>(snd_size))
960 {
961 CONDUIT_INFO("Warning size value (" << snd_size << ")"
962 " exceeds the size of MPI_Gather max value "
963 "(" << std::numeric_limits<int>::max() << ")")
964 }
965
966 int mpi_error = MPI_Gather( const_cast<void*>(snd_ptr), // local data
967 static_cast<int>(snd_size), // local data len
968 MPI_BYTE, // send chars
969 recv_node.data_ptr(), // rcv buffer
970 static_cast<int>(snd_size), // data len
971 MPI_BYTE, // rcv chars
972 root,
973 mpi_comm); // mpi com
974
975 CONDUIT_CHECK_MPI_ERROR(mpi_error);
976
977 return mpi_error;
978 }
979
980 //---------------------------------------------------------------------------//
981 int
all_gather(Node & send_node,Node & recv_node,MPI_Comm mpi_comm)982 all_gather(Node &send_node,
983 Node &recv_node,
984 MPI_Comm mpi_comm)
985 {
986 Node n_snd_compact;
987 Schema s_snd_compact;
988
989 send_node.schema().compact_to(s_snd_compact);
990
991 const void *snd_ptr = send_node.contiguous_data_ptr();
992 index_t snd_size = send_node.total_bytes_compact();
993
994 if( snd_ptr == NULL ||
995 !send_node.is_compact() )
996 {
997 send_node.compact_to(n_snd_compact);
998 snd_ptr = n_snd_compact.data_ptr();
999 }
1000
1001
1002 int mpi_size = mpi::size(mpi_comm);
1003
1004
1005 // TODO: copy out support w/o always reallocing?
1006 // TODO: what about common case of scatter w/ leaf types?
1007 // instead of list_of, we would have a leaf of
1008 // of a given type w/ # of elements == # of ranks.
1009 recv_node.list_of(n_snd_compact.schema(),
1010 mpi_size);
1011
1012 if(!conduit::utils::value_fits<index_t,int>(snd_size))
1013 {
1014 CONDUIT_INFO("Warning size value (" << snd_size << ")"
1015 " exceeds the size of MPI_Gather max value "
1016 "(" << std::numeric_limits<int>::max() << ")")
1017 }
1018
1019 int mpi_error = MPI_Allgather( const_cast<void*>(snd_ptr), // local data
1020 static_cast<int>(snd_size), // local data len
1021 MPI_BYTE, // send chars
1022 recv_node.data_ptr(), // rcv buffer
1023 static_cast<int>(snd_size), // data len
1024 MPI_BYTE, // rcv chars
1025 mpi_comm); // mpi com
1026
1027 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1028
1029 return mpi_error;
1030 }
1031
1032
1033
1034 //---------------------------------------------------------------------------//
1035 int
gather_using_schema(Node & send_node,Node & recv_node,int root,MPI_Comm mpi_comm)1036 gather_using_schema(Node &send_node,
1037 Node &recv_node,
1038 int root,
1039 MPI_Comm mpi_comm)
1040 {
1041 Node n_snd_compact;
1042 send_node.compact_to(n_snd_compact);
1043
1044 int m_size = mpi::size(mpi_comm);
1045 int m_rank = mpi::rank(mpi_comm);
1046
1047 std::string schema_str = n_snd_compact.schema().to_json();
1048
1049 int schema_len = static_cast<int>(schema_str.length() + 1);
1050 int data_len = static_cast<int>(n_snd_compact.total_bytes_compact());
1051
1052 // to do the conduit gatherv, first need a gather to get the
1053 // schema and data buffer sizes
1054
1055 int snd_sizes[] = {schema_len, data_len};
1056
1057 Node n_rcv_sizes;
1058
1059 if( m_rank == root )
1060 {
1061 Schema s;
1062 s["schema_len"].set(DataType::c_int());
1063 s["data_len"].set(DataType::c_int());
1064 n_rcv_sizes.list_of(s,m_size);
1065 }
1066
1067 int mpi_error = MPI_Gather( snd_sizes, // local data
1068 2, // two ints per rank
1069 MPI_INT, // send ints
1070 n_rcv_sizes.data_ptr(), // rcv buffer
1071 2, // two ints per rank
1072 MPI_INT, // rcv ints
1073 root, // id of root for gather op
1074 mpi_comm); // mpi com
1075
1076 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1077
1078 Node n_rcv_tmp;
1079
1080 int *schema_rcv_counts = NULL;
1081 int *schema_rcv_displs = NULL;
1082 char *schema_rcv_buff = NULL;
1083
1084 int *data_rcv_counts = NULL;
1085 int *data_rcv_displs = NULL;
1086 char *data_rcv_buff = NULL;
1087
1088 // we only need rcv params on the gather root
1089 if( m_rank == root )
1090 {
1091 // alloc data for the mpi gather counts and displ arrays
1092 n_rcv_tmp["schemas/counts"].set(DataType::c_int(m_size));
1093 n_rcv_tmp["schemas/displs"].set(DataType::c_int(m_size));
1094
1095 n_rcv_tmp["data/counts"].set(DataType::c_int(m_size));
1096 n_rcv_tmp["data/displs"].set(DataType::c_int(m_size));
1097
1098 // get pointers to counts and displs
1099 schema_rcv_counts = n_rcv_tmp["schemas/counts"].value();
1100 schema_rcv_displs = n_rcv_tmp["schemas/displs"].value();
1101
1102 data_rcv_counts = n_rcv_tmp["data/counts"].value();
1103 data_rcv_displs = n_rcv_tmp["data/displs"].value();
1104
1105 int schema_curr_displ = 0;
1106 int data_curr_displ = 0;
1107 int i=0;
1108
1109 NodeIterator itr = n_rcv_sizes.children();
1110 while(itr.has_next())
1111 {
1112 Node &curr = itr.next();
1113
1114 int schema_curr_count = curr["schema_len"].value();
1115 int data_curr_count = curr["data_len"].value();
1116
1117 schema_rcv_counts[i] = schema_curr_count;
1118 schema_rcv_displs[i] = schema_curr_displ;
1119 schema_curr_displ += schema_curr_count;
1120
1121 data_rcv_counts[i] = data_curr_count;
1122 data_rcv_displs[i] = data_curr_displ;
1123 data_curr_displ += data_curr_count;
1124
1125 i++;
1126 }
1127
1128 n_rcv_tmp["schemas/data"].set(DataType::c_char(schema_curr_displ));
1129 schema_rcv_buff = n_rcv_tmp["schemas/data"].value();
1130 }
1131
1132 mpi_error = MPI_Gatherv( const_cast <char*>(schema_str.c_str()),
1133 schema_len,
1134 MPI_BYTE,
1135 schema_rcv_buff,
1136 schema_rcv_counts,
1137 schema_rcv_displs,
1138 MPI_BYTE,
1139 root,
1140 mpi_comm);
1141
1142 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1143
1144 // build all schemas from JSON, compact them.
1145 Schema rcv_schema;
1146 if( m_rank == root )
1147 {
1148 //TODO: should we make it easer to create a compact schema?
1149 // TODO: Revisit, I think we can do this better
1150
1151 Schema s_tmp;
1152 for(int i=0;i < m_size; i++)
1153 {
1154 Schema &s = s_tmp.append();
1155 s.set(&schema_rcv_buff[schema_rcv_displs[i]]);
1156 }
1157
1158 s_tmp.compact_to(rcv_schema);
1159 }
1160
1161
1162 if( m_rank == root )
1163 {
1164 // allocate data to hold the gather result
1165 // TODO can we support copy out w/out realloc
1166 recv_node.set(rcv_schema);
1167 data_rcv_buff = (char*)recv_node.data_ptr();
1168 }
1169
1170 mpi_error = MPI_Gatherv( n_snd_compact.data_ptr(),
1171 data_len,
1172 MPI_BYTE,
1173 data_rcv_buff,
1174 data_rcv_counts,
1175 data_rcv_displs,
1176 MPI_BYTE,
1177 root,
1178 mpi_comm);
1179
1180 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1181
1182 return mpi_error;
1183 }
1184
1185 //---------------------------------------------------------------------------//
1186 int
all_gather_using_schema(Node & send_node,Node & recv_node,MPI_Comm mpi_comm)1187 all_gather_using_schema(Node &send_node,
1188 Node &recv_node,
1189 MPI_Comm mpi_comm)
1190 {
1191 Node n_snd_compact;
1192 send_node.compact_to(n_snd_compact);
1193
1194 int m_size = mpi::size(mpi_comm);
1195
1196 std::string schema_str = n_snd_compact.schema().to_json();
1197
1198 int schema_len = static_cast<int>(schema_str.length() + 1);
1199 int data_len = static_cast<int>(n_snd_compact.total_bytes_compact());
1200
1201 // to do the conduit gatherv, first need a gather to get the
1202 // schema and data buffer sizes
1203
1204 int snd_sizes[] = {schema_len, data_len};
1205
1206 Node n_rcv_sizes;
1207
1208 Schema s;
1209 s["schema_len"].set(DataType::c_int());
1210 s["data_len"].set(DataType::c_int());
1211 n_rcv_sizes.list_of(s,m_size);
1212
1213 int mpi_error = MPI_Allgather( snd_sizes, // local data
1214 2, // two ints per rank
1215 MPI_INT, // send ints
1216 n_rcv_sizes.data_ptr(), // rcv buffer
1217 2, // two ints per rank
1218 MPI_INT, // rcv ints
1219 mpi_comm); // mpi com
1220
1221 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1222
1223 Node n_rcv_tmp;
1224
1225 int *schema_rcv_counts = NULL;
1226 int *schema_rcv_displs = NULL;
1227 char *schema_rcv_buff = NULL;
1228
1229 int *data_rcv_counts = NULL;
1230 int *data_rcv_displs = NULL;
1231 char *data_rcv_buff = NULL;
1232
1233
1234 // alloc data for the mpi gather counts and displ arrays
1235 n_rcv_tmp["schemas/counts"].set(DataType::c_int(m_size));
1236 n_rcv_tmp["schemas/displs"].set(DataType::c_int(m_size));
1237
1238 n_rcv_tmp["data/counts"].set(DataType::c_int(m_size));
1239 n_rcv_tmp["data/displs"].set(DataType::c_int(m_size));
1240
1241 // get pointers to counts and displs
1242 schema_rcv_counts = n_rcv_tmp["schemas/counts"].value();
1243 schema_rcv_displs = n_rcv_tmp["schemas/displs"].value();
1244
1245 data_rcv_counts = n_rcv_tmp["data/counts"].value();
1246 data_rcv_displs = n_rcv_tmp["data/displs"].value();
1247
1248 int schema_curr_displ = 0;
1249 int data_curr_displ = 0;
1250
1251 NodeIterator itr = n_rcv_sizes.children();
1252
1253 index_t child_idx = 0;
1254
1255 while(itr.has_next())
1256 {
1257 Node &curr = itr.next();
1258
1259 int schema_curr_count = curr["schema_len"].value();
1260 int data_curr_count = curr["data_len"].value();
1261
1262 schema_rcv_counts[child_idx] = schema_curr_count;
1263 schema_rcv_displs[child_idx] = schema_curr_displ;
1264 schema_curr_displ += schema_curr_count;
1265
1266 data_rcv_counts[child_idx] = data_curr_count;
1267 data_rcv_displs[child_idx] = data_curr_displ;
1268 data_curr_displ += data_curr_count;
1269
1270 child_idx+=1;
1271 }
1272
1273 n_rcv_tmp["schemas/data"].set(DataType::c_char(schema_curr_displ));
1274 schema_rcv_buff = n_rcv_tmp["schemas/data"].value();
1275
1276 mpi_error = MPI_Allgatherv( const_cast <char*>(schema_str.c_str()),
1277 schema_len,
1278 MPI_BYTE,
1279 schema_rcv_buff,
1280 schema_rcv_counts,
1281 schema_rcv_displs,
1282 MPI_BYTE,
1283 mpi_comm);
1284
1285 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1286
1287 // build all schemas from JSON, compact them.
1288 Schema rcv_schema;
1289 //TODO: should we make it easer to create a compact schema?
1290 // TODO: Revisit, I think we can do this better
1291 Schema s_tmp;
1292 for(int s_idx=0; s_idx < m_size; s_idx++)
1293 {
1294 Schema &s_new = s_tmp.append();
1295 s_new.set(&schema_rcv_buff[schema_rcv_displs[s_idx]]);
1296 }
1297
1298 // TODO can we support copy out w/out realloc
1299 s_tmp.compact_to(rcv_schema);
1300
1301 // allocate data to hold the gather result
1302 recv_node.set(rcv_schema);
1303 data_rcv_buff = (char*)recv_node.data_ptr();
1304
1305 mpi_error = MPI_Allgatherv( n_snd_compact.data_ptr(),
1306 data_len,
1307 MPI_BYTE,
1308 data_rcv_buff,
1309 data_rcv_counts,
1310 data_rcv_displs,
1311 MPI_BYTE,
1312 mpi_comm);
1313
1314 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1315
1316 return mpi_error;
1317 }
1318
1319
1320 //---------------------------------------------------------------------------//
1321 int
broadcast(Node & node,int root,MPI_Comm comm)1322 broadcast(Node &node,
1323 int root,
1324 MPI_Comm comm)
1325 {
1326 int rank = mpi::rank(comm);
1327
1328 Node bcast_buffer;
1329
1330 bool cpy_out = false;
1331
1332 void *bcast_data_ptr = node.contiguous_data_ptr();
1333 index_t bcast_data_size = node.total_bytes_compact();
1334
1335 // setup buffers on root for send
1336 if(rank == root)
1337 {
1338 if( bcast_data_ptr == NULL ||
1339 ! node.is_compact() )
1340 {
1341 node.compact_to(bcast_buffer);
1342 bcast_data_ptr = bcast_buffer.data_ptr();
1343 }
1344
1345 }
1346 else // rank != root, setup buffers on non root for rcv
1347 {
1348 if( bcast_data_ptr == NULL ||
1349 ! node.is_compact() )
1350 {
1351 Schema s_compact;
1352 node.schema().compact_to(s_compact);
1353 bcast_buffer.set_schema(s_compact);
1354
1355 bcast_data_ptr = bcast_buffer.data_ptr();
1356 cpy_out = true;
1357 }
1358 }
1359
1360
1361 if(!conduit::utils::value_fits<index_t,int>(bcast_data_size))
1362 {
1363 CONDUIT_INFO("Warning size value (" << bcast_data_size << ")"
1364 " exceeds the size of MPI_Bcast max value "
1365 "(" << std::numeric_limits<int>::max() << ")")
1366 }
1367
1368
1369 int mpi_error = MPI_Bcast(bcast_data_ptr,
1370 static_cast<int>(bcast_data_size),
1371 MPI_BYTE,
1372 root,
1373 comm);
1374
1375 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1376
1377 // note: cpy_out will always be false when rank == root
1378 if( cpy_out )
1379 {
1380 node.update(bcast_buffer);
1381 }
1382
1383 return mpi_error;
1384 }
1385
1386 //---------------------------------------------------------------------------//
1387 int
broadcast_using_schema(Node & node,int root,MPI_Comm comm)1388 broadcast_using_schema(Node &node,
1389 int root,
1390 MPI_Comm comm)
1391 {
1392 int rank = mpi::rank(comm);
1393
1394 Node bcast_buffers;
1395
1396 void *bcast_data_ptr = NULL;
1397 int bcast_data_size = 0;
1398
1399 int bcast_schema_size = 0;
1400 int rcv_bcast_schema_size = 0;
1401
1402 // setup buffers for send
1403 if(rank == root)
1404 {
1405
1406 bcast_data_ptr = node.contiguous_data_ptr();
1407 bcast_data_size = static_cast<int>(node.total_bytes_compact());
1408
1409 if(bcast_data_ptr != NULL &&
1410 node.is_compact() &&
1411 node.is_contiguous())
1412 {
1413 bcast_buffers["schema"] = node.schema().to_json();
1414 }
1415 else
1416 {
1417 Node &bcast_data_compact = bcast_buffers["data"];
1418 node.compact_to(bcast_data_compact);
1419
1420 bcast_data_ptr = bcast_data_compact.data_ptr();
1421 bcast_buffers["schema"] = bcast_data_compact.schema().to_json();
1422 }
1423
1424
1425
1426 bcast_schema_size = static_cast<int>(bcast_buffers["schema"].dtype().number_of_elements());
1427 }
1428
1429 int mpi_error = MPI_Allreduce(&bcast_schema_size,
1430 &rcv_bcast_schema_size,
1431 1,
1432 MPI_INT,
1433 MPI_MAX,
1434 comm);
1435
1436 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1437
1438 bcast_schema_size = rcv_bcast_schema_size;
1439
1440 // alloc for rcv for schema
1441 if(rank != root)
1442 {
1443 bcast_buffers["schema"].set(DataType::char8_str(bcast_schema_size));
1444 }
1445
1446 // broadcast the schema
1447 mpi_error = MPI_Bcast(bcast_buffers["schema"].data_ptr(),
1448 bcast_schema_size,
1449 MPI_CHAR,
1450 root,
1451 comm);
1452
1453 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1454
1455 bool cpy_out = false;
1456
1457 // setup buffers for receive
1458 if(rank != root)
1459 {
1460 Schema bcast_schema;
1461 Generator gen(bcast_buffers["schema"].as_char8_str());
1462 gen.walk(bcast_schema);
1463
1464 // only check compat for leaves
1465 // there are more zero copy cases possible here, but
1466 // we need a better way to identify them
1467 // compatible won't work for object cases that
1468 // have different named leaves
1469 if( !(node.dtype().is_empty() ||
1470 node.dtype().is_object() ||
1471 node.dtype().is_list() ) &&
1472 !(bcast_schema.dtype().is_empty() ||
1473 bcast_schema.dtype().is_object() ||
1474 bcast_schema.dtype().is_list() )
1475 && bcast_schema.compatible(node.schema()))
1476 {
1477
1478 bcast_data_ptr = node.contiguous_data_ptr();
1479 bcast_data_size = static_cast<int>(node.total_bytes_compact());
1480
1481 if( bcast_data_ptr == NULL ||
1482 ! node.is_compact() )
1483 {
1484 Node &bcast_data_buffer = bcast_buffers["data"];
1485 bcast_data_buffer.set_schema(bcast_schema);
1486
1487 bcast_data_ptr = bcast_data_buffer.data_ptr();
1488 cpy_out = true;
1489 }
1490 }
1491 else
1492 {
1493 node.set_schema(bcast_schema);
1494
1495 bcast_data_ptr = node.data_ptr();
1496 bcast_data_size = static_cast<int>(node.total_bytes_compact());
1497 }
1498 }
1499
1500 mpi_error = MPI_Bcast(bcast_data_ptr,
1501 bcast_data_size,
1502 MPI_BYTE,
1503 root,
1504 comm);
1505
1506 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1507
1508 // note: cpy_out will always be false when rank == root
1509 if( cpy_out )
1510 {
1511 node.update(bcast_buffers["data"]);
1512 }
1513
1514 return mpi_error;
1515 }
1516
1517 //-----------------------------------------------------------------------------
1518 //-----------------------------------------------------------------------------
1519 const int communicate_using_schema::OP_SEND = 1;
1520 const int communicate_using_schema::OP_RECV = 2;
1521
1522 //-----------------------------------------------------------------------------
communicate_using_schema(MPI_Comm c)1523 communicate_using_schema::communicate_using_schema(MPI_Comm c) :
1524 comm(c), operations(), logging(false)
1525 {
1526 }
1527
1528 //-----------------------------------------------------------------------------
~communicate_using_schema()1529 communicate_using_schema::~communicate_using_schema()
1530 {
1531 clear();
1532 }
1533
1534 //-----------------------------------------------------------------------------
1535 void
clear()1536 communicate_using_schema::clear()
1537 {
1538 for(size_t i = 0; i < operations.size(); i++)
1539 {
1540 if(operations[i].free[0])
1541 delete operations[i].node[0];
1542 if(operations[i].free[1])
1543 delete operations[i].node[1];
1544 }
1545 operations.clear();
1546 }
1547
1548 //-----------------------------------------------------------------------------
1549 void
set_logging(bool val)1550 communicate_using_schema::set_logging(bool val)
1551 {
1552 logging = val;
1553 }
1554
1555 //-----------------------------------------------------------------------------
1556 void
add_isend(const Node & node,int dest,int tag)1557 communicate_using_schema::add_isend(const Node &node, int dest, int tag)
1558 {
1559 // Append the work to the operations.
1560 operation work;
1561 work.op = OP_SEND;
1562 work.rank = dest;
1563 work.tag = tag;
1564 work.node[0] = const_cast<Node *>(&node); // The node we're sending.
1565 work.free[0] = false;
1566 work.node[1] = nullptr;
1567 work.free[1] = false;
1568 operations.push_back(work);
1569 }
1570
1571 //-----------------------------------------------------------------------------
1572 void
add_irecv(Node & node,int src,int tag)1573 communicate_using_schema::add_irecv(Node &node, int src, int tag)
1574 {
1575 // Append the work to the operations.
1576 operation work;
1577 work.op = OP_RECV;
1578 work.rank = src;
1579 work.tag = tag;
1580 work.node[0] = &node; // Node that will contain final data.
1581 work.free[0] = false; // Don't need to free it.
1582 work.node[1] = nullptr;
1583 work.free[1] = false;
1584 operations.push_back(work);
1585 }
1586
1587 //-----------------------------------------------------------------------------
1588 int
execute()1589 communicate_using_schema::execute()
1590 {
1591 int mpi_error = 0;
1592 std::vector<MPI_Request> requests(operations.size());
1593 std::vector<MPI_Status> statuses(operations.size());
1594
1595 int rank, size;
1596 MPI_Comm_rank(comm, &rank);
1597 MPI_Comm_size(comm, &size);
1598 std::ofstream log;
1599 double t0 = MPI_Wtime();
1600 if(logging)
1601 {
1602 char fn[128];
1603 sprintf(fn, "communicate_using_schema.%04d.log", rank);
1604 log.open(fn, std::ofstream::out);
1605 log << "* Log started on rank " << rank << " at " << t0 << std::endl;
1606 }
1607
1608 // Issue all the sends (so they are in flight by the time we probe them)
1609 for(size_t i = 0; i < operations.size(); i++)
1610 {
1611 if(operations[i].op == OP_SEND)
1612 {
1613 Schema s_data_compact;
1614 const Node &node = *operations[i].node[0];
1615 // schema will only be valid if compact and contig
1616 if( node.is_compact() && node.is_contiguous())
1617 {
1618 s_data_compact = node.schema();
1619 }
1620 else
1621 {
1622 node.schema().compact_to(s_data_compact);
1623 }
1624
1625 std::string snd_schema_json = s_data_compact.to_json();
1626
1627 Schema s_msg;
1628 s_msg["schema_len"].set(DataType::int64());
1629 s_msg["schema"].set(DataType::char8_str(snd_schema_json.size()+1));
1630 s_msg["data"].set(s_data_compact);
1631
1632 // create a compact schema to use
1633 Schema s_msg_compact;
1634 s_msg.compact_to(s_msg_compact);
1635
1636 operations[i].node[1] = new Node(s_msg_compact);
1637 operations[i].free[1] = true;
1638 Node &n_msg = *operations[i].node[1];
1639 // these sets won't realloc since schemas are compatible
1640 n_msg["schema_len"].set((int64)snd_schema_json.length());
1641 n_msg["schema"].set(snd_schema_json);
1642 n_msg["data"].update(node);
1643
1644 // Send the serialized node data.
1645 index_t msg_data_size = operations[i].node[1]->total_bytes_compact();
1646 if(logging)
1647 {
1648 log << " MPI_Isend("
1649 << const_cast<void*>(operations[i].node[1]->data_ptr()) << ", "
1650 << msg_data_size << ", "
1651 << "MPI_BYTE, "
1652 << operations[i].rank << ", "
1653 << operations[i].tag << ", "
1654 << "comm, &requests[" << i << "]);" << std::endl;
1655 }
1656
1657 if(!conduit::utils::value_fits<index_t,int>(msg_data_size))
1658 {
1659 CONDUIT_INFO("Warning size value (" << msg_data_size << ")"
1660 " exceeds the size of MPI_Isend max value "
1661 "(" << std::numeric_limits<int>::max() << ")")
1662 }
1663
1664 mpi_error = MPI_Isend(const_cast<void*>(operations[i].node[1]->data_ptr()),
1665 static_cast<int>(msg_data_size),
1666 MPI_BYTE,
1667 operations[i].rank,
1668 operations[i].tag,
1669 comm,
1670 &requests[i]);
1671 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1672 }
1673 }
1674 double t1 = MPI_Wtime();
1675 if(logging)
1676 {
1677 log << "* Time issuing MPI_Isend calls: " << (t1-t0) << std::endl;
1678 }
1679
1680 // Issue all the recvs.
1681 for(size_t i = 0; i < operations.size(); i++)
1682 {
1683 if(operations[i].op == OP_RECV)
1684 {
1685 // Probe the message for its buffer size.
1686 if(logging)
1687 {
1688 log << " MPI_Probe("
1689 << operations[i].rank << ", "
1690 << operations[i].tag << ", "
1691 << "comm, &statuses[" << i << "]);" << std::endl;
1692 }
1693 mpi_error = MPI_Probe(operations[i].rank, operations[i].tag, comm, &statuses[i]);
1694 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1695
1696 int buffer_size = 0;
1697 MPI_Get_count(&statuses[i], MPI_BYTE, &buffer_size);
1698 if(logging)
1699 {
1700 log << " MPI_Get_count(&statuses[" << i << "], MPI_BYTE, &buffer_size); -> "
1701 << buffer_size << std::endl;
1702 }
1703
1704 // Allocate a node into which we'll receive the raw data.
1705 operations[i].node[1] = new Node(DataType::uint8(buffer_size));
1706 operations[i].free[1] = true;
1707
1708 if(logging)
1709 {
1710 log << " MPI_Irecv("
1711 << operations[i].node[1]->data_ptr() << ", "
1712 << buffer_size << ", "
1713 << "MPI_BYTE, "
1714 << operations[i].rank << ", "
1715 << operations[i].tag << ", "
1716 << "comm, &requests[" << i << "]);" << std::endl;
1717 }
1718
1719 // Post the actual receive.
1720 mpi_error = MPI_Irecv(operations[i].node[1]->data_ptr(),
1721 buffer_size,
1722 MPI_BYTE,
1723 operations[i].rank,
1724 operations[i].tag,
1725 comm,
1726 &requests[i]);
1727 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1728 }
1729 }
1730 double t2 = MPI_Wtime();
1731 if(logging)
1732 {
1733 log << "* Time issuing MPI_Irecv calls: " << (t2-t1) << std::endl;
1734 }
1735
1736 // Wait for the requests to complete.
1737 int n = static_cast<int>(operations.size());
1738 if(logging)
1739 {
1740 log << " MPI_Waitall(" << n << ", &requests[0], &statuses[0]);" << std::endl;
1741 }
1742 mpi_error = MPI_Waitall(n, &requests[0], &statuses[0]);
1743 CONDUIT_CHECK_MPI_ERROR(mpi_error);
1744 double t3 = MPI_Wtime();
1745 if(logging)
1746 {
1747 log << "* Time in MPI_Waitall: " << (t3-t2) << std::endl;
1748 }
1749
1750 // Finish building the nodes for which we received data.
1751 for(size_t i = 0; i < operations.size(); i++)
1752 {
1753 if(operations[i].op == OP_RECV)
1754 {
1755 // Get the buffer of the data we received.
1756 uint8 *n_buff_ptr = (uint8*)operations[i].node[1]->data_ptr();
1757
1758 Node n_msg;
1759 // length of the schema is sent as a 64-bit signed int
1760 // NOTE: we aren't using this value ...
1761 n_msg["schema_len"].set_external((int64*)n_buff_ptr);
1762 n_buff_ptr +=8;
1763 // wrap the schema string
1764 n_msg["schema"].set_external_char8_str((char*)(n_buff_ptr));
1765 // create the schema
1766 Schema rcv_schema;
1767 Generator gen(n_msg["schema"].as_char8_str());
1768 gen.walk(rcv_schema);
1769
1770 // advance by the schema length
1771 n_buff_ptr += n_msg["schema"].total_bytes_compact();
1772
1773 // apply the schema to the data
1774 n_msg["data"].set_external(rcv_schema,n_buff_ptr);
1775
1776 // copy out to our result node
1777 operations[i].node[0]->update(n_msg["data"]);
1778
1779 if(logging)
1780 {
1781 log << "* Built output node " << i << std::endl;
1782 }
1783 }
1784 }
1785 double t4 = MPI_Wtime();
1786 if(logging)
1787 {
1788 log << "* Time building output nodes " << (t4-t3) << std::endl;
1789 log.close();
1790 }
1791
1792 // Cleanup
1793 clear();
1794
1795 return 0;
1796 }
1797
1798 //---------------------------------------------------------------------------//
1799 std::string
about()1800 about()
1801 {
1802 Node n;
1803 mpi::about(n);
1804 return n.to_yaml();
1805 }
1806
1807 //---------------------------------------------------------------------------//
1808 void
about(Node & n)1809 about(Node &n)
1810 {
1811 n.reset();
1812 n["mpi"] = "enabled";
1813 }
1814
1815 }
1816 //-----------------------------------------------------------------------------
1817 // -- end conduit::relay::mpi --
1818 //-----------------------------------------------------------------------------
1819
1820 }
1821 //-----------------------------------------------------------------------------
1822 // -- end conduit::relay --
1823 //-----------------------------------------------------------------------------
1824
1825
1826 }
1827 //-----------------------------------------------------------------------------
1828 // -- end conduit:: --
1829 //-----------------------------------------------------------------------------
1830
1831
1832