1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 /* Header protection (i.e., IALLREDUCE_TSP_RING_ALGOS_H_INCLUDED) is
7  * intentionally omitted since this header might get included multiple
8  * times within the same .c file. */
9 
10 #include "algo_common.h"
11 #include "tsp_namespace_def.h"
12 #include "../iallgatherv/iallgatherv_tsp_ring_algos_prototypes.h"
13 
14 /* Routine to schedule a ring exchange based allreduce.
15  * The implementation is based on Baidu's ring algorithm
16  * for Machine Learning/Deep Learning. The algorithm is
17  * explained here: http://andrew.gibiansky.com/ */
MPIR_TSP_Iallreduce_sched_intra_ring(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm,MPIR_TSP_sched_t * sched)18 int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, int count,
19                                          MPI_Datatype datatype, MPI_Op op,
20                                          MPIR_Comm * comm, MPIR_TSP_sched_t * sched)
21 {
22     int mpi_errno = MPI_SUCCESS;
23     int i, src, dst;
24     int nranks, is_inplace, rank;
25     size_t extent;
26     MPI_Aint lb, true_extent;
27     int *cnts, *displs, recv_id, *reduce_id, nvtcs, vtcs;
28     int send_rank, recv_rank, total_count;
29     void *tmpbuf;
30     int tag;
31 
32     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
33     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
34     MPIR_CHKLMEM_DECL(4);
35 
36     is_inplace = (sendbuf == MPI_IN_PLACE);
37     nranks = MPIR_Comm_size(comm);
38     rank = MPIR_Comm_rank(comm);
39 
40     MPIR_Datatype_get_extent_macro(datatype, extent);
41     MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent);
42     extent = MPL_MAX(extent, true_extent);
43 
44     MPIR_CHKLMEM_MALLOC(cnts, int *, nranks * sizeof(int), mpi_errno, "cnts", MPL_MEM_COLL);
45     MPIR_CHKLMEM_MALLOC(displs, int *, nranks * sizeof(int), mpi_errno, "displs", MPL_MEM_COLL);
46 
47     for (i = 0; i < nranks; i++)
48         cnts[i] = 0;
49 
50     total_count = 0;
51     for (i = 0; i < nranks; i++) {
52         cnts[i] = (count + nranks - 1) / nranks;
53         if (total_count + cnts[i] > count) {
54             cnts[i] = count - total_count;
55             break;
56         } else
57             total_count += cnts[i];
58     }
59 
60     displs[0] = 0;
61     for (i = 1; i < nranks; i++)
62         displs[i] = displs[i - 1] + cnts[i - 1];
63 
64     /* Phase 1: copy to tmp buf */
65     if (!is_inplace) {
66         MPIR_TSP_sched_localcopy(sendbuf, count, datatype, recvbuf, count, datatype, sched, 0,
67                                  NULL);
68         MPIR_TSP_sched_fence(sched);
69     }
70 
71     /* Phase 2: Ring based send recv reduce scatter */
72     /* Need only 2 spaces for current and previous reduce_id(s) */
73     MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL);
74     tmpbuf = MPIR_TSP_sched_malloc(count * extent, sched);
75 
76     src = (nranks + rank - 1) % nranks;
77     dst = (rank + 1) % nranks;
78 
79     for (i = 0; i < nranks - 1; i++) {
80         recv_rank = (nranks + rank - 2 - i) % nranks;
81         send_rank = (nranks + rank - 1 - i) % nranks;
82 
83         /* get a new tag to prevent out of order messages */
84         mpi_errno = MPIR_Sched_next_tag(comm, &tag);
85         MPIR_ERR_CHECK(mpi_errno);
86 
87         nvtcs = (i == 0) ? 0 : 1;
88         vtcs = (i == 0) ? 0 : reduce_id[(i - 1) % 2];
89         recv_id =
90             MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, sched, nvtcs,
91                                  &vtcs);
92 
93         reduce_id[i % 2] =
94             MPIR_TSP_sched_reduce_local(tmpbuf, (char *) recvbuf + displs[recv_rank] * extent,
95                                         cnts[recv_rank], datatype, op, sched, 1, &recv_id);
96 
97         MPIR_TSP_sched_isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank],
98                              datatype, dst, tag, comm, sched, nvtcs, &vtcs);
99 
100         MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE,
101                         (MPL_DBG_FDEST,
102                          "displs[recv_rank:%d]:%d, cnts[recv_rank:%d, displs[send_rank:%d]:%d, cnts[send_rank:%d]:%d]:%d ",
103                          recv_rank, displs[recv_rank], recv_rank, cnts[recv_rank], send_rank,
104                          displs[send_rank], send_rank, cnts[send_rank]));
105     }
106     MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL);
107 
108     MPIR_TSP_sched_fence(sched);
109 
110     /* Phase 3: Allgatherv ring, so everyone has the reduced data */
111     MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts,
112                                           displs, datatype, comm, sched);
113 
114     MPIR_CHKLMEM_FREEALL();
115 
116   fn_exit:
117     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
118     return mpi_errno;
119 
120   fn_fail:
121     goto fn_exit;
122 }
123 
124 /* Non-blocking ring based Allreduce */
MPIR_TSP_Iallreduce_intra_ring(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm,MPIR_Request ** req)125 int MPIR_TSP_Iallreduce_intra_ring(const void *sendbuf, void *recvbuf, int count,
126                                    MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm,
127                                    MPIR_Request ** req)
128 {
129     int mpi_errno = MPI_SUCCESS;
130     MPIR_TSP_sched_t *sched;
131     *req = NULL;
132 
133     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
134     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
135 
136     /* generate the schedule */
137     sched = MPL_malloc(sizeof(MPIR_TSP_sched_t), MPL_MEM_COLL);
138     MPIR_ERR_CHKANDJUMP(!sched, mpi_errno, MPI_ERR_OTHER, "**nomem");
139     MPIR_TSP_sched_create(sched);
140 
141     mpi_errno =
142         MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, op, comm, sched);
143     MPIR_ERR_CHECK(mpi_errno);
144 
145     /* start and register the schedule */
146     mpi_errno = MPIR_TSP_sched_start(sched, comm, req);
147     MPIR_ERR_CHECK(mpi_errno);
148 
149   fn_exit:
150     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
151     return mpi_errno;
152   fn_fail:
153     goto fn_exit;
154 }
155