1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 /* Header protection (i.e., IALLGATHERV_TSP_RING_ALGOS_H_INCLUDED) is
7  * intentionally omitted since this header might get included multiple
8  * times within the same .c file. */
9 
10 #include "algo_common.h"
11 #include "tsp_namespace_def.h"
12 
13 /* Routine to schedule a recursive exchange based allgatherv */
MPIR_TSP_Iallgatherv_sched_intra_ring(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,const int * recvcounts,const int * displs,MPI_Datatype recvtype,MPIR_Comm * comm,MPIR_TSP_sched_t * sched)14 int MPIR_TSP_Iallgatherv_sched_intra_ring(const void *sendbuf, int sendcount,
15                                           MPI_Datatype sendtype, void *recvbuf,
16                                           const int *recvcounts, const int *displs,
17                                           MPI_Datatype recvtype, MPIR_Comm * comm,
18                                           MPIR_TSP_sched_t * sched)
19 {
20     size_t extent;
21     MPI_Aint lb, true_extent;
22     int mpi_errno = MPI_SUCCESS;
23     int i, src, dst;
24     int nranks, is_inplace, rank;
25     int send_rank, recv_rank;
26     void *data_buf, *buf1, *buf2, *sbuf, *rbuf;
27     int max_count;
28     int tag;
29 
30     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLGATHERV_SCHED_INTRA_RING);
31     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLGATHERV_SCHED_INTRA_RING);
32 
33     is_inplace = (sendbuf == MPI_IN_PLACE);
34     nranks = MPIR_Comm_size(comm);
35     rank = MPIR_Comm_rank(comm);
36 
37     /* find out the buffer which has the send data and point data_buf to it */
38     if (is_inplace) {
39         sendcount = recvcounts[rank];
40         sendtype = recvtype;
41         data_buf = (char *) recvbuf;
42     } else
43         data_buf = (char *) sendbuf;
44 
45     MPIR_Datatype_get_extent_macro(recvtype, extent);
46     MPIR_Type_get_true_extent_impl(recvtype, &lb, &true_extent);
47     extent = MPL_MAX(extent, true_extent);
48 
49     max_count = recvcounts[0];
50     for (i = 1; i < nranks; i++) {
51         if (recvcounts[i] > max_count)
52             max_count = recvcounts[i];
53     }
54 
55     /* allocate space for temporary buffers to accommodate the largest recvcount */
56     buf1 = MPIR_TSP_sched_malloc(max_count * extent, sched);
57     buf2 = MPIR_TSP_sched_malloc(max_count * extent, sched);
58 
59     /* Phase 1: copy data to buf1 from sendbuf or recvbuf(in case of inplace) */
60     int dtcopy_id[3];
61     if (is_inplace) {
62         dtcopy_id[0] =
63             MPIR_TSP_sched_localcopy((char *) data_buf + displs[rank] * extent, sendcount, sendtype,
64                                      buf1, recvcounts[rank], recvtype, sched, 0, NULL);
65     } else {
66         /* copy your data into your recvbuf from your sendbuf */
67         MPIR_TSP_sched_localcopy(sendbuf, sendcount, sendtype,
68                                  (char *) recvbuf + displs[rank] * extent, recvcounts[rank],
69                                  recvtype, sched, 0, NULL);
70         /* copy data from sendbuf to tmp_sendbuf to send the data */
71         dtcopy_id[0] =
72             MPIR_TSP_sched_localcopy(sendbuf, sendcount, sendtype, buf1, recvcounts[rank], recvtype,
73                                      sched, 0, NULL);
74     }
75 
76     src = (nranks + rank - 1) % nranks;
77     dst = (rank + 1) % nranks;
78 
79     sbuf = buf1;
80     rbuf = buf2;
81 
82     int send_id[3];
83     int recv_id[3] = { 0 };     /* warning fix: icc: maybe used before set */
84     for (i = 0; i < nranks - 1; i++) {
85         recv_rank = (rank - i - 1 + nranks) % nranks;   /* Rank whose data you're receiving */
86         send_rank = (rank - i + nranks) % nranks;       /* Rank whose data you're sending */
87 
88         /* New tag for each send-recv pair. */
89         mpi_errno = MPIR_Sched_next_tag(comm, &tag);
90         MPIR_ERR_CHECK(mpi_errno);
91 
92         int nvtcs, vtcs[3];
93         if (i == 0) {
94             nvtcs = 1;
95             vtcs[0] = dtcopy_id[0];
96 
97             send_id[i % 3] =
98                 MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched,
99                                      nvtcs, vtcs);
100 
101             nvtcs = 0;
102         } else {
103             nvtcs = 2;
104             vtcs[0] = recv_id[(i - 1) % 3];
105             vtcs[1] = send_id[(i - 1) % 3];
106 
107             send_id[i % 3] =
108                 MPIR_TSP_sched_isend(sbuf, recvcounts[send_rank], recvtype, dst, tag, comm, sched,
109                                      nvtcs, vtcs);
110 
111             if (i == 1) {
112                 nvtcs = 2;
113                 vtcs[0] = send_id[0];
114                 vtcs[1] = recv_id[0];
115             } else {
116                 nvtcs = 3;
117                 vtcs[0] = send_id[(i - 1) % 3];
118                 vtcs[1] = dtcopy_id[(i - 2) % 3];
119                 vtcs[2] = recv_id[(i - 1) % 3];
120             }
121         }
122 
123         recv_id[i % 3] =
124             MPIR_TSP_sched_irecv(rbuf, recvcounts[recv_rank], recvtype, src, tag, comm, sched,
125                                  nvtcs, vtcs);
126 
127         /* Copy to correct position in recvbuf */
128         dtcopy_id[i % 3] =
129             MPIR_TSP_sched_localcopy(rbuf, recvcounts[recv_rank], recvtype,
130                                      (char *) recvbuf + displs[recv_rank] * extent,
131                                      recvcounts[recv_rank], recvtype, sched, 1, &recv_id[i % 3]);
132 
133         data_buf = sbuf;
134         sbuf = rbuf;
135         rbuf = data_buf;
136 
137     }
138 
139     MPIR_TSP_sched_fence(sched);
140 
141   fn_exit:
142     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLGATHERV_SCHED_INTRA_RING);
143     return mpi_errno;
144   fn_fail:
145     goto fn_exit;
146 }
147 
148 
149 /* Non-blocking ring based Allgatherv */
MPIR_TSP_Iallgatherv_intra_ring(const void * sendbuf,int sendcount,MPI_Datatype sendtype,void * recvbuf,const int * recvcounts,const int * displs,MPI_Datatype recvtype,MPIR_Comm * comm,MPIR_Request ** req)150 int MPIR_TSP_Iallgatherv_intra_ring(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
151                                     void *recvbuf, const int *recvcounts, const int *displs,
152                                     MPI_Datatype recvtype, MPIR_Comm * comm, MPIR_Request ** req)
153 {
154     int mpi_errno = MPI_SUCCESS;
155     MPIR_TSP_sched_t *sched;
156 
157     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLGATHERV_INTRA_RING);
158     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLGATHERV_INTRA_RING);
159 
160     *req = NULL;
161 
162     /* generate the schedule */
163     sched = MPL_malloc(sizeof(MPIR_TSP_sched_t), MPL_MEM_COLL);
164     MPIR_ERR_CHKANDJUMP(!sched, mpi_errno, MPI_ERR_OTHER, "**nomem");
165     MPIR_TSP_sched_create(sched);
166 
167     mpi_errno =
168         MPIR_TSP_Iallgatherv_sched_intra_ring(sendbuf, sendcount, sendtype, recvbuf, recvcounts,
169                                               displs, recvtype, comm, sched);
170     MPIR_ERR_CHECK(mpi_errno);
171 
172     /* start and register the schedule */
173     mpi_errno = MPIR_TSP_sched_start(sched, comm, req);
174     MPIR_ERR_CHECK(mpi_errno);
175 
176   fn_exit:
177     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLGATHERV_INTRA_RING);
178     return mpi_errno;
179   fn_fail:
180     goto fn_exit;
181 }
182