1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 /* Header protection (i.e., IALLTOALLV_TSP_ALGOS_H_INCLUDED) is
7  * intentionally omitted since this header might get included multiple
8  * times within the same .c file. */
9 
10 #include "tsp_namespace_def.h"
11 
12 /* Routine to schedule a scattered based alltoallv */
13 /* Alltoallv doesn't support MPI_IN_PLACE */
MPIR_TSP_Ialltoallv_sched_intra_scattered(const void * sendbuf,const int sendcounts[],const int sdispls[],MPI_Datatype sendtype,void * recvbuf,const int recvcounts[],const int rdispls[],MPI_Datatype recvtype,MPIR_Comm * comm,int batch_size,int bblock,MPIR_TSP_sched_t * sched)14 int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const int sendcounts[],
15                                               const int sdispls[], MPI_Datatype sendtype,
16                                               void *recvbuf, const int recvcounts[],
17                                               const int rdispls[], MPI_Datatype recvtype,
18                                               MPIR_Comm * comm, int batch_size, int bblock,
19                                               MPIR_TSP_sched_t * sched)
20 {
21     int mpi_errno = MPI_SUCCESS;
22     int src, dst;
23     int i, j, ww;
24     int invtcs;
25     int tag;
26     int *vtcs, *recv_id, *send_id;
27     MPIR_CHKLMEM_DECL(3);
28 
29     MPIR_Assert(!(sendbuf == MPI_IN_PLACE));
30 
31     int size = MPIR_Comm_size(comm);
32     int rank = MPIR_Comm_rank(comm);
33 
34     MPI_Aint recvtype_lb, recvtype_extent;
35     MPI_Aint sendtype_lb, sendtype_extent;
36     MPI_Aint sendtype_true_extent, recvtype_true_extent;
37 
38     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
39     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
40 
41     /* Alltoall: each process need send #size msgs and recv #size msgs
42      * This algorithm does #bblock send/recv first
43      * followed by addition batchs of #batch_size each.
44      */
45     if (bblock > size)
46         bblock = size;
47 
48     /* vtcs is twice the batch size to store both send and recv ids */
49     MPIR_CHKLMEM_MALLOC(vtcs, int *, 2 * batch_size * sizeof(int), mpi_errno, "vtcs", MPL_MEM_COLL);
50     MPIR_CHKLMEM_MALLOC(recv_id, int *, bblock * sizeof(int), mpi_errno, "recv_id", MPL_MEM_COLL);
51     MPIR_CHKLMEM_MALLOC(send_id, int *, bblock * sizeof(int), mpi_errno, "send_id", MPL_MEM_COLL);
52 
53     /* Get datatype info of sendtype and recvtype */
54     MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent);
55     MPIR_Type_get_true_extent_impl(recvtype, &recvtype_lb, &recvtype_true_extent);
56     recvtype_extent = MPL_MAX(recvtype_extent, recvtype_true_extent);
57 
58     MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent);
59     MPIR_Type_get_true_extent_impl(sendtype, &sendtype_lb, &sendtype_true_extent);
60     sendtype_extent = MPL_MAX(sendtype_extent, sendtype_true_extent);
61 
62     mpi_errno = MPIR_Sched_next_tag(comm, &tag);
63     MPIR_ERR_CHECK(mpi_errno);
64 
65     /* First, post bblock number of sends/recvs */
66     for (i = 0; i < bblock; i++) {
67         src = (rank + i) % size;
68         recv_id[i] =
69             MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent,
70                                  recvcounts[src], recvtype, src, tag, comm, sched, 0, NULL);
71         dst = (rank - i + size) % size;
72         send_id[i] =
73             MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent,
74                                  sendcounts[dst], sendtype, dst, tag, comm, sched, 0, NULL);
75     }
76 
77     /* Post more send/recv pairs as the previous ones finish */
78     for (i = bblock; i < size; i += batch_size) {
79         int i_vtcs = 0;
80         ww = MPL_MIN(size - i, batch_size);
81         /* Add dependency to ensure not to run until previous block the batch portion
82          * is finished -- effectively limiting the on-going tasks to bblock
83          */
84         for (j = 0; j < ww; j++) {
85             vtcs[i_vtcs++] = recv_id[(i + j) % bblock];
86             vtcs[i_vtcs++] = send_id[(i + j) % bblock];
87         }
88         invtcs = MPIR_TSP_sched_selective_sink(sched, 2 * ww, vtcs);
89         for (j = 0; j < ww; j++) {
90             src = (rank + i + j) % size;
91             recv_id[(i + j) % bblock] =
92                 MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent,
93                                      recvcounts[src], recvtype, src, tag, comm, sched, 1, &invtcs);
94             dst = (rank - i - j + size) % size;
95             send_id[(i + j) % bblock] =
96                 MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent,
97                                      sendcounts[dst], sendtype, dst, tag, comm, sched, 1, &invtcs);
98         }
99     }
100 
101   fn_exit:
102     MPIR_CHKLMEM_FREEALL();
103     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
104     return mpi_errno;
105 
106   fn_fail:
107     goto fn_exit;
108 }
109 
110 /* Scattered sliding window based Alltoallv */
MPIR_TSP_Ialltoallv_intra_scattered(const void * sendbuf,const int sendcounts[],const int sdispls[],MPI_Datatype sendtype,void * recvbuf,const int recvcounts[],const int rdispls[],MPI_Datatype recvtype,MPIR_Comm * comm,int batch_size,int bblock,MPIR_Request ** req)111 int MPIR_TSP_Ialltoallv_intra_scattered(const void *sendbuf, const int sendcounts[],
112                                         const int sdispls[], MPI_Datatype sendtype, void *recvbuf,
113                                         const int recvcounts[], const int rdispls[],
114                                         MPI_Datatype recvtype, MPIR_Comm * comm, int batch_size,
115                                         int bblock, MPIR_Request ** req)
116 {
117     int mpi_errno = MPI_SUCCESS;
118     MPIR_TSP_sched_t *sched;
119     *req = NULL;
120 
121     MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
122     MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
123 
124     /* Generate the schedule */
125     sched = MPL_malloc(sizeof(MPIR_TSP_sched_t), MPL_MEM_COLL);
126     MPIR_ERR_CHKANDJUMP(!sched, mpi_errno, MPI_ERR_OTHER, "**nomem");
127     MPIR_TSP_sched_create(sched);
128 
129     mpi_errno =
130         MPIR_TSP_Ialltoallv_sched_intra_scattered(sendbuf, sendcounts, sdispls, sendtype,
131                                                   recvbuf, recvcounts, rdispls, recvtype, comm,
132                                                   batch_size, bblock, sched);
133     MPIR_ERR_CHECK(mpi_errno);
134 
135     /* Start and register the schedule */
136     mpi_errno = MPIR_TSP_sched_start(sched, comm, req);
137     MPIR_ERR_CHECK(mpi_errno);
138 
139   fn_exit:
140     MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
141     return mpi_errno;
142   fn_fail:
143     goto fn_exit;
144 }
145