1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include "mpiimpl.h"
7 #include "hcoll.h"
8
9 /*
10 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
11
12 cvars:
13 - name : MPIR_CVAR_ENABLE_HCOLL
14 category : COLLECTIVE
15 type : boolean
16 default : false
17 class : none
18 verbosity : MPI_T_VERBOSITY_USER_BASIC
19 scope : MPI_T_SCOPE_LOCAL
20 description : >-
21 Enable hcoll collective support.
22
23 === END_MPI_T_CVAR_INFO_BLOCK ===
24 */
25
26 static int hcoll_comm_world_initialized = 0;
27 static int hcoll_progress_hook_id = 0;
28
29 int hcoll_initialized = 0;
30 int hcoll_enable = -1;
31 int hcoll_enable_barrier = 1;
32 int hcoll_enable_bcast = 1;
33 int hcoll_enable_reduce = 1;
34 int hcoll_enable_allgather = 1;
35 int hcoll_enable_allreduce = 1;
36 int hcoll_enable_alltoall = 1;
37 int hcoll_enable_alltoallv = 1;
38 int hcoll_enable_ibarrier = 1;
39 int hcoll_enable_ibcast = 1;
40 int hcoll_enable_iallgather = 1;
41 int hcoll_enable_iallreduce = 1;
42 int world_comm_destroying = 0;
43
44 #if defined(MPL_USE_DBG_LOGGING)
45 MPL_dbg_class MPIR_DBG_HCOLL;
46 #endif /* MPL_USE_DBG_LOGGING */
47
48 void hcoll_rte_fns_setup(void);
49
50
hcoll_destroy(void * param ATTRIBUTE ((unused)))51 int hcoll_destroy(void *param ATTRIBUTE((unused)))
52 {
53 if (1 == hcoll_initialized) {
54 hcoll_finalize();
55 MPIR_Progress_hook_deactivate(hcoll_progress_hook_id);
56 MPIR_Progress_hook_deregister(hcoll_progress_hook_id);
57 }
58 hcoll_initialized = 0;
59 return 0;
60 }
61
62 #define CHECK_ENABLE_ENV_VARS(nameEnv, name) \
63 do { \
64 envar = getenv("HCOLL_ENABLE_" #nameEnv); \
65 if (NULL != envar) { \
66 hcoll_enable_##name = atoi(envar); \
67 MPL_DBG_MSG_D(MPIR_DBG_HCOLL, VERBOSE, "HCOLL_ENABLE_" #nameEnv " = %d\n", hcoll_enable_##name); \
68 } \
69 } while (0)
70
hcoll_initialize(void)71 int hcoll_initialize(void)
72 {
73 int mpi_errno;
74 char *envar;
75 hcoll_init_opts_t *init_opts;
76 mpi_errno = MPI_SUCCESS;
77
78 hcoll_enable = (MPIR_CVAR_ENABLE_HCOLL | MPIR_CVAR_CH3_ENABLE_HCOLL) &&
79 MPIR_ThreadInfo.thread_provided != MPI_THREAD_MULTIPLE;
80 if (0 >= hcoll_enable) {
81 goto fn_exit;
82 }
83 #if defined(MPL_USE_DBG_LOGGING)
84 MPIR_DBG_HCOLL = MPL_dbg_class_alloc("HCOLL", "hcoll");
85 #endif /* MPL_USE_DBG_LOGGING */
86
87 hcoll_rte_fns_setup();
88
89 hcoll_read_init_opts(&init_opts);
90 init_opts->base_tag = MPIR_FIRST_HCOLL_TAG;
91 init_opts->max_tag = MPIR_LAST_HCOLL_TAG;
92
93 init_opts->enable_thread_support = MPIR_IS_THREADED;
94
95 mpi_errno = hcoll_init_with_opts(&init_opts);
96 MPIR_ERR_CHECK(mpi_errno);
97
98 if (!hcoll_initialized) {
99 hcoll_initialized = 1;
100 mpi_errno = MPIR_Progress_hook_register(hcoll_do_progress, &hcoll_progress_hook_id);
101 MPIR_ERR_CHECK(mpi_errno);
102
103 MPIR_Progress_hook_activate(hcoll_progress_hook_id);
104 }
105 MPIR_Add_finalize(hcoll_destroy, 0, 0);
106
107 CHECK_ENABLE_ENV_VARS(BARRIER, barrier);
108 CHECK_ENABLE_ENV_VARS(BCAST, bcast);
109 CHECK_ENABLE_ENV_VARS(REDUCE, reduce);
110 CHECK_ENABLE_ENV_VARS(ALLGATHER, allgather);
111 CHECK_ENABLE_ENV_VARS(ALLREDUCE, allreduce);
112 CHECK_ENABLE_ENV_VARS(ALLTOALL, alltoall);
113 CHECK_ENABLE_ENV_VARS(ALLTOALLV, alltoallv);
114 CHECK_ENABLE_ENV_VARS(IBARRIER, ibarrier);
115 CHECK_ENABLE_ENV_VARS(IBCAST, ibcast);
116 CHECK_ENABLE_ENV_VARS(IALLGATHER, iallgather);
117 CHECK_ENABLE_ENV_VARS(IALLREDUCE, iallreduce);
118 fn_exit:
119 return mpi_errno;
120 fn_fail:
121 goto fn_exit;
122 }
123
124
hcoll_comm_create(MPIR_Comm * comm_ptr,void * param)125 int hcoll_comm_create(MPIR_Comm * comm_ptr, void *param)
126 {
127 int mpi_errno;
128 int num_ranks;
129 int context_destroyed;
130 mpi_errno = MPI_SUCCESS;
131
132 if (0 == hcoll_initialized) {
133 mpi_errno = hcoll_initialize();
134 MPIR_ERR_CHECK(mpi_errno);
135 }
136
137 if (0 == hcoll_enable) {
138 comm_ptr->hcoll_priv.is_hcoll_init = 0;
139 goto fn_exit;
140 }
141
142 if (MPIR_Process.comm_world == comm_ptr) {
143 hcoll_comm_world_initialized = 1;
144 }
145 if (!hcoll_comm_world_initialized) {
146 comm_ptr->hcoll_priv.is_hcoll_init = 0;
147 goto fn_exit;
148 }
149 num_ranks = comm_ptr->local_size;
150 if ((MPIR_COMM_KIND__INTRACOMM != comm_ptr->comm_kind) || (2 > num_ranks)
151 || comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE_ROOTS
152 || comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE) {
153 comm_ptr->hcoll_priv.is_hcoll_init = 0;
154 goto fn_exit;
155 }
156
157 comm_ptr->hcoll_priv.hcoll_context = hcoll_create_context((rte_grp_handle_t) comm_ptr);
158 if (NULL == comm_ptr->hcoll_priv.hcoll_context) {
159 MPL_DBG_MSG(MPIR_DBG_HCOLL, VERBOSE, "Couldn't create hcoll context.");
160 goto fn_fail;
161 }
162
163 comm_ptr->hcoll_priv.is_hcoll_init = 1;
164 fn_exit:
165 return mpi_errno;
166 fn_fail:
167 goto fn_exit;
168 }
169
hcoll_comm_destroy(MPIR_Comm * comm_ptr,void * param)170 int hcoll_comm_destroy(MPIR_Comm * comm_ptr, void *param)
171 {
172 int mpi_errno;
173 int context_destroyed;
174 if (0 >= hcoll_enable) {
175 goto fn_exit;
176 }
177 mpi_errno = MPI_SUCCESS;
178
179 if (comm_ptr->handle == MPI_COMM_WORLD)
180 world_comm_destroying = 1;
181
182 context_destroyed = 0;
183 if ((NULL != comm_ptr) && (0 != comm_ptr->hcoll_priv.is_hcoll_init)) {
184 hcoll_destroy_context(comm_ptr->hcoll_priv.hcoll_context,
185 (rte_grp_handle_t) comm_ptr, &context_destroyed);
186 comm_ptr->hcoll_priv.is_hcoll_init = 0;
187 }
188 fn_exit:
189 return mpi_errno;
190 fn_fail:
191 goto fn_exit;
192 }
193
hcoll_do_progress(int * made_progress)194 int hcoll_do_progress(int *made_progress)
195 {
196 *made_progress = 1;
197
198 /* hcoll_progress_fn() has been deprecated since v4.0. */
199 #if HCOLL_API < HCOLL_VERSION(4,0)
200 MPID_THREAD_CS_ENTER(VCI, MPIDIU_THREAD_HCOLL_MUTEX);
201 hcoll_progress_fn();
202 MPID_THREAD_CS_EXIT(VCI, MPIDIU_THREAD_HCOLL_MUTEX);
203 #endif
204 return MPI_SUCCESS;
205 }
206