1 /*
2 * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED.
3 * See file LICENSE for terms.
4 */
5
6 #ifdef HAVE_CONFIG_H
7 # include "config.h"
8 #endif
9
10 #include "rocm_copy_ep.h"
11 #include "rocm_copy_iface.h"
12
13 #include <uct/base/uct_log.h>
14 #include <uct/base/uct_iov.inl>
15 #include <ucs/debug/memtrack.h>
16 #include <ucs/type/class.h>
17 #include <ucs/arch/cpu.h>
18
19 #define uct_rocm_memcpy_h2d(_d,_s,_l) memcpy((_d),(_s),(_l))
20 #define uct_rocm_memcpy_d2h(_d,_s,_l) ucs_memcpy_nontemporal((_d),(_s),(_l))
21
UCS_CLASS_INIT_FUNC(uct_rocm_copy_ep_t,const uct_ep_params_t * params)22 static UCS_CLASS_INIT_FUNC(uct_rocm_copy_ep_t, const uct_ep_params_t *params)
23 {
24 uct_rocm_copy_iface_t *iface = ucs_derived_of(params->iface, uct_rocm_copy_iface_t);
25
26 UCS_CLASS_CALL_SUPER_INIT(uct_base_ep_t, &iface->super);
27
28 return UCS_OK;
29 }
30
UCS_CLASS_CLEANUP_FUNC(uct_rocm_copy_ep_t)31 static UCS_CLASS_CLEANUP_FUNC(uct_rocm_copy_ep_t)
32 {
33 }
34
35 UCS_CLASS_DEFINE(uct_rocm_copy_ep_t, uct_base_ep_t)
36 UCS_CLASS_DEFINE_NEW_FUNC(uct_rocm_copy_ep_t, uct_ep_t, const uct_ep_params_t *);
37 UCS_CLASS_DEFINE_DELETE_FUNC(uct_rocm_copy_ep_t, uct_ep_t);
38
39 #define uct_rocm_copy_trace_data(_remote_addr, _rkey, _fmt, ...) \
40 ucs_trace_data(_fmt " to %"PRIx64"(%+ld)", ## __VA_ARGS__, (_remote_addr), \
41 (_rkey))
42
43 static UCS_F_ALWAYS_INLINE ucs_status_t
uct_rocm_copy_ep_zcopy(uct_ep_h tl_ep,uint64_t remote_addr,const uct_iov_t * iov,int is_put)44 uct_rocm_copy_ep_zcopy(uct_ep_h tl_ep,
45 uint64_t remote_addr,
46 const uct_iov_t *iov,
47 int is_put)
48 {
49 size_t size = uct_iov_get_length(iov);
50
51 if (!size) {
52 return UCS_OK;
53 }
54
55 if (is_put)
56 uct_rocm_memcpy_h2d((void *)remote_addr, iov->buffer, size);
57 else
58 uct_rocm_memcpy_d2h(iov->buffer, (void *)remote_addr, size);
59
60 return UCS_OK;
61 }
62
uct_rocm_copy_ep_get_zcopy(uct_ep_h tl_ep,const uct_iov_t * iov,size_t iovcnt,uint64_t remote_addr,uct_rkey_t rkey,uct_completion_t * comp)63 ucs_status_t uct_rocm_copy_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
64 uint64_t remote_addr, uct_rkey_t rkey,
65 uct_completion_t *comp)
66 {
67 ucs_status_t status;
68
69 status = uct_rocm_copy_ep_zcopy(tl_ep, remote_addr, iov, 0);
70
71 UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), GET, ZCOPY,
72 uct_iov_total_length(iov, iovcnt));
73 uct_rocm_copy_trace_data(remote_addr, rkey, "GET_ZCOPY [length %zu]",
74 uct_iov_total_length(iov, iovcnt));
75 return status;
76 }
77
uct_rocm_copy_ep_put_zcopy(uct_ep_h tl_ep,const uct_iov_t * iov,size_t iovcnt,uint64_t remote_addr,uct_rkey_t rkey,uct_completion_t * comp)78 ucs_status_t uct_rocm_copy_ep_put_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
79 uint64_t remote_addr, uct_rkey_t rkey,
80 uct_completion_t *comp)
81 {
82 ucs_status_t status;
83
84 status = uct_rocm_copy_ep_zcopy(tl_ep, remote_addr, iov, 1);
85
86 UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), PUT, ZCOPY,
87 uct_iov_total_length(iov, iovcnt));
88 uct_rocm_copy_trace_data(remote_addr, rkey, "GET_ZCOPY [length %zu]",
89 uct_iov_total_length(iov, iovcnt));
90 return status;
91
92 }
93
94
uct_rocm_copy_ep_put_short(uct_ep_h tl_ep,const void * buffer,unsigned length,uint64_t remote_addr,uct_rkey_t rkey)95 ucs_status_t uct_rocm_copy_ep_put_short(uct_ep_h tl_ep, const void *buffer,
96 unsigned length, uint64_t remote_addr,
97 uct_rkey_t rkey)
98 {
99 uct_rocm_memcpy_h2d((void *)remote_addr, buffer, length);
100
101 UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), PUT, SHORT, length);
102 ucs_trace_data("PUT_SHORT size %d from %p to %p",
103 length, buffer, (void *)remote_addr);
104 return UCS_OK;
105 }
106
uct_rocm_copy_ep_get_short(uct_ep_h tl_ep,void * buffer,unsigned length,uint64_t remote_addr,uct_rkey_t rkey)107 ucs_status_t uct_rocm_copy_ep_get_short(uct_ep_h tl_ep, void *buffer,
108 unsigned length, uint64_t remote_addr,
109 uct_rkey_t rkey)
110 {
111 uct_rocm_memcpy_d2h(buffer, (void *)remote_addr, length);
112
113 UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), GET, SHORT, length);
114 ucs_trace_data("GET_SHORT size %d from %p to %p",
115 length, (void *)remote_addr, buffer);
116 return UCS_OK;
117 }
118