1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 #include "hcoll.h"
8 
9 /*
10 === BEGIN_MPI_T_CVAR_INFO_BLOCK ===
11 
12 cvars:
13     - name        : MPIR_CVAR_ENABLE_HCOLL
14       category    : COLLECTIVE
15       type        : boolean
16       default     : false
17       class       : none
18       verbosity   : MPI_T_VERBOSITY_USER_BASIC
19       scope       : MPI_T_SCOPE_LOCAL
20       description : >-
21         Enable hcoll collective support.
22 
23 === END_MPI_T_CVAR_INFO_BLOCK ===
24 */
25 
26 static int hcoll_comm_world_initialized = 0;
27 static int hcoll_progress_hook_id = 0;
28 
29 int hcoll_initialized = 0;
30 int hcoll_enable = -1;
31 int hcoll_enable_barrier = 1;
32 int hcoll_enable_bcast = 1;
33 int hcoll_enable_reduce = 1;
34 int hcoll_enable_allgather = 1;
35 int hcoll_enable_allreduce = 1;
36 int hcoll_enable_alltoall = 1;
37 int hcoll_enable_alltoallv = 1;
38 int hcoll_enable_ibarrier = 1;
39 int hcoll_enable_ibcast = 1;
40 int hcoll_enable_iallgather = 1;
41 int hcoll_enable_iallreduce = 1;
42 int world_comm_destroying = 0;
43 
44 #if defined(MPL_USE_DBG_LOGGING)
45 MPL_dbg_class MPIR_DBG_HCOLL;
46 #endif /* MPL_USE_DBG_LOGGING */
47 
48 void hcoll_rte_fns_setup(void);
49 
50 
hcoll_destroy(void * param ATTRIBUTE ((unused)))51 int hcoll_destroy(void *param ATTRIBUTE((unused)))
52 {
53     if (1 == hcoll_initialized) {
54         hcoll_finalize();
55         MPIR_Progress_hook_deactivate(hcoll_progress_hook_id);
56         MPIR_Progress_hook_deregister(hcoll_progress_hook_id);
57     }
58     hcoll_initialized = 0;
59     return 0;
60 }
61 
62 #define CHECK_ENABLE_ENV_VARS(nameEnv, name) \
63     do { \
64         envar = getenv("HCOLL_ENABLE_" #nameEnv); \
65         if (NULL != envar) { \
66             hcoll_enable_##name = atoi(envar); \
67             MPL_DBG_MSG_D(MPIR_DBG_HCOLL, VERBOSE, "HCOLL_ENABLE_" #nameEnv " = %d\n", hcoll_enable_##name); \
68         } \
69     } while (0)
70 
hcoll_initialize(void)71 int hcoll_initialize(void)
72 {
73     int mpi_errno;
74     char *envar;
75     hcoll_init_opts_t *init_opts;
76     mpi_errno = MPI_SUCCESS;
77 
78     hcoll_enable = (MPIR_CVAR_ENABLE_HCOLL | MPIR_CVAR_CH3_ENABLE_HCOLL) &&
79         MPIR_ThreadInfo.thread_provided != MPI_THREAD_MULTIPLE;
80     if (0 >= hcoll_enable) {
81         goto fn_exit;
82     }
83 #if defined(MPL_USE_DBG_LOGGING)
84     MPIR_DBG_HCOLL = MPL_dbg_class_alloc("HCOLL", "hcoll");
85 #endif /* MPL_USE_DBG_LOGGING */
86 
87     hcoll_rte_fns_setup();
88 
89     hcoll_read_init_opts(&init_opts);
90     init_opts->base_tag = MPIR_FIRST_HCOLL_TAG;
91     init_opts->max_tag = MPIR_LAST_HCOLL_TAG;
92 
93     init_opts->enable_thread_support = MPIR_IS_THREADED;
94 
95     mpi_errno = hcoll_init_with_opts(&init_opts);
96     MPIR_ERR_CHECK(mpi_errno);
97 
98     if (!hcoll_initialized) {
99         hcoll_initialized = 1;
100         mpi_errno = MPIR_Progress_hook_register(hcoll_do_progress, &hcoll_progress_hook_id);
101         MPIR_ERR_CHECK(mpi_errno);
102 
103         MPIR_Progress_hook_activate(hcoll_progress_hook_id);
104     }
105     MPIR_Add_finalize(hcoll_destroy, 0, 0);
106 
107     CHECK_ENABLE_ENV_VARS(BARRIER, barrier);
108     CHECK_ENABLE_ENV_VARS(BCAST, bcast);
109     CHECK_ENABLE_ENV_VARS(REDUCE, reduce);
110     CHECK_ENABLE_ENV_VARS(ALLGATHER, allgather);
111     CHECK_ENABLE_ENV_VARS(ALLREDUCE, allreduce);
112     CHECK_ENABLE_ENV_VARS(ALLTOALL, alltoall);
113     CHECK_ENABLE_ENV_VARS(ALLTOALLV, alltoallv);
114     CHECK_ENABLE_ENV_VARS(IBARRIER, ibarrier);
115     CHECK_ENABLE_ENV_VARS(IBCAST, ibcast);
116     CHECK_ENABLE_ENV_VARS(IALLGATHER, iallgather);
117     CHECK_ENABLE_ENV_VARS(IALLREDUCE, iallreduce);
118   fn_exit:
119     return mpi_errno;
120   fn_fail:
121     goto fn_exit;
122 }
123 
124 
hcoll_comm_create(MPIR_Comm * comm_ptr,void * param)125 int hcoll_comm_create(MPIR_Comm * comm_ptr, void *param)
126 {
127     int mpi_errno;
128     int num_ranks;
129     int context_destroyed;
130     mpi_errno = MPI_SUCCESS;
131 
132     if (0 == hcoll_initialized) {
133         mpi_errno = hcoll_initialize();
134         MPIR_ERR_CHECK(mpi_errno);
135     }
136 
137     if (0 == hcoll_enable) {
138         comm_ptr->hcoll_priv.is_hcoll_init = 0;
139         goto fn_exit;
140     }
141 
142     if (MPIR_Process.comm_world == comm_ptr) {
143         hcoll_comm_world_initialized = 1;
144     }
145     if (!hcoll_comm_world_initialized) {
146         comm_ptr->hcoll_priv.is_hcoll_init = 0;
147         goto fn_exit;
148     }
149     num_ranks = comm_ptr->local_size;
150     if ((MPIR_COMM_KIND__INTRACOMM != comm_ptr->comm_kind) || (2 > num_ranks)
151         || comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE_ROOTS
152         || comm_ptr->hierarchy_kind == MPIR_COMM_HIERARCHY_KIND__NODE) {
153         comm_ptr->hcoll_priv.is_hcoll_init = 0;
154         goto fn_exit;
155     }
156 
157     comm_ptr->hcoll_priv.hcoll_context = hcoll_create_context((rte_grp_handle_t) comm_ptr);
158     if (NULL == comm_ptr->hcoll_priv.hcoll_context) {
159         MPL_DBG_MSG(MPIR_DBG_HCOLL, VERBOSE, "Couldn't create hcoll context.");
160         goto fn_fail;
161     }
162 
163     comm_ptr->hcoll_priv.is_hcoll_init = 1;
164   fn_exit:
165     return mpi_errno;
166   fn_fail:
167     goto fn_exit;
168 }
169 
hcoll_comm_destroy(MPIR_Comm * comm_ptr,void * param)170 int hcoll_comm_destroy(MPIR_Comm * comm_ptr, void *param)
171 {
172     int mpi_errno;
173     int context_destroyed;
174     if (0 >= hcoll_enable) {
175         goto fn_exit;
176     }
177     mpi_errno = MPI_SUCCESS;
178 
179     if (comm_ptr->handle == MPI_COMM_WORLD)
180         world_comm_destroying = 1;
181 
182     context_destroyed = 0;
183     if ((NULL != comm_ptr) && (0 != comm_ptr->hcoll_priv.is_hcoll_init)) {
184         hcoll_destroy_context(comm_ptr->hcoll_priv.hcoll_context,
185                               (rte_grp_handle_t) comm_ptr, &context_destroyed);
186         comm_ptr->hcoll_priv.is_hcoll_init = 0;
187     }
188   fn_exit:
189     return mpi_errno;
190   fn_fail:
191     goto fn_exit;
192 }
193 
hcoll_do_progress(int * made_progress)194 int hcoll_do_progress(int *made_progress)
195 {
196     *made_progress = 1;
197 
198     /* hcoll_progress_fn() has been deprecated since v4.0. */
199 #if HCOLL_API < HCOLL_VERSION(4,0)
200     MPID_THREAD_CS_ENTER(VCI, MPIDIU_THREAD_HCOLL_MUTEX);
201     hcoll_progress_fn();
202     MPID_THREAD_CS_EXIT(VCI, MPIDIU_THREAD_HCOLL_MUTEX);
203 #endif
204     return MPI_SUCCESS;
205 }
206