1 /*
2  * Copyright (C) Mellanox Technologies Ltd. 2001-2011.  ALL RIGHTS RESERVED.
3  * $COPYRIGHT$
4  *
5  * Additional copyrights may follow
6  *
7  * $HEADER$
8  */
9 
10 #include "pml_ucx_datatype.h"
11 #include "pml_ucx_request.h"
12 
13 #include "ompi/runtime/mpiruntime.h"
14 #include "ompi/attribute/attribute.h"
15 
16 #include <inttypes.h>
17 #include <math.h>
18 
19 #ifdef HAVE_UCP_REQUEST_PARAM_T
20 #define PML_UCX_DATATYPE_SET_VALUE(_datatype, _val) \
21     (_datatype)->op_param.send._val; \
22     (_datatype)->op_param.bsend._val; \
23     (_datatype)->op_param.recv._val;
24 #endif
25 
pml_ucx_generic_datatype_start_pack(void * context,const void * buffer,size_t count)26 static void* pml_ucx_generic_datatype_start_pack(void *context, const void *buffer,
27                                                  size_t count)
28 {
29     ompi_datatype_t *datatype = context;
30     mca_pml_ucx_convertor_t *convertor;
31 
32     convertor = (mca_pml_ucx_convertor_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.convs);
33 
34     OMPI_DATATYPE_RETAIN(datatype);
35     convertor->datatype = datatype;
36     opal_convertor_copy_and_prepare_for_send(ompi_proc_local_proc->super.proc_convertor,
37                                              &datatype->super, count, buffer, 0,
38                                              &convertor->opal_conv);
39     return convertor;
40 }
41 
pml_ucx_generic_datatype_start_unpack(void * context,void * buffer,size_t count)42 static void* pml_ucx_generic_datatype_start_unpack(void *context, void *buffer,
43                                                    size_t count)
44 {
45     ompi_datatype_t *datatype = context;
46     mca_pml_ucx_convertor_t *convertor;
47 
48     convertor = (mca_pml_ucx_convertor_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.convs);
49 
50     OMPI_DATATYPE_RETAIN(datatype);
51     convertor->datatype = datatype;
52     convertor->offset = 0;
53     opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
54                                              &datatype->super, count, buffer, 0,
55                                              &convertor->opal_conv);
56     return convertor;
57 }
58 
pml_ucx_generic_datatype_packed_size(void * state)59 static size_t pml_ucx_generic_datatype_packed_size(void *state)
60 {
61     mca_pml_ucx_convertor_t *convertor = state;
62     size_t size;
63 
64     opal_convertor_get_packed_size(&convertor->opal_conv, &size);
65     return size;
66 }
67 
pml_ucx_generic_datatype_pack(void * state,size_t offset,void * dest,size_t max_length)68 static size_t pml_ucx_generic_datatype_pack(void *state, size_t offset,
69                                             void *dest, size_t max_length)
70 {
71     mca_pml_ucx_convertor_t *convertor = state;
72     uint32_t iov_count;
73     struct iovec iov;
74     size_t length;
75 
76     iov_count    = 1;
77     iov.iov_base = dest;
78     iov.iov_len  = max_length;
79 
80     opal_convertor_set_position(&convertor->opal_conv, &offset);
81     length = max_length;
82     opal_convertor_pack(&convertor->opal_conv, &iov, &iov_count, &length);
83     return length;
84 }
85 
pml_ucx_generic_datatype_unpack(void * state,size_t offset,const void * src,size_t length)86 static ucs_status_t pml_ucx_generic_datatype_unpack(void *state, size_t offset,
87                                                     const void *src, size_t length)
88 {
89     mca_pml_ucx_convertor_t *convertor = state;
90 
91     uint32_t iov_count;
92     struct iovec iov;
93     opal_convertor_t conv;
94 
95     iov_count    = 1;
96     iov.iov_base = (void*)src;
97     iov.iov_len  = length;
98 
99     /* in case if unordered message arrived - create separate convertor to
100      * unpack data. */
101     if (offset != convertor->offset) {
102         OBJ_CONSTRUCT(&conv, opal_convertor_t);
103         opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
104                                                  &convertor->datatype->super,
105                                                  convertor->opal_conv.count,
106                                                  convertor->opal_conv.pBaseBuf, 0,
107                                                  &conv);
108         opal_convertor_set_position(&conv, &offset);
109         opal_convertor_unpack(&conv, &iov, &iov_count, &length);
110         opal_convertor_cleanup(&conv);
111         OBJ_DESTRUCT(&conv);
112         /* permanently switch to un-ordered mode */
113         convertor->offset = 0;
114     } else {
115         opal_convertor_unpack(&convertor->opal_conv, &iov, &iov_count, &length);
116         convertor->offset += length;
117     }
118     return UCS_OK;
119 }
120 
pml_ucx_generic_datatype_finish(void * state)121 static void pml_ucx_generic_datatype_finish(void *state)
122 {
123     mca_pml_ucx_convertor_t *convertor = state;
124 
125     opal_convertor_cleanup(&convertor->opal_conv);
126     OMPI_DATATYPE_RELEASE(convertor->datatype);
127     PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.convs, &convertor->super);
128 }
129 
130 static ucp_generic_dt_ops_t pml_ucx_generic_datatype_ops = {
131     .start_pack   = pml_ucx_generic_datatype_start_pack,
132     .start_unpack = pml_ucx_generic_datatype_start_unpack,
133     .packed_size  = pml_ucx_generic_datatype_packed_size,
134     .pack         = pml_ucx_generic_datatype_pack,
135     .unpack       = pml_ucx_generic_datatype_unpack,
136     .finish       = pml_ucx_generic_datatype_finish
137 };
138 
mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t * datatype,int keyval,void * attr_val,void * extra)139 int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
140                                      void *attr_val, void *extra)
141 {
142     ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;
143 
144 #ifdef HAVE_UCP_REQUEST_PARAM_T
145     free((void*)datatype->pml_data);
146 #else
147     PML_UCX_ASSERT((uint64_t)ucp_datatype == datatype->pml_data);
148 #endif
149     ucp_dt_destroy(ucp_datatype);
150     datatype->pml_data = PML_UCX_DATATYPE_INVALID;
151     return OMPI_SUCCESS;
152 }
153 
154 __opal_attribute_always_inline__
mca_pml_ucx_datatype_is_contig(ompi_datatype_t * datatype)155 static inline int mca_pml_ucx_datatype_is_contig(ompi_datatype_t *datatype)
156 {
157     ptrdiff_t lb;
158 
159     ompi_datatype_type_lb(datatype, &lb);
160 
161     return (datatype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) &&
162            (datatype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) &&
163            (lb == 0);
164 }
165 
166 #ifdef HAVE_UCP_REQUEST_PARAM_T
167 __opal_attribute_always_inline__ static inline
mca_pml_ucx_init_nbx_datatype(ompi_datatype_t * datatype,ucp_datatype_t ucp_datatype,size_t size)168 pml_ucx_datatype_t *mca_pml_ucx_init_nbx_datatype(ompi_datatype_t *datatype,
169                                                   ucp_datatype_t ucp_datatype,
170                                                   size_t size)
171 {
172     pml_ucx_datatype_t *pml_datatype;
173     int is_contig_pow2;
174 
175     pml_datatype = malloc(sizeof(*pml_datatype));
176     if (pml_datatype == NULL) {
177         PML_UCX_ERROR("Failed to allocate datatype structure");
178         ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
179     }
180 
181     pml_datatype->datatype                    = ucp_datatype;
182     pml_datatype->op_param.send.op_attr_mask  = UCP_OP_ATTR_FIELD_CALLBACK;
183     pml_datatype->op_param.send.cb.send       = mca_pml_ucx_send_nbx_completion;
184     pml_datatype->op_param.bsend.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
185     pml_datatype->op_param.bsend.cb.send      = mca_pml_ucx_bsend_nbx_completion;
186     pml_datatype->op_param.recv.op_attr_mask  = UCP_OP_ATTR_FIELD_CALLBACK |
187                                                 UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
188     pml_datatype->op_param.recv.cb.recv       = mca_pml_ucx_recv_nbx_completion;
189 
190     is_contig_pow2 = mca_pml_ucx_datatype_is_contig(datatype) &&
191                      (size && !(size & (size - 1))); /* is_pow2(size) */
192     if (is_contig_pow2) {
193         pml_datatype->size_shift = (int)(log(size) / log(2.0)); /* log2(size) */
194     } else {
195         pml_datatype->size_shift = 0;
196         PML_UCX_DATATYPE_SET_VALUE(pml_datatype, op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE);
197         PML_UCX_DATATYPE_SET_VALUE(pml_datatype, datatype = ucp_datatype);
198     }
199 
200     return pml_datatype;
201 }
202 #endif
203 
mca_pml_ucx_init_datatype(ompi_datatype_t * datatype)204 ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
205 {
206     size_t size = 0; /* init to suppress compiler warning */
207     ucp_datatype_t ucp_datatype;
208     ucs_status_t status;
209     int ret;
210 
211     if (mca_pml_ucx_datatype_is_contig(datatype)) {
212         ompi_datatype_type_size(datatype, &size);
213         PML_UCX_ASSERT(size > 0);
214         ucp_datatype = ucp_dt_make_contig(size);
215         goto out;
216     }
217 
218     status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops,
219                                    datatype, &ucp_datatype);
220     if (status != UCS_OK) {
221         PML_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name);
222         ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
223     }
224 
225     /* Add custom attribute, to clean up UCX resources when OMPI datatype is
226      * released.
227      */
228     if (ompi_datatype_is_predefined(datatype)) {
229         PML_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED);
230         ompi_pml_ucx.predefined_types[datatype->id] = ucp_datatype;
231     } else {
232         ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash,
233                               ompi_pml_ucx.datatype_attr_keyval,
234                               (void*)ucp_datatype, false);
235         if (ret != OMPI_SUCCESS) {
236             PML_UCX_ERROR("Failed to add UCX datatype attribute for %s: %d",
237                           datatype->name, ret);
238             ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
239         }
240     }
241 out:
242     PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)
243 
244 #ifdef HAVE_UCP_REQUEST_PARAM_T
245     UCS_STATIC_ASSERT(sizeof(datatype->pml_data) >= sizeof(pml_ucx_datatype_t*));
246     datatype->pml_data = (uint64_t)mca_pml_ucx_init_nbx_datatype(datatype,
247                                                                  ucp_datatype,
248                                                                  size);
249 #else
250     datatype->pml_data = ucp_datatype;
251 #endif
252 
253     return ucp_datatype;
254 }
255 
mca_pml_ucx_convertor_construct(mca_pml_ucx_convertor_t * convertor)256 static void mca_pml_ucx_convertor_construct(mca_pml_ucx_convertor_t *convertor)
257 {
258     OBJ_CONSTRUCT(&convertor->opal_conv, opal_convertor_t);
259 }
260 
mca_pml_ucx_convertor_destruct(mca_pml_ucx_convertor_t * convertor)261 static void mca_pml_ucx_convertor_destruct(mca_pml_ucx_convertor_t *convertor)
262 {
263     OBJ_DESTRUCT(&convertor->opal_conv);
264 }
265 
266 OBJ_CLASS_INSTANCE(mca_pml_ucx_convertor_t,
267                    opal_free_list_item_t,
268                    mca_pml_ucx_convertor_construct,
269                    mca_pml_ucx_convertor_destruct);
270