1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 /* Header protection (i.e., IALLTOALLV_TSP_ALGOS_H_INCLUDED) is
7 * intentionally omitted since this header might get included multiple
8 * times within the same .c file. */
9
10 #include "tsp_namespace_def.h"
11
12 /* Routine to schedule a scattered based alltoallv */
13 /* Alltoallv doesn't support MPI_IN_PLACE */
MPIR_TSP_Ialltoallv_sched_intra_scattered(const void * sendbuf,const int sendcounts[],const int sdispls[],MPI_Datatype sendtype,void * recvbuf,const int recvcounts[],const int rdispls[],MPI_Datatype recvtype,MPIR_Comm * comm,int batch_size,int bblock,MPIR_TSP_sched_t * sched)14 int MPIR_TSP_Ialltoallv_sched_intra_scattered(const void *sendbuf, const int sendcounts[],
15 const int sdispls[], MPI_Datatype sendtype,
16 void *recvbuf, const int recvcounts[],
17 const int rdispls[], MPI_Datatype recvtype,
18 MPIR_Comm * comm, int batch_size, int bblock,
19 MPIR_TSP_sched_t * sched)
20 {
21 int mpi_errno = MPI_SUCCESS;
22 int src, dst;
23 int i, j, ww;
24 int invtcs;
25 int tag;
26 int *vtcs, *recv_id, *send_id;
27 MPIR_CHKLMEM_DECL(3);
28
29 MPIR_Assert(!(sendbuf == MPI_IN_PLACE));
30
31 int size = MPIR_Comm_size(comm);
32 int rank = MPIR_Comm_rank(comm);
33
34 MPI_Aint recvtype_lb, recvtype_extent;
35 MPI_Aint sendtype_lb, sendtype_extent;
36 MPI_Aint sendtype_true_extent, recvtype_true_extent;
37
38 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
39 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
40
41 /* Alltoall: each process need send #size msgs and recv #size msgs
42 * This algorithm does #bblock send/recv first
43 * followed by addition batchs of #batch_size each.
44 */
45 if (bblock > size)
46 bblock = size;
47
48 /* vtcs is twice the batch size to store both send and recv ids */
49 MPIR_CHKLMEM_MALLOC(vtcs, int *, 2 * batch_size * sizeof(int), mpi_errno, "vtcs", MPL_MEM_COLL);
50 MPIR_CHKLMEM_MALLOC(recv_id, int *, bblock * sizeof(int), mpi_errno, "recv_id", MPL_MEM_COLL);
51 MPIR_CHKLMEM_MALLOC(send_id, int *, bblock * sizeof(int), mpi_errno, "send_id", MPL_MEM_COLL);
52
53 /* Get datatype info of sendtype and recvtype */
54 MPIR_Datatype_get_extent_macro(recvtype, recvtype_extent);
55 MPIR_Type_get_true_extent_impl(recvtype, &recvtype_lb, &recvtype_true_extent);
56 recvtype_extent = MPL_MAX(recvtype_extent, recvtype_true_extent);
57
58 MPIR_Datatype_get_extent_macro(sendtype, sendtype_extent);
59 MPIR_Type_get_true_extent_impl(sendtype, &sendtype_lb, &sendtype_true_extent);
60 sendtype_extent = MPL_MAX(sendtype_extent, sendtype_true_extent);
61
62 mpi_errno = MPIR_Sched_next_tag(comm, &tag);
63 MPIR_ERR_CHECK(mpi_errno);
64
65 /* First, post bblock number of sends/recvs */
66 for (i = 0; i < bblock; i++) {
67 src = (rank + i) % size;
68 recv_id[i] =
69 MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent,
70 recvcounts[src], recvtype, src, tag, comm, sched, 0, NULL);
71 dst = (rank - i + size) % size;
72 send_id[i] =
73 MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent,
74 sendcounts[dst], sendtype, dst, tag, comm, sched, 0, NULL);
75 }
76
77 /* Post more send/recv pairs as the previous ones finish */
78 for (i = bblock; i < size; i += batch_size) {
79 int i_vtcs = 0;
80 ww = MPL_MIN(size - i, batch_size);
81 /* Add dependency to ensure not to run until previous block the batch portion
82 * is finished -- effectively limiting the on-going tasks to bblock
83 */
84 for (j = 0; j < ww; j++) {
85 vtcs[i_vtcs++] = recv_id[(i + j) % bblock];
86 vtcs[i_vtcs++] = send_id[(i + j) % bblock];
87 }
88 invtcs = MPIR_TSP_sched_selective_sink(sched, 2 * ww, vtcs);
89 for (j = 0; j < ww; j++) {
90 src = (rank + i + j) % size;
91 recv_id[(i + j) % bblock] =
92 MPIR_TSP_sched_irecv((char *) recvbuf + rdispls[src] * recvtype_extent,
93 recvcounts[src], recvtype, src, tag, comm, sched, 1, &invtcs);
94 dst = (rank - i - j + size) % size;
95 send_id[(i + j) % bblock] =
96 MPIR_TSP_sched_isend((char *) sendbuf + sdispls[dst] * sendtype_extent,
97 sendcounts[dst], sendtype, dst, tag, comm, sched, 1, &invtcs);
98 }
99 }
100
101 fn_exit:
102 MPIR_CHKLMEM_FREEALL();
103 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_SCHED_INTRA_SCATTERED);
104 return mpi_errno;
105
106 fn_fail:
107 goto fn_exit;
108 }
109
110 /* Scattered sliding window based Alltoallv */
MPIR_TSP_Ialltoallv_intra_scattered(const void * sendbuf,const int sendcounts[],const int sdispls[],MPI_Datatype sendtype,void * recvbuf,const int recvcounts[],const int rdispls[],MPI_Datatype recvtype,MPIR_Comm * comm,int batch_size,int bblock,MPIR_Request ** req)111 int MPIR_TSP_Ialltoallv_intra_scattered(const void *sendbuf, const int sendcounts[],
112 const int sdispls[], MPI_Datatype sendtype, void *recvbuf,
113 const int recvcounts[], const int rdispls[],
114 MPI_Datatype recvtype, MPIR_Comm * comm, int batch_size,
115 int bblock, MPIR_Request ** req)
116 {
117 int mpi_errno = MPI_SUCCESS;
118 MPIR_TSP_sched_t *sched;
119 *req = NULL;
120
121 MPIR_FUNC_VERBOSE_STATE_DECL(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
122 MPIR_FUNC_VERBOSE_ENTER(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
123
124 /* Generate the schedule */
125 sched = MPL_malloc(sizeof(MPIR_TSP_sched_t), MPL_MEM_COLL);
126 MPIR_ERR_CHKANDJUMP(!sched, mpi_errno, MPI_ERR_OTHER, "**nomem");
127 MPIR_TSP_sched_create(sched);
128
129 mpi_errno =
130 MPIR_TSP_Ialltoallv_sched_intra_scattered(sendbuf, sendcounts, sdispls, sendtype,
131 recvbuf, recvcounts, rdispls, recvtype, comm,
132 batch_size, bblock, sched);
133 MPIR_ERR_CHECK(mpi_errno);
134
135 /* Start and register the schedule */
136 mpi_errno = MPIR_TSP_sched_start(sched, comm, req);
137 MPIR_ERR_CHECK(mpi_errno);
138
139 fn_exit:
140 MPIR_FUNC_VERBOSE_EXIT(MPID_STATE_MPIR_TSP_IALLTOALLV_INTRA_SCATTERED);
141 return mpi_errno;
142 fn_fail:
143 goto fn_exit;
144 }
145