1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 
8 /* Intercommunicator Allreduce
9  *
10  * We first do intracommunicator reduces to rank 0 on left and right
11  * groups, then an exchange between left and right rank 0, and finally
12  * intracommunicator broadcasts from rank 0 on left and right
13  * group.
14  */
15 
MPIR_Allreduce_inter_reduce_exchange_bcast(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm_ptr,MPIR_Errflag_t * errflag)16 int MPIR_Allreduce_inter_reduce_exchange_bcast(const void *sendbuf, void *recvbuf, int
17                                                count, MPI_Datatype datatype, MPI_Op op,
18                                                MPIR_Comm * comm_ptr, MPIR_Errflag_t * errflag)
19 {
20     int mpi_errno;
21     int mpi_errno_ret = MPI_SUCCESS;
22     MPI_Aint true_extent, true_lb, extent;
23     void *tmp_buf = NULL;
24     MPIR_Comm *newcomm_ptr = NULL;
25     MPIR_CHKLMEM_DECL(1);
26 
27     if (comm_ptr->rank == 0) {
28         MPIR_Type_get_true_extent_impl(datatype, &true_lb, &true_extent);
29         MPIR_Datatype_get_extent_macro(datatype, extent);
30         MPIR_CHKLMEM_MALLOC(tmp_buf, void *, count * (MPL_MAX(extent, true_extent)), mpi_errno,
31                             "temporary buffer", MPL_MEM_BUFFER);
32         /* adjust for potential negative lower bound in datatype */
33         tmp_buf = (void *) ((char *) tmp_buf - true_lb);
34     }
35 
36     /* Get the local intracommunicator */
37     if (!comm_ptr->local_comm)
38         MPII_Setup_intercomm_localcomm(comm_ptr);
39 
40     newcomm_ptr = comm_ptr->local_comm;
41 
42     /* Do a local reduce on this intracommunicator */
43     mpi_errno = MPIR_Reduce(sendbuf, tmp_buf, count, datatype, op, 0, newcomm_ptr, errflag);
44     if (mpi_errno) {
45         /* for communication errors, just record the error but continue */
46         *errflag =
47             MPIX_ERR_PROC_FAILED ==
48             MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
49         MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
50         MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
51     }
52 
53     /* Do a exchange between local and remote rank 0 on this intercommunicator */
54     if (comm_ptr->rank == 0) {
55         mpi_errno = MPIC_Sendrecv(tmp_buf, count, datatype, 0, MPIR_REDUCE_TAG,
56                                   recvbuf, count, datatype, 0, MPIR_REDUCE_TAG,
57                                   comm_ptr, MPI_STATUS_IGNORE, errflag);
58         if (mpi_errno) {
59             /* for communication errors, just record the error but continue */
60             *errflag =
61                 MPIX_ERR_PROC_FAILED ==
62                 MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
63             MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
64             MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
65         }
66     }
67 
68     /* Do a local broadcast on this intracommunicator */
69     mpi_errno = MPIR_Bcast(recvbuf, count, datatype, 0, newcomm_ptr, errflag);
70     if (mpi_errno) {
71         /* for communication errors, just record the error but continue */
72         *errflag =
73             MPIX_ERR_PROC_FAILED ==
74             MPIR_ERR_GET_CLASS(mpi_errno) ? MPIR_ERR_PROC_FAILED : MPIR_ERR_OTHER;
75         MPIR_ERR_SET(mpi_errno, *errflag, "**fail");
76         MPIR_ERR_ADD(mpi_errno_ret, mpi_errno);
77     }
78 
79   fn_exit:
80     MPIR_CHKLMEM_FREEALL();
81     if (mpi_errno_ret)
82         mpi_errno = mpi_errno_ret;
83     else if (*errflag != MPIR_ERR_NONE)
84         MPIR_ERR_SET(mpi_errno, *errflag, "**coll_fail");
85 
86     return mpi_errno;
87 
88   fn_fail:
89     goto fn_exit;
90 }
91