1 /*
2  * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED.
3  * Copyright (C) Mellanox Technologies Ltd. 2020.  ALL RIGHTS RESERVED.
4  * See file LICENSE for terms.
5  */
6 
7 #ifdef HAVE_CONFIG_H
8 #  include "config.h"
9 #endif
10 
11 #include "rocm_ipc_iface.h"
12 #include "rocm_ipc_md.h"
13 #include "rocm_ipc_ep.h"
14 
15 #include <uct/rocm/base/rocm_base.h>
16 #include <ucs/arch/cpu.h>
17 #include <ucs/type/class.h>
18 #include <ucs/sys/string.h>
19 
20 
21 static ucs_config_field_t uct_rocm_ipc_iface_config_table[] = {
22 
23     {"", "", NULL,
24      ucs_offsetof(uct_rocm_ipc_iface_config_t, super),
25      UCS_CONFIG_TYPE_TABLE(uct_iface_config_table)},
26 
27     {NULL}
28 };
29 
uct_rocm_ipc_iface_node_guid(uct_base_iface_t * iface)30 static uint64_t uct_rocm_ipc_iface_node_guid(uct_base_iface_t *iface)
31 {
32     return ucs_machine_guid() *
33            ucs_string_to_id(iface->md->component->name);
34 }
35 
uct_rocm_ipc_iface_get_device_address(uct_iface_t * tl_iface,uct_device_addr_t * addr)36 ucs_status_t uct_rocm_ipc_iface_get_device_address(uct_iface_t *tl_iface,
37                                                    uct_device_addr_t *addr)
38 {
39     uct_base_iface_t *iface = ucs_derived_of(tl_iface, uct_base_iface_t);
40 
41     *(uint64_t*)addr = uct_rocm_ipc_iface_node_guid(iface);
42     return UCS_OK;
43 }
44 
uct_rocm_ipc_iface_get_address(uct_iface_h tl_iface,uct_iface_addr_t * iface_addr)45 static ucs_status_t uct_rocm_ipc_iface_get_address(uct_iface_h tl_iface,
46                                                    uct_iface_addr_t *iface_addr)
47 {
48     *(pid_t*)iface_addr = getpid();
49     return UCS_OK;
50 }
51 
uct_rocm_ipc_iface_is_reachable(const uct_iface_h tl_iface,const uct_device_addr_t * dev_addr,const uct_iface_addr_t * iface_addr)52 static int uct_rocm_ipc_iface_is_reachable(const uct_iface_h tl_iface,
53                                            const uct_device_addr_t *dev_addr,
54                                            const uct_iface_addr_t *iface_addr)
55 {
56     uct_rocm_ipc_iface_t  *iface = ucs_derived_of(tl_iface, uct_rocm_ipc_iface_t);
57 
58     return ((uct_rocm_ipc_iface_node_guid(&iface->super) ==
59             *((const uint64_t *)dev_addr)) && ((getpid() != *(pid_t *)iface_addr)));
60 }
61 
uct_rocm_ipc_iface_query(uct_iface_h tl_iface,uct_iface_attr_t * iface_attr)62 static ucs_status_t uct_rocm_ipc_iface_query(uct_iface_h tl_iface,
63                                              uct_iface_attr_t *iface_attr)
64 {
65     uct_rocm_ipc_iface_t *iface = ucs_derived_of(tl_iface, uct_rocm_ipc_iface_t);
66 
67     uct_base_iface_query(&iface->super, iface_attr);
68 
69     iface_attr->cap.put.min_zcopy       = 0;
70     iface_attr->cap.put.max_zcopy       = SIZE_MAX;
71     iface_attr->cap.put.opt_zcopy_align = sizeof(uint32_t);
72     iface_attr->cap.put.align_mtu       = iface_attr->cap.put.opt_zcopy_align;
73     iface_attr->cap.put.max_iov         = 1;
74 
75     iface_attr->cap.get.min_zcopy       = 0;
76     iface_attr->cap.get.max_zcopy       = SIZE_MAX;
77     iface_attr->cap.get.opt_zcopy_align = sizeof(uint32_t);
78     iface_attr->cap.get.align_mtu       = iface_attr->cap.get.opt_zcopy_align;
79     iface_attr->cap.get.max_iov         = 1;
80 
81     iface_attr->iface_addr_len          = sizeof(pid_t);
82     iface_attr->device_addr_len         = sizeof(uint64_t);
83     iface_attr->ep_addr_len             = 0;
84     iface_attr->max_conn_priv           = 0;
85     iface_attr->cap.flags               = UCT_IFACE_FLAG_GET_ZCOPY |
86                                           UCT_IFACE_FLAG_PUT_ZCOPY |
87                                           UCT_IFACE_FLAG_PENDING   |
88                                           UCT_IFACE_FLAG_CONNECT_TO_IFACE;
89 
90     /* TODO: get accurate info */
91     iface_attr->latency                 = ucs_linear_func_make(80e-9, 0);
92     iface_attr->bandwidth.dedicated     = 10.0 * UCS_GBYTE; /* 10 GB */
93     iface_attr->bandwidth.shared        = 0;
94     iface_attr->overhead                = 0.4e-6; /* 0.4 us */
95 
96     return UCS_OK;
97 }
98 
99 static UCS_CLASS_DECLARE_DELETE_FUNC(uct_rocm_ipc_iface_t, uct_iface_t);
100 
101 static ucs_status_t
uct_rocm_ipc_iface_flush(uct_iface_h tl_iface,unsigned flags,uct_completion_t * comp)102 uct_rocm_ipc_iface_flush(uct_iface_h tl_iface, unsigned flags,
103                          uct_completion_t *comp)
104 {
105     uct_rocm_ipc_iface_t *iface = ucs_derived_of(tl_iface, uct_rocm_ipc_iface_t);
106 
107     if (comp != NULL) {
108         return UCS_ERR_UNSUPPORTED;
109     }
110 
111     if (ucs_queue_is_empty(&iface->signal_queue)) {
112         UCT_TL_IFACE_STAT_FLUSH(ucs_derived_of(tl_iface, uct_base_iface_t));
113         return UCS_OK;
114     }
115 
116     UCT_TL_IFACE_STAT_FLUSH_WAIT(ucs_derived_of(tl_iface, uct_base_iface_t));
117     return UCS_INPROGRESS;
118 }
119 
uct_rocm_ipc_iface_progress(uct_iface_h tl_iface)120 static unsigned uct_rocm_ipc_iface_progress(uct_iface_h tl_iface)
121 {
122     uct_rocm_ipc_iface_t *iface = ucs_derived_of(tl_iface, uct_rocm_ipc_iface_t);
123     static const unsigned max_signals = 16;
124     unsigned count = 0;
125     uct_rocm_ipc_signal_desc_t *rocm_ipc_signal;
126     ucs_queue_iter_t iter;
127 
128     ucs_queue_for_each_safe(rocm_ipc_signal, iter, &iface->signal_queue, queue) {
129         if (hsa_signal_load_scacquire(rocm_ipc_signal->signal) != 0) {
130             continue;
131         }
132 
133         ucs_queue_del_iter(&iface->signal_queue, iter);
134         if (rocm_ipc_signal->comp != NULL) {
135             uct_invoke_completion(rocm_ipc_signal->comp, UCS_OK);
136         }
137 
138         ucs_trace_poll("ROCM_IPC Signal Done :%p", rocm_ipc_signal);
139         ucs_mpool_put(rocm_ipc_signal);
140         count++;
141 
142         if (count >= max_signals) {
143             break;
144         }
145     }
146 
147     return count;
148 }
149 
150 static uct_iface_ops_t uct_rocm_ipc_iface_ops = {
151     .ep_put_zcopy             = uct_rocm_ipc_ep_put_zcopy,
152     .ep_get_zcopy             = uct_rocm_ipc_ep_get_zcopy,
153     .ep_pending_add           = ucs_empty_function_return_busy,
154     .ep_pending_purge         = ucs_empty_function,
155     .ep_flush                 = uct_base_ep_flush,
156     .ep_fence                 = uct_base_ep_fence,
157     .ep_create                = UCS_CLASS_NEW_FUNC_NAME(uct_rocm_ipc_ep_t),
158     .ep_destroy               = UCS_CLASS_DELETE_FUNC_NAME(uct_rocm_ipc_ep_t),
159     .iface_flush              = uct_rocm_ipc_iface_flush,
160     .iface_fence              = uct_base_iface_fence,
161     .iface_progress_enable    = uct_base_iface_progress_enable,
162     .iface_progress_disable   = uct_base_iface_progress_disable,
163     .iface_progress           = uct_rocm_ipc_iface_progress,
164     .iface_close              = UCS_CLASS_DELETE_FUNC_NAME(uct_rocm_ipc_iface_t),
165     .iface_query              = uct_rocm_ipc_iface_query,
166     .iface_get_address        = uct_rocm_ipc_iface_get_address,
167     .iface_get_device_address = uct_rocm_ipc_iface_get_device_address,
168     .iface_is_reachable       = uct_rocm_ipc_iface_is_reachable
169 };
170 
uct_rocm_ipc_signal_desc_init(ucs_mpool_t * mp,void * obj,void * chunk)171 static void uct_rocm_ipc_signal_desc_init(ucs_mpool_t *mp, void *obj, void *chunk)
172 {
173     uct_rocm_ipc_signal_desc_t *base = (uct_rocm_ipc_signal_desc_t *)obj;
174     hsa_status_t status;
175 
176     memset(base, 0, sizeof(*base));
177     status = hsa_signal_create(1, 0, NULL, &base->signal);
178     if (status != HSA_STATUS_SUCCESS) {
179         ucs_fatal("fail to create signal");
180     }
181 }
182 
uct_rocm_ipc_signal_desc_cleanup(ucs_mpool_t * mp,void * obj)183 static void uct_rocm_ipc_signal_desc_cleanup(ucs_mpool_t *mp, void *obj)
184 {
185     uct_rocm_ipc_signal_desc_t *base = (uct_rocm_ipc_signal_desc_t *)obj;
186     hsa_status_t status;
187 
188     status = hsa_signal_destroy(base->signal);
189     if (status != HSA_STATUS_SUCCESS) {
190         ucs_fatal("fail to destroy signal");
191     }
192 }
193 
194 static ucs_mpool_ops_t uct_rocm_ipc_signal_desc_mpool_ops = {
195     .chunk_alloc   = ucs_mpool_chunk_malloc,
196     .chunk_release = ucs_mpool_chunk_free,
197     .obj_init      = uct_rocm_ipc_signal_desc_init,
198     .obj_cleanup   = uct_rocm_ipc_signal_desc_cleanup,
199 };
200 
UCS_CLASS_INIT_FUNC(uct_rocm_ipc_iface_t,uct_md_h md,uct_worker_h worker,const uct_iface_params_t * params,const uct_iface_config_t * tl_config)201 static UCS_CLASS_INIT_FUNC(uct_rocm_ipc_iface_t, uct_md_h md, uct_worker_h worker,
202                            const uct_iface_params_t *params,
203                            const uct_iface_config_t *tl_config)
204 {
205     ucs_status_t status;
206 
207     UCS_CLASS_CALL_SUPER_INIT(uct_base_iface_t, &uct_rocm_ipc_iface_ops, md, worker,
208                               params, tl_config UCS_STATS_ARG(params->stats_root)
209                               UCS_STATS_ARG(UCT_ROCM_IPC_TL_NAME));
210 
211     status = ucs_mpool_init(&self->signal_pool,
212                             0,
213                             sizeof(uct_rocm_ipc_signal_desc_t),
214                             0,
215                             UCS_SYS_CACHE_LINE_SIZE,
216                             128,
217                             1024,
218                             &uct_rocm_ipc_signal_desc_mpool_ops,
219                             "ROCM_IPC signal objects");
220     if (status != UCS_OK) {
221         ucs_error("rocm/ipc signal mpool creation failed");
222         return status;
223     }
224 
225     ucs_queue_head_init(&self->signal_queue);
226 
227     return UCS_OK;
228 }
229 
230 
UCS_CLASS_CLEANUP_FUNC(uct_rocm_ipc_iface_t)231 static UCS_CLASS_CLEANUP_FUNC(uct_rocm_ipc_iface_t)
232 {
233     uct_base_iface_progress_disable(&self->super.super,
234                                     UCT_PROGRESS_SEND | UCT_PROGRESS_RECV);
235     ucs_mpool_cleanup(&self->signal_pool, 1);
236 }
237 
238 UCS_CLASS_DEFINE(uct_rocm_ipc_iface_t, uct_base_iface_t);
239 
240 static UCS_CLASS_DEFINE_NEW_FUNC(uct_rocm_ipc_iface_t, uct_iface_t, uct_md_h,
241                                  uct_worker_h, const uct_iface_params_t*,
242                                  const uct_iface_config_t *);
243 static UCS_CLASS_DEFINE_DELETE_FUNC(uct_rocm_ipc_iface_t, uct_iface_t);
244 
245 UCT_TL_DEFINE(&uct_rocm_ipc_component, rocm_ipc, uct_rocm_base_query_devices,
246               uct_rocm_ipc_iface_t, "ROCM_IPC_",
247               uct_rocm_ipc_iface_config_table, uct_rocm_ipc_iface_config_t);
248