1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 /* Header protection (i.e., IALLREDUCE_TSP_RING_ALGOS_H_INCLUDED) is
7 * intentionally omitted since this header might get included multiple
8 * times within the same .c file. */
9
10 #include "algo_common.h"
11 #include "tsp_namespace_def.h"
12 #include "../iallgatherv/iallgatherv_tsp_ring_algos_prototypes.h"
13
14 /* Routine to schedule a ring exchange based allreduce.
15 * The implementation is based on Baidu's ring algorithm
16 * for Machine Learning/Deep Learning. The algorithm is
17 * explained here: http://andrew.gibiansky.com/ */
MPIR_TSP_Iallreduce_sched_intra_ring(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm,MPIR_TSP_sched_t * sched)18 int MPIR_TSP_Iallreduce_sched_intra_ring(const void *sendbuf, void *recvbuf, int count,
19 MPI_Datatype datatype, MPI_Op op,
20 MPIR_Comm * comm, MPIR_TSP_sched_t * sched)
21 {
22 int mpi_errno = MPI_SUCCESS;
23 int i, src, dst;
24 int nranks, is_inplace, rank;
25 size_t extent;
26 MPI_Aint lb, true_extent;
27 int *cnts, *displs, recv_id, *reduce_id, nvtcs, vtcs;
28 int send_rank, recv_rank, total_count;
29 void *tmpbuf;
30 int tag;
31
32 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
33 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
34 MPIR_CHKLMEM_DECL(4);
35
36 is_inplace = (sendbuf == MPI_IN_PLACE);
37 nranks = MPIR_Comm_size(comm);
38 rank = MPIR_Comm_rank(comm);
39
40 MPIR_Datatype_get_extent_macro(datatype, extent);
41 MPIR_Type_get_true_extent_impl(datatype, &lb, &true_extent);
42 extent = MPL_MAX(extent, true_extent);
43
44 MPIR_CHKLMEM_MALLOC(cnts, int *, nranks * sizeof(int), mpi_errno, "cnts", MPL_MEM_COLL);
45 MPIR_CHKLMEM_MALLOC(displs, int *, nranks * sizeof(int), mpi_errno, "displs", MPL_MEM_COLL);
46
47 for (i = 0; i < nranks; i++)
48 cnts[i] = 0;
49
50 total_count = 0;
51 for (i = 0; i < nranks; i++) {
52 cnts[i] = (count + nranks - 1) / nranks;
53 if (total_count + cnts[i] > count) {
54 cnts[i] = count - total_count;
55 break;
56 } else
57 total_count += cnts[i];
58 }
59
60 displs[0] = 0;
61 for (i = 1; i < nranks; i++)
62 displs[i] = displs[i - 1] + cnts[i - 1];
63
64 /* Phase 1: copy to tmp buf */
65 if (!is_inplace) {
66 MPIR_TSP_sched_localcopy(sendbuf, count, datatype, recvbuf, count, datatype, sched, 0,
67 NULL);
68 MPIR_TSP_sched_fence(sched);
69 }
70
71 /* Phase 2: Ring based send recv reduce scatter */
72 /* Need only 2 spaces for current and previous reduce_id(s) */
73 MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL);
74 tmpbuf = MPIR_TSP_sched_malloc(count * extent, sched);
75
76 src = (nranks + rank - 1) % nranks;
77 dst = (rank + 1) % nranks;
78
79 for (i = 0; i < nranks - 1; i++) {
80 recv_rank = (nranks + rank - 2 - i) % nranks;
81 send_rank = (nranks + rank - 1 - i) % nranks;
82
83 /* get a new tag to prevent out of order messages */
84 mpi_errno = MPIR_Sched_next_tag(comm, &tag);
85 MPIR_ERR_CHECK(mpi_errno);
86
87 nvtcs = (i == 0) ? 0 : 1;
88 vtcs = (i == 0) ? 0 : reduce_id[(i - 1) % 2];
89 recv_id =
90 MPIR_TSP_sched_irecv(tmpbuf, cnts[recv_rank], datatype, src, tag, comm, sched, nvtcs,
91 &vtcs);
92
93 reduce_id[i % 2] =
94 MPIR_TSP_sched_reduce_local(tmpbuf, (char *) recvbuf + displs[recv_rank] * extent,
95 cnts[recv_rank], datatype, op, sched, 1, &recv_id);
96
97 MPIR_TSP_sched_isend((char *) recvbuf + displs[send_rank] * extent, cnts[send_rank],
98 datatype, dst, tag, comm, sched, nvtcs, &vtcs);
99
100 MPL_DBG_MSG_FMT(MPIR_DBG_COLL, VERBOSE,
101 (MPL_DBG_FDEST,
102 "displs[recv_rank:%d]:%d, cnts[recv_rank:%d, displs[send_rank:%d]:%d, cnts[send_rank:%d]:%d]:%d ",
103 recv_rank, displs[recv_rank], recv_rank, cnts[recv_rank], send_rank,
104 displs[send_rank], send_rank, cnts[send_rank]));
105 }
106 MPIR_CHKLMEM_MALLOC(reduce_id, int *, 2 * sizeof(int), mpi_errno, "reduce_id", MPL_MEM_COLL);
107
108 MPIR_TSP_sched_fence(sched);
109
110 /* Phase 3: Allgatherv ring, so everyone has the reduced data */
111 MPIR_TSP_Iallgatherv_sched_intra_ring(MPI_IN_PLACE, -1, MPI_DATATYPE_NULL, recvbuf, cnts,
112 displs, datatype, comm, sched);
113
114 MPIR_CHKLMEM_FREEALL();
115
116 fn_exit:
117 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLREDUCE_SCHED_INTRA_RING);
118 return mpi_errno;
119
120 fn_fail:
121 goto fn_exit;
122 }
123
124 /* Non-blocking ring based Allreduce */
MPIR_TSP_Iallreduce_intra_ring(const void * sendbuf,void * recvbuf,int count,MPI_Datatype datatype,MPI_Op op,MPIR_Comm * comm,MPIR_Request ** req)125 int MPIR_TSP_Iallreduce_intra_ring(const void *sendbuf, void *recvbuf, int count,
126 MPI_Datatype datatype, MPI_Op op, MPIR_Comm * comm,
127 MPIR_Request ** req)
128 {
129 int mpi_errno = MPI_SUCCESS;
130 MPIR_TSP_sched_t *sched;
131 *req = NULL;
132
133 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
134 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
135
136 /* generate the schedule */
137 sched = MPL_malloc(sizeof(MPIR_TSP_sched_t), MPL_MEM_COLL);
138 MPIR_ERR_CHKANDJUMP(!sched, mpi_errno, MPI_ERR_OTHER, "**nomem");
139 MPIR_TSP_sched_create(sched);
140
141 mpi_errno =
142 MPIR_TSP_Iallreduce_sched_intra_ring(sendbuf, recvbuf, count, datatype, op, comm, sched);
143 MPIR_ERR_CHECK(mpi_errno);
144
145 /* start and register the schedule */
146 mpi_errno = MPIR_TSP_sched_start(sched, comm, req);
147 MPIR_ERR_CHECK(mpi_errno);
148
149 fn_exit:
150 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLREDUCE_INTRA_RING);
151 return mpi_errno;
152 fn_fail:
153 goto fn_exit;
154 }
155