1 /*
2  * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED.
3  * See file LICENSE for terms.
4  */
5 
6 #ifdef HAVE_CONFIG_H
7 #  include "config.h"
8 #endif
9 
10 #include "rocm_copy_ep.h"
11 #include "rocm_copy_iface.h"
12 
13 #include <uct/base/uct_log.h>
14 #include <uct/base/uct_iov.inl>
15 #include <ucs/debug/memtrack.h>
16 #include <ucs/type/class.h>
17 #include <ucs/arch/cpu.h>
18 
19 #define uct_rocm_memcpy_h2d(_d,_s,_l)  memcpy((_d),(_s),(_l))
20 #define uct_rocm_memcpy_d2h(_d,_s,_l)  ucs_memcpy_nontemporal((_d),(_s),(_l))
21 
UCS_CLASS_INIT_FUNC(uct_rocm_copy_ep_t,const uct_ep_params_t * params)22 static UCS_CLASS_INIT_FUNC(uct_rocm_copy_ep_t, const uct_ep_params_t *params)
23 {
24     uct_rocm_copy_iface_t *iface = ucs_derived_of(params->iface, uct_rocm_copy_iface_t);
25 
26     UCS_CLASS_CALL_SUPER_INIT(uct_base_ep_t, &iface->super);
27 
28     return UCS_OK;
29 }
30 
UCS_CLASS_CLEANUP_FUNC(uct_rocm_copy_ep_t)31 static UCS_CLASS_CLEANUP_FUNC(uct_rocm_copy_ep_t)
32 {
33 }
34 
35 UCS_CLASS_DEFINE(uct_rocm_copy_ep_t, uct_base_ep_t)
36 UCS_CLASS_DEFINE_NEW_FUNC(uct_rocm_copy_ep_t, uct_ep_t, const uct_ep_params_t *);
37 UCS_CLASS_DEFINE_DELETE_FUNC(uct_rocm_copy_ep_t, uct_ep_t);
38 
39 #define uct_rocm_copy_trace_data(_remote_addr, _rkey, _fmt, ...) \
40      ucs_trace_data(_fmt " to %"PRIx64"(%+ld)", ## __VA_ARGS__, (_remote_addr), \
41                     (_rkey))
42 
43 static UCS_F_ALWAYS_INLINE ucs_status_t
uct_rocm_copy_ep_zcopy(uct_ep_h tl_ep,uint64_t remote_addr,const uct_iov_t * iov,int is_put)44 uct_rocm_copy_ep_zcopy(uct_ep_h tl_ep,
45                                    uint64_t remote_addr,
46                                    const uct_iov_t *iov,
47                                    int is_put)
48 {
49     size_t size = uct_iov_get_length(iov);
50 
51     if (!size) {
52         return UCS_OK;
53     }
54 
55     if (is_put)
56         uct_rocm_memcpy_h2d((void *)remote_addr, iov->buffer, size);
57     else
58         uct_rocm_memcpy_d2h(iov->buffer, (void *)remote_addr, size);
59 
60     return UCS_OK;
61 }
62 
uct_rocm_copy_ep_get_zcopy(uct_ep_h tl_ep,const uct_iov_t * iov,size_t iovcnt,uint64_t remote_addr,uct_rkey_t rkey,uct_completion_t * comp)63 ucs_status_t uct_rocm_copy_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
64                                         uint64_t remote_addr, uct_rkey_t rkey,
65                                         uct_completion_t *comp)
66 {
67     ucs_status_t status;
68 
69     status = uct_rocm_copy_ep_zcopy(tl_ep, remote_addr, iov, 0);
70 
71     UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), GET, ZCOPY,
72                       uct_iov_total_length(iov, iovcnt));
73     uct_rocm_copy_trace_data(remote_addr, rkey, "GET_ZCOPY [length %zu]",
74                              uct_iov_total_length(iov, iovcnt));
75     return status;
76 }
77 
uct_rocm_copy_ep_put_zcopy(uct_ep_h tl_ep,const uct_iov_t * iov,size_t iovcnt,uint64_t remote_addr,uct_rkey_t rkey,uct_completion_t * comp)78 ucs_status_t uct_rocm_copy_ep_put_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
79                                         uint64_t remote_addr, uct_rkey_t rkey,
80                                         uct_completion_t *comp)
81 {
82     ucs_status_t status;
83 
84     status = uct_rocm_copy_ep_zcopy(tl_ep, remote_addr, iov, 1);
85 
86     UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), PUT, ZCOPY,
87                       uct_iov_total_length(iov, iovcnt));
88     uct_rocm_copy_trace_data(remote_addr, rkey, "GET_ZCOPY [length %zu]",
89                              uct_iov_total_length(iov, iovcnt));
90     return status;
91 
92 }
93 
94 
uct_rocm_copy_ep_put_short(uct_ep_h tl_ep,const void * buffer,unsigned length,uint64_t remote_addr,uct_rkey_t rkey)95 ucs_status_t uct_rocm_copy_ep_put_short(uct_ep_h tl_ep, const void *buffer,
96                                         unsigned length, uint64_t remote_addr,
97                                         uct_rkey_t rkey)
98 {
99     uct_rocm_memcpy_h2d((void *)remote_addr, buffer, length);
100 
101     UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), PUT, SHORT, length);
102     ucs_trace_data("PUT_SHORT size %d from %p to %p",
103                    length, buffer, (void *)remote_addr);
104     return UCS_OK;
105 }
106 
uct_rocm_copy_ep_get_short(uct_ep_h tl_ep,void * buffer,unsigned length,uint64_t remote_addr,uct_rkey_t rkey)107 ucs_status_t uct_rocm_copy_ep_get_short(uct_ep_h tl_ep, void *buffer,
108                                         unsigned length, uint64_t remote_addr,
109                                         uct_rkey_t rkey)
110 {
111     uct_rocm_memcpy_d2h(buffer, (void *)remote_addr, length);
112 
113     UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), GET, SHORT, length);
114     ucs_trace_data("GET_SHORT size %d from %p to %p",
115                    length, (void *)remote_addr, buffer);
116     return UCS_OK;
117 }
118