1 /*
2  * Copyright (c) 2019 Amazon.com, Inc. or its affiliates.
3  * All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33 
34 #include <ofi_atomic.h>
35 #include "efa.h"
36 #include "rxr.h"
37 #include "rxr_rma.h"
38 #include "rxr_msg.h"
39 #include "rxr_pkt_cmd.h"
40 #include "rxr_read.h"
41 #include "efa_cuda.h"
42 
43 /*
44  * Utility constants and funnctions shared by all REQ packe
45  * types.
46  */
47 struct rxr_req_inf {
48 	uint64_t protover;
49 	uint64_t base_hdr_size;
50 	uint64_t ex_feature_flag;
51 };
52 
53 /*
54  * starting from protocol version 4, each REQ packet type will be assigned a
55  * version number, and once assigned, the version number will not change.
56  *
57  * Baseline features will always be version 4 features, baseline and
58  * not have a ex_feature_flag.
59  *
60  * Each extra feature will be assign a version and an ex_feature_flag.
61  * Each extra feature will correspond to 1 or more REQ packet types.
62  */
63 static const
64 struct rxr_req_inf REQ_INF_LIST[] = {
65 	/* rtm header */
66 	[RXR_EAGER_MSGRTM_PKT] = {4, sizeof(struct rxr_eager_msgrtm_hdr), 0},
67 	[RXR_EAGER_TAGRTM_PKT] = {4, sizeof(struct rxr_eager_tagrtm_hdr), 0},
68 	[RXR_MEDIUM_MSGRTM_PKT] = {4, sizeof(struct rxr_medium_msgrtm_hdr), 0},
69 	[RXR_MEDIUM_TAGRTM_PKT] = {4, sizeof(struct rxr_medium_tagrtm_hdr), 0},
70 	[RXR_LONG_MSGRTM_PKT] = {4, sizeof(struct rxr_long_msgrtm_hdr), 0},
71 	[RXR_LONG_TAGRTM_PKT] = {4, sizeof(struct rxr_long_tagrtm_hdr), 0},
72 	[RXR_READ_MSGRTM_PKT] = {4, sizeof(struct rxr_read_msgrtm_hdr), RXR_REQ_FEATURE_RDMA_READ},
73 	[RXR_READ_TAGRTM_PKT] = {4, sizeof(struct rxr_read_tagrtm_hdr), RXR_REQ_FEATURE_RDMA_READ},
74 	/* rtw header */
75 	[RXR_EAGER_RTW_PKT] = {4, sizeof(struct rxr_eager_rtw_hdr), 0},
76 	[RXR_LONG_RTW_PKT] = {4, sizeof(struct rxr_long_rtw_hdr), 0},
77 	[RXR_READ_RTW_PKT] = {4, sizeof(struct rxr_read_rtw_hdr), RXR_REQ_FEATURE_RDMA_READ},
78 	/* rtr header */
79 	[RXR_SHORT_RTR_PKT] = {4, sizeof(struct rxr_rtr_hdr), 0},
80 	[RXR_LONG_RTR_PKT] = {4, sizeof(struct rxr_rtr_hdr), 0},
81 	[RXR_READ_RTR_PKT] = {4, sizeof(struct rxr_base_hdr), RXR_REQ_FEATURE_RDMA_READ},
82 	/* rta header */
83 	[RXR_WRITE_RTA_PKT] = {4, sizeof(struct rxr_rta_hdr), 0},
84 	[RXR_FETCH_RTA_PKT] = {4, sizeof(struct rxr_rta_hdr), 0},
85 	[RXR_COMPARE_RTA_PKT] = {4, sizeof(struct rxr_rta_hdr), 0},
86 };
87 
rxr_pkt_req_data_size(struct rxr_pkt_entry * pkt_entry)88 size_t rxr_pkt_req_data_size(struct rxr_pkt_entry *pkt_entry)
89 {
90 	assert(pkt_entry->hdr_size > 0);
91 	return pkt_entry->pkt_size - pkt_entry->hdr_size;
92 }
93 
rxr_pkt_init_req_hdr(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,struct rxr_pkt_entry * pkt_entry)94 void rxr_pkt_init_req_hdr(struct rxr_ep *ep,
95 			  struct rxr_tx_entry *tx_entry,
96 			  int pkt_type,
97 			  struct rxr_pkt_entry *pkt_entry)
98 {
99 	char *opt_hdr;
100 	struct rxr_peer *peer;
101 	struct rxr_base_hdr *base_hdr;
102 
103 	/* init the base header */
104 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
105 	base_hdr->type = pkt_type;
106 	base_hdr->version = REQ_INF_LIST[pkt_type].protover;
107 	base_hdr->flags = 0;
108 
109 	peer = rxr_ep_get_peer(ep, tx_entry->addr);
110 	assert(peer);
111 	if (OFI_UNLIKELY(!(peer->flags & RXR_PEER_HANDSHAKE_RECEIVED))) {
112 		/*
113 		 * This is the first communication with this peer on this
114 		 * endpoint, so send the core's address for this EP in the REQ
115 		 * so the remote side can insert it into its address vector.
116 		 */
117 		base_hdr->flags |= RXR_REQ_OPT_RAW_ADDR_HDR;
118 	}
119 
120 	if (tx_entry->fi_flags & FI_REMOTE_CQ_DATA) {
121 		base_hdr->flags |= RXR_REQ_OPT_CQ_DATA_HDR;
122 	}
123 
124 	/* init the opt header */
125 	opt_hdr = (char *)base_hdr + rxr_pkt_req_base_hdr_size(pkt_entry);
126 	if (base_hdr->flags & RXR_REQ_OPT_RAW_ADDR_HDR) {
127 		struct rxr_req_opt_raw_addr_hdr *raw_addr_hdr;
128 
129 		raw_addr_hdr = (struct rxr_req_opt_raw_addr_hdr *)opt_hdr;
130 		raw_addr_hdr->addr_len = ep->core_addrlen;
131 		memcpy(raw_addr_hdr->raw_addr, ep->core_addr, raw_addr_hdr->addr_len);
132 		opt_hdr += sizeof(*raw_addr_hdr) + raw_addr_hdr->addr_len;
133 	}
134 
135 	if (base_hdr->flags & RXR_REQ_OPT_CQ_DATA_HDR) {
136 		struct rxr_req_opt_cq_data_hdr *cq_data_hdr;
137 
138 		cq_data_hdr = (struct rxr_req_opt_cq_data_hdr *)opt_hdr;
139 		cq_data_hdr->cq_data = tx_entry->cq_entry.data;
140 		opt_hdr += sizeof(*cq_data_hdr);
141 	}
142 
143 	pkt_entry->addr = tx_entry->addr;
144 	pkt_entry->hdr_size = opt_hdr - (char *)pkt_entry->pkt;
145 }
146 
rxr_pkt_req_base_hdr_size(struct rxr_pkt_entry * pkt_entry)147 size_t rxr_pkt_req_base_hdr_size(struct rxr_pkt_entry *pkt_entry)
148 {
149 	struct rxr_base_hdr *base_hdr;
150 	size_t hdr_size;
151 
152 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
153 	assert(base_hdr->type >= RXR_REQ_PKT_BEGIN);
154 
155 	hdr_size = REQ_INF_LIST[base_hdr->type].base_hdr_size;
156 	if (base_hdr->type == RXR_EAGER_RTW_PKT ||
157 	    base_hdr->type == RXR_LONG_RTW_PKT ||
158 	    base_hdr->type == RXR_READ_RTW_PKT)
159 		hdr_size += rxr_get_rtw_base_hdr(pkt_entry->pkt)->rma_iov_count * sizeof(struct fi_rma_iov);
160 	else if (base_hdr->type == RXR_SHORT_RTR_PKT ||
161 		 base_hdr->type == RXR_LONG_RTR_PKT)
162 		hdr_size += rxr_get_rtr_hdr(pkt_entry->pkt)->rma_iov_count * sizeof(struct fi_rma_iov);
163 	else if (base_hdr->type == RXR_WRITE_RTA_PKT ||
164 		 base_hdr->type == RXR_FETCH_RTA_PKT ||
165 		 base_hdr->type == RXR_COMPARE_RTA_PKT)
166 		hdr_size += rxr_get_rta_hdr(pkt_entry->pkt)->rma_iov_count * sizeof(struct fi_rma_iov);
167 
168 	return hdr_size;
169 }
170 
rxr_pkt_proc_req_common_hdr(struct rxr_pkt_entry * pkt_entry)171 void rxr_pkt_proc_req_common_hdr(struct rxr_pkt_entry *pkt_entry)
172 {
173 	char *opt_hdr;
174 	struct rxr_base_hdr *base_hdr;
175 
176 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
177 
178 	opt_hdr = (char *)pkt_entry->pkt + rxr_pkt_req_base_hdr_size(pkt_entry);
179 	if (base_hdr->flags & RXR_REQ_OPT_RAW_ADDR_HDR) {
180 		struct rxr_req_opt_raw_addr_hdr *raw_addr_hdr;
181 
182 		raw_addr_hdr = (struct rxr_req_opt_raw_addr_hdr *)opt_hdr;
183 		pkt_entry->raw_addr = raw_addr_hdr->raw_addr;
184 		opt_hdr += sizeof(*raw_addr_hdr) + raw_addr_hdr->addr_len;
185 	} else {
186 		pkt_entry->raw_addr = NULL;
187 	}
188 
189 	if (base_hdr->flags & RXR_REQ_OPT_CQ_DATA_HDR) {
190 		struct rxr_req_opt_cq_data_hdr *cq_data_hdr;
191 
192 		cq_data_hdr = (struct rxr_req_opt_cq_data_hdr *)opt_hdr;
193 		pkt_entry->cq_data = cq_data_hdr->cq_data;
194 		opt_hdr += sizeof(struct rxr_req_opt_cq_data_hdr);
195 	}
196 
197 	pkt_entry->hdr_size = opt_hdr - (char *)pkt_entry->pkt;
198 }
199 
rxr_pkt_req_max_data_size(struct rxr_ep * ep,fi_addr_t addr,int pkt_type)200 size_t rxr_pkt_req_max_data_size(struct rxr_ep *ep, fi_addr_t addr, int pkt_type)
201 {
202 	struct rxr_peer *peer;
203 
204 	peer = rxr_ep_get_peer(ep, addr);
205 	assert(peer);
206 
207 	if (peer->is_local) {
208 		assert(ep->use_shm);
209 		return rxr_env.shm_max_medium_size;
210 	}
211 
212 	int max_hdr_size = REQ_INF_LIST[pkt_type].base_hdr_size
213 		+ sizeof(struct rxr_req_opt_raw_addr_hdr)
214 		+ sizeof(struct rxr_req_opt_cq_data_hdr);
215 
216 	if (pkt_type == RXR_EAGER_RTW_PKT || pkt_type == RXR_LONG_RTW_PKT)
217 		max_hdr_size += RXR_IOV_LIMIT * sizeof(struct fi_rma_iov);
218 
219 	return ep->mtu_size - max_hdr_size;
220 }
221 
222 static
rxr_pkt_req_copy_data(struct rxr_rx_entry * rx_entry,struct rxr_pkt_entry * pkt_entry,char * data,size_t data_size)223 size_t rxr_pkt_req_copy_data(struct rxr_rx_entry *rx_entry,
224 			     struct rxr_pkt_entry *pkt_entry,
225 			     char *data, size_t data_size)
226 {
227 	size_t bytes_copied;
228 	int bytes_left;
229 
230 	bytes_copied = rxr_copy_to_rx(data, data_size, rx_entry, 0);
231 
232 	if (OFI_UNLIKELY(bytes_copied < data_size)) {
233 		/* recv buffer is not big enough to hold req, this must be a truncated message */
234 		assert(bytes_copied == rx_entry->cq_entry.len &&
235 		       rx_entry->cq_entry.len < rx_entry->total_len);
236 		rx_entry->bytes_done = bytes_copied;
237 		bytes_left = 0;
238 	} else {
239 		assert(bytes_copied == data_size);
240 		rx_entry->bytes_done = data_size;
241 		bytes_left = rx_entry->total_len - rx_entry->bytes_done;
242 	}
243 
244 	assert(bytes_left >= 0);
245 	return bytes_left;
246 }
247 
248 /*
249  * REQ packet type functions
250  *
251  *     init() functions
252  */
253 
254 /*
255  * this function is called after you have set header in pkt_entry->pkt and
256  * pkt_entry->hdr_size
257  */
rxr_pkt_data_from_tx(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry,struct rxr_tx_entry * tx_entry,size_t data_offset,size_t data_size)258 void rxr_pkt_data_from_tx(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry,
259 			  struct rxr_tx_entry *tx_entry, size_t data_offset,
260 			  size_t data_size)
261 {
262 	int tx_iov_index;
263 	size_t tx_iov_offset;
264 	char *data;
265 
266 	if (data_size == 0) {
267 		pkt_entry->iov_count = 0;
268 		pkt_entry->pkt_size = pkt_entry->hdr_size;
269 		return;
270 	}
271 
272 	rxr_locate_iov_pos(tx_entry->iov, tx_entry->iov_count, data_offset,
273 			   &tx_iov_index, &tx_iov_offset);
274 	assert(tx_iov_index < tx_entry->iov_count);
275 	assert(tx_iov_offset < tx_entry->iov[tx_iov_index].iov_len);
276 	assert(pkt_entry->hdr_size > 0);
277 	if (!tx_entry->desc[tx_iov_index]) {
278 		data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
279 		data_size = rxr_copy_from_tx(data, data_size, tx_entry, data_offset);
280 		pkt_entry->iov_count = 0;
281 		pkt_entry->pkt_size = pkt_entry->hdr_size + data_size;
282 		return;
283 	}
284 
285 	/* when desc is available, we use it instead of copying */
286 	assert(ep->core_iov_limit >= 2);
287 	pkt_entry->iov[0].iov_base = pkt_entry->pkt;
288 	pkt_entry->iov[0].iov_len = pkt_entry->hdr_size;
289 	pkt_entry->desc[0] = fi_mr_desc(pkt_entry->mr);
290 
291 	pkt_entry->iov[1].iov_base = (char *)tx_entry->iov[tx_iov_index].iov_base + tx_iov_offset;
292 	pkt_entry->iov[1].iov_len = MIN(data_size,
293 					tx_entry->iov[tx_iov_index].iov_len - tx_iov_offset);
294 	pkt_entry->desc[1] = tx_entry->desc[tx_iov_index];
295 	pkt_entry->iov_count = 2;
296 	pkt_entry->pkt_size = pkt_entry->hdr_size + pkt_entry->iov[1].iov_len;
297 }
298 
rxr_pkt_init_rtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,uint64_t data_offset,struct rxr_pkt_entry * pkt_entry)299 void rxr_pkt_init_rtm(struct rxr_ep *ep,
300 		      struct rxr_tx_entry *tx_entry,
301 		      int pkt_type, uint64_t data_offset,
302 		      struct rxr_pkt_entry *pkt_entry)
303 {
304 	size_t data_size;
305 	struct rxr_rtm_base_hdr *rtm_hdr;
306 	/* this function set pkt_entry->hdr_size */
307 	rxr_pkt_init_req_hdr(ep, tx_entry, pkt_type, pkt_entry);
308 
309 	rtm_hdr = (struct rxr_rtm_base_hdr *)pkt_entry->pkt;
310 	rtm_hdr->flags |= RXR_REQ_MSG;
311 	rtm_hdr->msg_id = tx_entry->msg_id;
312 
313 	data_size = MIN(tx_entry->total_len - data_offset, ep->mtu_size - pkt_entry->hdr_size);
314 	rxr_pkt_data_from_tx(ep, pkt_entry, tx_entry, data_offset, data_size);
315 	pkt_entry->x_entry = tx_entry;
316 }
317 
rxr_pkt_init_eager_msgrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)318 ssize_t rxr_pkt_init_eager_msgrtm(struct rxr_ep *ep,
319 				  struct rxr_tx_entry *tx_entry,
320 				  struct rxr_pkt_entry *pkt_entry)
321 {
322 	rxr_pkt_init_rtm(ep, tx_entry, RXR_EAGER_MSGRTM_PKT, 0, pkt_entry);
323 	return 0;
324 }
325 
rxr_pkt_init_eager_tagrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)326 ssize_t rxr_pkt_init_eager_tagrtm(struct rxr_ep *ep,
327 				  struct rxr_tx_entry *tx_entry,
328 				  struct rxr_pkt_entry *pkt_entry)
329 {
330 	struct rxr_base_hdr *base_hdr;
331 
332 	rxr_pkt_init_rtm(ep, tx_entry, RXR_EAGER_TAGRTM_PKT, 0, pkt_entry);
333 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
334 	base_hdr->flags |= RXR_REQ_TAGGED;
335 	rxr_pkt_rtm_settag(pkt_entry, tx_entry->tag);
336 	return 0;
337 }
338 
rxr_pkt_init_medium_msgrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)339 ssize_t rxr_pkt_init_medium_msgrtm(struct rxr_ep *ep,
340 				   struct rxr_tx_entry *tx_entry,
341 				   struct rxr_pkt_entry *pkt_entry)
342 {
343 	struct rxr_medium_rtm_base_hdr *rtm_hdr;
344 
345 	rxr_pkt_init_rtm(ep, tx_entry, RXR_MEDIUM_MSGRTM_PKT,
346 			 tx_entry->bytes_sent, pkt_entry);
347 	rtm_hdr = rxr_get_medium_rtm_base_hdr(pkt_entry->pkt);
348 	rtm_hdr->data_len = tx_entry->total_len;
349 	rtm_hdr->offset = tx_entry->bytes_sent;
350 	return 0;
351 }
352 
rxr_pkt_init_medium_tagrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)353 ssize_t rxr_pkt_init_medium_tagrtm(struct rxr_ep *ep,
354 				   struct rxr_tx_entry *tx_entry,
355 				   struct rxr_pkt_entry *pkt_entry)
356 {
357 	struct rxr_medium_rtm_base_hdr *rtm_hdr;
358 
359 	rxr_pkt_init_rtm(ep, tx_entry, RXR_MEDIUM_TAGRTM_PKT,
360 			 tx_entry->bytes_sent, pkt_entry);
361 	rtm_hdr = rxr_get_medium_rtm_base_hdr(pkt_entry->pkt);
362 	rtm_hdr->data_len = tx_entry->total_len;
363 	rtm_hdr->offset = tx_entry->bytes_sent;
364 	rtm_hdr->hdr.flags |= RXR_REQ_TAGGED;
365 	rxr_pkt_rtm_settag(pkt_entry, tx_entry->tag);
366 	return 0;
367 }
368 
rxr_pkt_init_long_rtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,struct rxr_pkt_entry * pkt_entry)369 void rxr_pkt_init_long_rtm(struct rxr_ep *ep,
370 			   struct rxr_tx_entry *tx_entry,
371 			   int pkt_type,
372 			   struct rxr_pkt_entry *pkt_entry)
373 {
374 	struct rxr_long_rtm_base_hdr *rtm_hdr;
375 
376 	rxr_pkt_init_rtm(ep, tx_entry, pkt_type, 0, pkt_entry);
377 	rtm_hdr = rxr_get_long_rtm_base_hdr(pkt_entry->pkt);
378 	rtm_hdr->data_len = tx_entry->total_len;
379 	rtm_hdr->tx_id = tx_entry->tx_id;
380 	rtm_hdr->credit_request = tx_entry->credit_request;
381 }
382 
rxr_pkt_init_long_msgrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)383 ssize_t rxr_pkt_init_long_msgrtm(struct rxr_ep *ep,
384 				 struct rxr_tx_entry *tx_entry,
385 				 struct rxr_pkt_entry *pkt_entry)
386 {
387 	rxr_pkt_init_long_rtm(ep, tx_entry, RXR_LONG_MSGRTM_PKT, pkt_entry);
388 	return 0;
389 }
390 
rxr_pkt_init_long_tagrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)391 ssize_t rxr_pkt_init_long_tagrtm(struct rxr_ep *ep,
392 				 struct rxr_tx_entry *tx_entry,
393 				 struct rxr_pkt_entry *pkt_entry)
394 {
395 	struct rxr_base_hdr *base_hdr;
396 
397 	rxr_pkt_init_long_rtm(ep, tx_entry, RXR_LONG_TAGRTM_PKT, pkt_entry);
398 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
399 	base_hdr->flags |= RXR_REQ_TAGGED;
400 	rxr_pkt_rtm_settag(pkt_entry, tx_entry->tag);
401 	return 0;
402 }
403 
rxr_pkt_init_read_rtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,struct rxr_pkt_entry * pkt_entry)404 ssize_t rxr_pkt_init_read_rtm(struct rxr_ep *ep,
405 			      struct rxr_tx_entry *tx_entry,
406 			      int pkt_type,
407 			      struct rxr_pkt_entry *pkt_entry)
408 {
409 	struct rxr_read_rtm_base_hdr *rtm_hdr;
410 	struct fi_rma_iov *read_iov;
411 	int err;
412 
413 	rxr_pkt_init_req_hdr(ep, tx_entry, pkt_type, pkt_entry);
414 
415 	rtm_hdr = rxr_get_read_rtm_base_hdr(pkt_entry->pkt);
416 	rtm_hdr->hdr.flags |= RXR_REQ_MSG;
417 	rtm_hdr->hdr.msg_id = tx_entry->msg_id;
418 	rtm_hdr->data_len = tx_entry->total_len;
419 	rtm_hdr->tx_id = tx_entry->tx_id;
420 	rtm_hdr->read_iov_count = tx_entry->iov_count;
421 
422 	read_iov = (struct fi_rma_iov *)((char *)pkt_entry->pkt + pkt_entry->hdr_size);
423 	err = rxr_read_init_iov(ep, tx_entry, read_iov);
424 	if (OFI_UNLIKELY(err))
425 		return err;
426 
427 	pkt_entry->pkt_size = pkt_entry->hdr_size + tx_entry->iov_count * sizeof(struct fi_rma_iov);
428 	return 0;
429 }
430 
rxr_pkt_init_read_msgrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)431 ssize_t rxr_pkt_init_read_msgrtm(struct rxr_ep *ep,
432 				 struct rxr_tx_entry *tx_entry,
433 				 struct rxr_pkt_entry *pkt_entry)
434 {
435 	return rxr_pkt_init_read_rtm(ep, tx_entry, RXR_READ_MSGRTM_PKT, pkt_entry);
436 }
437 
rxr_pkt_init_read_tagrtm(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)438 ssize_t rxr_pkt_init_read_tagrtm(struct rxr_ep *ep,
439 				 struct rxr_tx_entry *tx_entry,
440 				 struct rxr_pkt_entry *pkt_entry)
441 {
442 	ssize_t err;
443 	struct rxr_base_hdr *base_hdr;
444 
445 	err = rxr_pkt_init_read_rtm(ep, tx_entry, RXR_READ_TAGRTM_PKT, pkt_entry);
446 	if (err)
447 		return err;
448 
449 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
450 	base_hdr->flags |= RXR_REQ_TAGGED;
451 	rxr_pkt_rtm_settag(pkt_entry, tx_entry->tag);
452 	return 0;
453 }
454 
455 /*
456  *     handle_sent() functions
457  */
458 
459 /*
460  *         rxr_pkt_handle_eager_rtm_sent() is empty and is defined in rxr_pkt_type_req.h
461  */
rxr_pkt_handle_medium_rtm_sent(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)462 void rxr_pkt_handle_medium_rtm_sent(struct rxr_ep *ep,
463 				    struct rxr_pkt_entry *pkt_entry)
464 {
465 	struct rxr_tx_entry *tx_entry;
466 
467 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
468 	tx_entry->bytes_sent += rxr_pkt_req_data_size(pkt_entry);
469 }
470 
rxr_pkt_handle_long_rtm_sent(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)471 void rxr_pkt_handle_long_rtm_sent(struct rxr_ep *ep,
472 				  struct rxr_pkt_entry *pkt_entry)
473 {
474 	struct rxr_tx_entry *tx_entry;
475 
476 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
477 	tx_entry->bytes_sent += rxr_pkt_req_data_size(pkt_entry);
478 	assert(tx_entry->bytes_sent < tx_entry->total_len);
479 
480 	if (efa_mr_cache_enable || rxr_ep_is_cuda_mr(tx_entry->desc[0]))
481 		rxr_prepare_desc_send(rxr_ep_domain(ep), tx_entry);
482 }
483 
484 /*
485  *     handle_send_completion() functions
486  */
rxr_pkt_handle_eager_rtm_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)487 void rxr_pkt_handle_eager_rtm_send_completion(struct rxr_ep *ep,
488 					      struct rxr_pkt_entry *pkt_entry)
489 {
490 	struct rxr_tx_entry *tx_entry;
491 
492 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
493 	assert(tx_entry->total_len == rxr_pkt_req_data_size(pkt_entry));
494 	rxr_cq_handle_tx_completion(ep, tx_entry);
495 }
496 
rxr_pkt_handle_medium_rtm_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)497 void rxr_pkt_handle_medium_rtm_send_completion(struct rxr_ep *ep,
498 					       struct rxr_pkt_entry *pkt_entry)
499 {
500 	struct rxr_tx_entry *tx_entry;
501 
502 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
503 	tx_entry->bytes_acked += rxr_pkt_req_data_size(pkt_entry);
504 	if (tx_entry->total_len == tx_entry->bytes_acked)
505 		rxr_cq_handle_tx_completion(ep, tx_entry);
506 }
507 
rxr_pkt_handle_long_rtm_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)508 void rxr_pkt_handle_long_rtm_send_completion(struct rxr_ep *ep,
509 					     struct rxr_pkt_entry *pkt_entry)
510 {
511 	struct rxr_tx_entry *tx_entry;
512 
513 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
514 	tx_entry->bytes_acked += rxr_pkt_req_data_size(pkt_entry);
515 	if (tx_entry->total_len == tx_entry->bytes_acked)
516 		rxr_cq_handle_tx_completion(ep, tx_entry);
517 }
518 
519 /*
520  *     proc() functions
521  */
rxr_pkt_rtm_total_len(struct rxr_pkt_entry * pkt_entry)522 size_t rxr_pkt_rtm_total_len(struct rxr_pkt_entry *pkt_entry)
523 {
524 	struct rxr_base_hdr *base_hdr;
525 
526 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
527 	switch (base_hdr->type) {
528 	case RXR_EAGER_MSGRTM_PKT:
529 	case RXR_EAGER_TAGRTM_PKT:
530 		return rxr_pkt_req_data_size(pkt_entry);
531 	case RXR_MEDIUM_MSGRTM_PKT:
532 	case RXR_MEDIUM_TAGRTM_PKT:
533 		return rxr_get_medium_rtm_base_hdr(pkt_entry->pkt)->data_len;
534 	case RXR_LONG_MSGRTM_PKT:
535 	case RXR_LONG_TAGRTM_PKT:
536 		return rxr_get_long_rtm_base_hdr(pkt_entry->pkt)->data_len;
537 	case RXR_READ_MSGRTM_PKT:
538 	case RXR_READ_TAGRTM_PKT:
539 		return rxr_get_read_rtm_base_hdr(pkt_entry->pkt)->data_len;
540 	default:
541 		assert(0 && "Unknown REQ packet type\n");
542 	}
543 
544 	return 0;
545 }
546 
rxr_pkt_rtm_init_rx_entry(struct rxr_pkt_entry * pkt_entry,struct rxr_rx_entry * rx_entry)547 void rxr_pkt_rtm_init_rx_entry(struct rxr_pkt_entry *pkt_entry,
548 			       struct rxr_rx_entry *rx_entry)
549 {
550 	struct rxr_base_hdr *base_hdr;
551 
552 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
553 	if (base_hdr->flags & RXR_REQ_OPT_CQ_DATA_HDR) {
554 		rx_entry->rxr_flags |= RXR_REMOTE_CQ_DATA;
555 		rx_entry->cq_entry.flags |= FI_REMOTE_CQ_DATA;
556 		rx_entry->cq_entry.data = pkt_entry->cq_data;
557 	}
558 
559 	rx_entry->addr = pkt_entry->addr;
560 	rx_entry->msg_id = rxr_pkt_msg_id(pkt_entry);
561 	rx_entry->total_len = rxr_pkt_rtm_total_len(pkt_entry);
562 	rx_entry->tag = rxr_pkt_rtm_tag(pkt_entry);
563 	rx_entry->cq_entry.tag = rx_entry->tag;
564 }
565 
rxr_pkt_get_rtm_matched_rx_entry(struct rxr_ep * ep,struct dlist_entry * match,struct rxr_pkt_entry * pkt_entry)566 struct rxr_rx_entry *rxr_pkt_get_rtm_matched_rx_entry(struct rxr_ep *ep,
567 						      struct dlist_entry *match,
568 						      struct rxr_pkt_entry *pkt_entry)
569 {
570 	struct rxr_rx_entry *rx_entry;
571 
572 	assert(match);
573 	rx_entry = container_of(match, struct rxr_rx_entry, entry);
574 	if (rx_entry->rxr_flags & RXR_MULTI_RECV_POSTED) {
575 		rx_entry = rxr_ep_split_rx_entry(ep, rx_entry,
576 						 NULL, pkt_entry);
577 		if (OFI_UNLIKELY(!rx_entry)) {
578 			FI_WARN(&rxr_prov, FI_LOG_CQ,
579 				"RX entries exhausted.\n");
580 			efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
581 			return NULL;
582 		}
583 	} else {
584 		rxr_pkt_rtm_init_rx_entry(pkt_entry, rx_entry);
585 	}
586 
587 	rx_entry->state = RXR_RX_MATCHED;
588 
589 	if (!(rx_entry->fi_flags & FI_MULTI_RECV) ||
590 	    !rxr_msg_multi_recv_buffer_available(ep, rx_entry->master_entry))
591 		dlist_remove(match);
592 
593 	return rx_entry;
594 }
595 
596 static
rxr_pkt_rtm_match_recv_anyaddr(struct dlist_entry * item,const void * arg)597 int rxr_pkt_rtm_match_recv_anyaddr(struct dlist_entry *item, const void *arg)
598 {
599 	return 1;
600 }
601 
602 static
rxr_pkt_rtm_match_recv(struct dlist_entry * item,const void * arg)603 int rxr_pkt_rtm_match_recv(struct dlist_entry *item, const void *arg)
604 {
605 	const struct rxr_pkt_entry *pkt_entry = arg;
606 	struct rxr_rx_entry *rx_entry;
607 
608 	rx_entry = container_of(item, struct rxr_rx_entry, entry);
609 	return rxr_match_addr(rx_entry->addr, pkt_entry->addr);
610 }
611 
612 static
rxr_pkt_rtm_match_trecv_anyaddr(struct dlist_entry * item,const void * arg)613 int rxr_pkt_rtm_match_trecv_anyaddr(struct dlist_entry *item, const void *arg)
614 {
615 	struct rxr_pkt_entry *pkt_entry = (struct rxr_pkt_entry *)arg;
616 	struct rxr_rx_entry *rx_entry;
617 	uint64_t match_tag;
618 
619 	rx_entry = container_of(item, struct rxr_rx_entry, entry);
620 	match_tag = rxr_pkt_rtm_tag(pkt_entry);
621 
622 	return rxr_match_tag(rx_entry->cq_entry.tag, rx_entry->ignore,
623 			     match_tag);
624 }
625 
626 static
rxr_pkt_rtm_match_trecv(struct dlist_entry * item,const void * arg)627 int rxr_pkt_rtm_match_trecv(struct dlist_entry *item, const void *arg)
628 {
629 	struct rxr_pkt_entry *pkt_entry = (struct rxr_pkt_entry *)arg;
630 	struct rxr_rx_entry *rx_entry;
631 	uint64_t match_tag;
632 
633 	rx_entry = container_of(item, struct rxr_rx_entry, entry);
634 	match_tag = rxr_pkt_rtm_tag(pkt_entry);
635 
636 	return rxr_match_addr(rx_entry->addr, pkt_entry->addr) &&
637 	       rxr_match_tag(rx_entry->cq_entry.tag, rx_entry->ignore,
638 			     match_tag);
639 }
640 
641 static
rxr_pkt_get_msgrtm_rx_entry(struct rxr_ep * ep,struct rxr_pkt_entry ** pkt_entry_ptr)642 struct rxr_rx_entry *rxr_pkt_get_msgrtm_rx_entry(struct rxr_ep *ep,
643 						 struct rxr_pkt_entry **pkt_entry_ptr)
644 {
645 	struct rxr_rx_entry *rx_entry;
646 	struct dlist_entry *match;
647 	dlist_func_t *match_func;
648 	int pkt_type;
649 
650 	if (ep->util_ep.caps & FI_DIRECTED_RECV)
651 		match_func = &rxr_pkt_rtm_match_recv;
652 	else
653 		match_func = &rxr_pkt_rtm_match_recv_anyaddr;
654 
655 	match = dlist_find_first_match(&ep->rx_list, match_func,
656 	                               *pkt_entry_ptr);
657 	if (OFI_UNLIKELY(!match)) {
658 		/*
659 		 * rxr_ep_alloc_unexp_rx_entry_for_msgrtm() might release pkt_entry,
660 		 * thus we have to use pkt_entry_ptr here
661 		 */
662 		rx_entry = rxr_ep_alloc_unexp_rx_entry_for_msgrtm(ep, pkt_entry_ptr);
663 		if (OFI_UNLIKELY(!rx_entry)) {
664 			FI_WARN(&rxr_prov, FI_LOG_CQ,
665 				"RX entries exhausted.\n");
666 			efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
667 			return NULL;
668 		}
669 
670 	} else {
671 		rx_entry = rxr_pkt_get_rtm_matched_rx_entry(ep, match, *pkt_entry_ptr);
672 	}
673 
674 	pkt_type = rxr_get_base_hdr((*pkt_entry_ptr)->pkt)->type;
675 	if (pkt_type == RXR_MEDIUM_MSGRTM_PKT)
676 		rxr_pkt_rx_map_insert(ep, *pkt_entry_ptr, rx_entry);
677 
678 	return rx_entry;
679 }
680 
681 static
rxr_pkt_get_tagrtm_rx_entry(struct rxr_ep * ep,struct rxr_pkt_entry ** pkt_entry_ptr)682 struct rxr_rx_entry *rxr_pkt_get_tagrtm_rx_entry(struct rxr_ep *ep,
683 						 struct rxr_pkt_entry **pkt_entry_ptr)
684 {
685 	struct rxr_rx_entry *rx_entry;
686 	struct dlist_entry *match;
687 	dlist_func_t *match_func;
688 	int pkt_type;
689 
690 	if (ep->util_ep.caps & FI_DIRECTED_RECV)
691 		match_func = &rxr_pkt_rtm_match_trecv;
692 	else
693 		match_func = &rxr_pkt_rtm_match_trecv_anyaddr;
694 
695 	match = dlist_find_first_match(&ep->rx_tagged_list, match_func,
696 	                               *pkt_entry_ptr);
697 	if (OFI_UNLIKELY(!match)) {
698 		/*
699 		 * rxr_ep_alloc_unexp_rx_entry_for_tagrtm() might release pkt_entry,
700 		 * thus we have to use pkt_entry_ptr here
701 		 */
702 		rx_entry = rxr_ep_alloc_unexp_rx_entry_for_tagrtm(ep, pkt_entry_ptr);
703 		if (OFI_UNLIKELY(!rx_entry)) {
704 			efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
705 			return NULL;
706 		}
707 	} else {
708 		rx_entry = rxr_pkt_get_rtm_matched_rx_entry(ep, match, *pkt_entry_ptr);
709 	}
710 
711 	pkt_type = rxr_get_base_hdr((*pkt_entry_ptr)->pkt)->type;
712 	if (pkt_type == RXR_MEDIUM_TAGRTM_PKT)
713 		rxr_pkt_rx_map_insert(ep, *pkt_entry_ptr, rx_entry);
714 
715 	return rx_entry;
716 }
717 
rxr_pkt_proc_matched_read_rtm(struct rxr_ep * ep,struct rxr_rx_entry * rx_entry,struct rxr_pkt_entry * pkt_entry)718 ssize_t rxr_pkt_proc_matched_read_rtm(struct rxr_ep *ep,
719 				      struct rxr_rx_entry *rx_entry,
720 				      struct rxr_pkt_entry *pkt_entry)
721 {
722 	struct rxr_read_rtm_base_hdr *rtm_hdr;
723 	struct fi_rma_iov *read_iov;
724 
725 	rtm_hdr = rxr_get_read_rtm_base_hdr(pkt_entry->pkt);
726 	read_iov = (struct fi_rma_iov *)((char *)pkt_entry->pkt + pkt_entry->hdr_size);
727 
728 	rx_entry->tx_id = rtm_hdr->tx_id;
729 	rx_entry->rma_iov_count = rtm_hdr->read_iov_count;
730 	memcpy(rx_entry->rma_iov, read_iov,
731 	       rx_entry->rma_iov_count * sizeof(struct fi_rma_iov));
732 
733 	rxr_pkt_entry_release_rx(ep, pkt_entry);
734 
735 	/* truncate rx_entry->iov to save memory registration pages because we
736 	 * need to do memory registration for the receiving buffer.
737 	 */
738 	ofi_truncate_iov(rx_entry->iov, &rx_entry->iov_count, rx_entry->total_len);
739 	return rxr_read_post_or_queue(ep, RXR_RX_ENTRY, rx_entry);
740 }
741 
rxr_pkt_proc_matched_medium_rtm(struct rxr_ep * ep,struct rxr_rx_entry * rx_entry,struct rxr_pkt_entry * pkt_entry)742 ssize_t rxr_pkt_proc_matched_medium_rtm(struct rxr_ep *ep,
743 					struct rxr_rx_entry *rx_entry,
744 					struct rxr_pkt_entry *pkt_entry)
745 {
746 	struct rxr_pkt_entry *cur;
747 	char *data;
748 	size_t offset, data_size;
749 
750 	cur = pkt_entry;
751 	while (cur) {
752 		data = (char *)cur->pkt + cur->hdr_size;
753 		offset = rxr_get_medium_rtm_base_hdr(cur->pkt)->offset;
754 		data_size = cur->pkt_size - cur->hdr_size;
755 		rxr_copy_to_rx(data, data_size, rx_entry, offset);
756 		rx_entry->bytes_done += data_size;
757 		cur = cur->next;
758 	}
759 
760 	if (rx_entry->total_len == rx_entry->bytes_done) {
761 		rxr_pkt_rx_map_remove(ep, pkt_entry, rx_entry);
762 		/*
763 		 * rxr_cq_handle_rx_completion() releases pkt_entry, thus
764 		 * we do not release it here.
765 		 */
766 		rxr_cq_handle_rx_completion(ep, pkt_entry, rx_entry);
767 		rxr_msg_multi_recv_free_posted_entry(ep, rx_entry);
768 		rxr_release_rx_entry(ep, rx_entry);
769 		return 0;
770 	}
771 
772 	rxr_pkt_entry_release_rx(ep, pkt_entry);
773 	return 0;
774 }
775 
rxr_pkt_proc_matched_rtm(struct rxr_ep * ep,struct rxr_rx_entry * rx_entry,struct rxr_pkt_entry * pkt_entry)776 ssize_t rxr_pkt_proc_matched_rtm(struct rxr_ep *ep,
777 				 struct rxr_rx_entry *rx_entry,
778 				 struct rxr_pkt_entry *pkt_entry)
779 {
780 	int pkt_type;
781 	char *data;
782 	size_t data_size, bytes_left;
783 	ssize_t ret;
784 
785 	assert(rx_entry->state == RXR_RX_MATCHED);
786 
787 	/* Adjust rx_entry->cq_entry.len as needed.
788 	 * Initialy rx_entry->cq_entry.len is total recv buffer size.
789 	 * rx_entry->total_len is from REQ packet and is total send buffer size.
790 	 * if send buffer size < recv buffer size, we adjust value of rx_entry->cq_entry.len
791 	 * if send buffer size > recv buffer size, we have a truncated message and will
792 	 * write error CQ entry.
793 	 */
794 	if (rx_entry->cq_entry.len > rx_entry->total_len)
795 		rx_entry->cq_entry.len = rx_entry->total_len;
796 
797 	pkt_type = rxr_get_base_hdr(pkt_entry->pkt)->type;
798 	if (pkt_type == RXR_READ_MSGRTM_PKT || pkt_type == RXR_READ_TAGRTM_PKT)
799 		return rxr_pkt_proc_matched_read_rtm(ep, rx_entry, pkt_entry);
800 
801 	if (pkt_type == RXR_MEDIUM_MSGRTM_PKT || pkt_type == RXR_MEDIUM_TAGRTM_PKT)
802 		return rxr_pkt_proc_matched_medium_rtm(ep, rx_entry, pkt_entry);
803 
804 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
805 	data_size = pkt_entry->pkt_size - pkt_entry->hdr_size;
806 	bytes_left = rxr_pkt_req_copy_data(rx_entry, pkt_entry,
807 					   data, data_size);
808 	if (!bytes_left) {
809 		/*
810 		 * rxr_cq_handle_rx_completion() releases pkt_entry, thus
811 		 * we do not release it here.
812 		 */
813 		rxr_cq_handle_rx_completion(ep, pkt_entry, rx_entry);
814 		rxr_msg_multi_recv_free_posted_entry(ep, rx_entry);
815 		rxr_release_rx_entry(ep, rx_entry);
816 		ret = 0;
817 	} else {
818 		/*
819 		 * long message protocol
820 		 */
821 #if ENABLE_DEBUG
822 		dlist_insert_tail(&rx_entry->rx_pending_entry, &ep->rx_pending_list);
823 		ep->rx_pending++;
824 #endif
825 		rx_entry->state = RXR_RX_RECV;
826 		rx_entry->tx_id = rxr_get_long_rtm_base_hdr(pkt_entry->pkt)->tx_id;
827 		/* we have noticed using the default value achieve better bandwidth */
828 		rx_entry->credit_request = rxr_env.tx_min_credits;
829 		ret = rxr_pkt_post_ctrl_or_queue(ep, RXR_RX_ENTRY, rx_entry, RXR_CTS_PKT, 0);
830 		rxr_pkt_entry_release_rx(ep, pkt_entry);
831 	}
832 
833 	return ret;
834 }
835 
rxr_pkt_proc_msgrtm(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)836 ssize_t rxr_pkt_proc_msgrtm(struct rxr_ep *ep,
837 			    struct rxr_pkt_entry *pkt_entry)
838 {
839 	ssize_t err;
840 	struct rxr_rx_entry *rx_entry;
841 
842 	rx_entry = rxr_pkt_get_msgrtm_rx_entry(ep, &pkt_entry);
843 	if (OFI_UNLIKELY(!rx_entry)) {
844 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
845 		rxr_pkt_entry_release_rx(ep, pkt_entry);
846 		return -FI_ENOBUFS;
847 	}
848 
849 	if (rx_entry->state == RXR_RX_MATCHED) {
850 		err = rxr_pkt_proc_matched_rtm(ep, rx_entry, pkt_entry);
851 		if (OFI_UNLIKELY(err)) {
852 			if (rxr_cq_handle_rx_error(ep, rx_entry, err)) {
853 				assert(0 && "cannot write cq error entry");
854 				efa_eq_write_error(&ep->util_ep, -err, err);
855 			}
856 			rxr_pkt_entry_release_rx(ep, pkt_entry);
857 			rxr_release_rx_entry(ep, rx_entry);
858 			return err;
859 		}
860 	}
861 
862 	return 0;
863 }
864 
rxr_pkt_proc_tagrtm(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)865 ssize_t rxr_pkt_proc_tagrtm(struct rxr_ep *ep,
866 			    struct rxr_pkt_entry *pkt_entry)
867 {
868 	ssize_t err;
869 	struct rxr_rx_entry *rx_entry;
870 
871 	rx_entry = rxr_pkt_get_tagrtm_rx_entry(ep, &pkt_entry);
872 	if (OFI_UNLIKELY(!rx_entry)) {
873 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
874 		rxr_pkt_entry_release_rx(ep, pkt_entry);
875 		return -FI_ENOBUFS;
876 	}
877 
878 	if (rx_entry->state == RXR_RX_MATCHED) {
879 		err = rxr_pkt_proc_matched_rtm(ep, rx_entry, pkt_entry);
880 		if (OFI_UNLIKELY(err)) {
881 			if (rxr_cq_handle_rx_error(ep, rx_entry, err)) {
882 				assert(0 && "cannot write error cq entry");
883 				efa_eq_write_error(&ep->util_ep, -err, err);
884 			}
885 			rxr_pkt_entry_release_rx(ep, pkt_entry);
886 			rxr_release_rx_entry(ep, rx_entry);
887 			return err;
888 		}
889 	}
890 
891 	return 0;
892 }
893 
894 /*
895  * proc() functions called by rxr_pkt_handle_recv_completion()
896  */
rxr_pkt_proc_rtm_rta(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)897 ssize_t rxr_pkt_proc_rtm_rta(struct rxr_ep *ep,
898 			     struct rxr_pkt_entry *pkt_entry)
899 {
900 	struct rxr_base_hdr *base_hdr;
901 
902 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
903 	assert(base_hdr->type >= RXR_BASELINE_REQ_PKT_BEGIN);
904 
905 	switch (base_hdr->type) {
906 	case RXR_EAGER_MSGRTM_PKT:
907 	case RXR_MEDIUM_MSGRTM_PKT:
908 	case RXR_LONG_MSGRTM_PKT:
909 	case RXR_READ_MSGRTM_PKT:
910 		return rxr_pkt_proc_msgrtm(ep, pkt_entry);
911 	case RXR_EAGER_TAGRTM_PKT:
912 	case RXR_MEDIUM_TAGRTM_PKT:
913 	case RXR_LONG_TAGRTM_PKT:
914 	case RXR_READ_TAGRTM_PKT:
915 		return rxr_pkt_proc_tagrtm(ep, pkt_entry);
916 	case RXR_WRITE_RTA_PKT:
917 		return rxr_pkt_proc_write_rta(ep, pkt_entry);
918 	case RXR_FETCH_RTA_PKT:
919 		return rxr_pkt_proc_fetch_rta(ep, pkt_entry);
920 	case RXR_COMPARE_RTA_PKT:
921 		return rxr_pkt_proc_compare_rta(ep, pkt_entry);
922 	default:
923 		FI_WARN(&rxr_prov, FI_LOG_EP_CTRL,
924 			"Unknown packet type ID: %d\n",
925 		       base_hdr->type);
926 		if (rxr_cq_handle_cq_error(ep, -FI_EINVAL))
927 			assert(0 && "failed to write err cq entry");
928 	}
929 
930 	return -FI_EINVAL;
931 }
932 
rxr_pkt_handle_rtm_rta_recv(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)933 void rxr_pkt_handle_rtm_rta_recv(struct rxr_ep *ep,
934 				 struct rxr_pkt_entry *pkt_entry)
935 {
936 	struct rxr_base_hdr *base_hdr;
937 	struct rxr_peer *peer;
938 	bool need_ordering;
939 	int ret, msg_id;
940 
941 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
942 	assert(base_hdr->type >= RXR_BASELINE_REQ_PKT_BEGIN);
943 
944 	if (base_hdr->type == RXR_MEDIUM_MSGRTM_PKT || base_hdr->type == RXR_MEDIUM_TAGRTM_PKT) {
945 		struct rxr_rx_entry *rx_entry;
946 		struct rxr_pkt_entry *unexp_pkt_entry;
947 
948 		rx_entry = rxr_pkt_rx_map_lookup(ep, pkt_entry);
949 		if (rx_entry) {
950 			if (rx_entry->state == RXR_RX_MATCHED) {
951 				rxr_pkt_proc_matched_medium_rtm(ep, rx_entry, pkt_entry);
952 			} else {
953 				assert(rx_entry->unexp_pkt);
954 				unexp_pkt_entry = rxr_pkt_get_unexp(ep, &pkt_entry);
955 				rxr_pkt_entry_append(rx_entry->unexp_pkt, unexp_pkt_entry);
956 			}
957 
958 			return;
959 		}
960 	}
961 
962 	need_ordering = false;
963 	peer = rxr_ep_get_peer(ep, pkt_entry->addr);
964 	assert(peer);
965 	if (!peer->is_local) {
966 		/*
967  		 * only need to reorder msg for efa_ep
968 		 */
969 		base_hdr = (struct rxr_base_hdr *)pkt_entry->pkt;
970 		if ((base_hdr->flags & RXR_REQ_MSG) && rxr_need_sas_ordering(ep))
971 			need_ordering = true;
972 		else if (base_hdr->flags & RXR_REQ_ATOMIC)
973 			need_ordering = true;
974 	}
975 
976 	if (!need_ordering) {
977 		/* rxr_pkt_proc_rtm will write error cq entry if needed */
978 		rxr_pkt_proc_rtm_rta(ep, pkt_entry);
979 		return;
980 	}
981 
982 	msg_id = rxr_pkt_msg_id(pkt_entry);
983 	ret = rxr_cq_reorder_msg(ep, peer, pkt_entry);
984 	if (ret == 1) {
985 		/* Packet was queued */
986 		return;
987 	}
988 
989 	if (OFI_UNLIKELY(ret == -FI_EALREADY)) {
990 		/* Packet with same msg_id has been processed before */
991 		FI_WARN(&rxr_prov, FI_LOG_EP_CTRL,
992 			"Invalid msg_id: %" PRIu32
993 			" robuf->exp_msg_id: %" PRIu32 "\n",
994 		       msg_id, peer->robuf->exp_msg_id);
995 		efa_eq_write_error(&ep->util_ep, FI_EIO, ret);
996 		rxr_pkt_entry_release_rx(ep, pkt_entry);
997 		return;
998 	}
999 
1000 	if (OFI_UNLIKELY(ret == -FI_ENOMEM)) {
1001 		/* running out of memory while copy packet */
1002 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1003 		return;
1004 	}
1005 
1006 	if (OFI_UNLIKELY(ret < 0)) {
1007 		FI_WARN(&rxr_prov, FI_LOG_EP_CTRL,
1008 			"Unknown error %d processing REQ packet msg_id: %"
1009 			PRIu32 "\n", ret, msg_id);
1010 		efa_eq_write_error(&ep->util_ep, FI_EIO, ret);
1011 		return;
1012 	}
1013 
1014 
1015 	/*
1016 	 * rxr_pkt_proc_rtm_rta() will write error cq entry if needed,
1017 	 * thus we do not write error cq entry
1018 	 */
1019 	ret = rxr_pkt_proc_rtm_rta(ep, pkt_entry);
1020 	if (OFI_UNLIKELY(ret))
1021 		return;
1022 
1023 	ofi_recvwin_slide(peer->robuf);
1024 	/* process pending items in reorder buff */
1025 	rxr_cq_proc_pending_items_in_recvwin(ep, peer);
1026 }
1027 
1028 /*
1029  * RTW pakcet type functions
1030  */
rxr_pkt_init_rtw_data(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry,struct fi_rma_iov * rma_iov)1031 void rxr_pkt_init_rtw_data(struct rxr_ep *ep,
1032 			   struct rxr_tx_entry *tx_entry,
1033 			   struct rxr_pkt_entry *pkt_entry,
1034 			   struct fi_rma_iov *rma_iov)
1035 {
1036 	char *data;
1037 	size_t data_size;
1038 	int i;
1039 
1040 	for (i = 0; i < tx_entry->rma_iov_count; ++i) {
1041 		rma_iov[i].addr = tx_entry->rma_iov[i].addr;
1042 		rma_iov[i].len = tx_entry->rma_iov[i].len;
1043 		rma_iov[i].key = tx_entry->rma_iov[i].key;
1044 	}
1045 
1046 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1047 	data_size = ofi_copy_from_iov(data, ep->mtu_size - pkt_entry->hdr_size,
1048 				      tx_entry->iov, tx_entry->iov_count, 0);
1049 
1050 	pkt_entry->pkt_size = pkt_entry->hdr_size + data_size;
1051 	pkt_entry->x_entry = tx_entry;
1052 }
1053 
rxr_pkt_init_eager_rtw(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1054 ssize_t rxr_pkt_init_eager_rtw(struct rxr_ep *ep,
1055 			       struct rxr_tx_entry *tx_entry,
1056 			       struct rxr_pkt_entry *pkt_entry)
1057 {
1058 	struct rxr_eager_rtw_hdr *rtw_hdr;
1059 
1060 	assert(tx_entry->op == ofi_op_write);
1061 
1062 	rtw_hdr = (struct rxr_eager_rtw_hdr *)pkt_entry->pkt;
1063 	rtw_hdr->rma_iov_count = tx_entry->rma_iov_count;
1064 	rxr_pkt_init_req_hdr(ep, tx_entry, RXR_EAGER_RTW_PKT, pkt_entry);
1065 	rxr_pkt_init_rtw_data(ep, tx_entry, pkt_entry, rtw_hdr->rma_iov);
1066 	return 0;
1067 }
1068 
rxr_pkt_init_long_rtw(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1069 ssize_t rxr_pkt_init_long_rtw(struct rxr_ep *ep,
1070 			      struct rxr_tx_entry *tx_entry,
1071 			      struct rxr_pkt_entry *pkt_entry)
1072 {
1073 	struct rxr_long_rtw_hdr *rtw_hdr;
1074 
1075 	assert(tx_entry->op == ofi_op_write);
1076 
1077 	rtw_hdr = (struct rxr_long_rtw_hdr *)pkt_entry->pkt;
1078 	rtw_hdr->rma_iov_count = tx_entry->rma_iov_count;
1079 	rtw_hdr->data_len = tx_entry->total_len;
1080 	rtw_hdr->tx_id = tx_entry->tx_id;
1081 	rtw_hdr->credit_request = tx_entry->credit_request;
1082 	rxr_pkt_init_req_hdr(ep, tx_entry, RXR_LONG_RTW_PKT, pkt_entry);
1083 	rxr_pkt_init_rtw_data(ep, tx_entry, pkt_entry, rtw_hdr->rma_iov);
1084 	return 0;
1085 }
1086 
rxr_pkt_init_read_rtw(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1087 ssize_t rxr_pkt_init_read_rtw(struct rxr_ep *ep,
1088 			      struct rxr_tx_entry *tx_entry,
1089 			      struct rxr_pkt_entry *pkt_entry)
1090 {
1091 	struct rxr_read_rtw_hdr *rtw_hdr;
1092 	struct fi_rma_iov *rma_iov, *read_iov;
1093 	int i, err;
1094 
1095 	assert(tx_entry->op == ofi_op_write);
1096 
1097 	rtw_hdr = (struct rxr_read_rtw_hdr *)pkt_entry->pkt;
1098 	rtw_hdr->rma_iov_count = tx_entry->rma_iov_count;
1099 	rtw_hdr->data_len = tx_entry->total_len;
1100 	rtw_hdr->tx_id = tx_entry->tx_id;
1101 	rtw_hdr->read_iov_count = tx_entry->iov_count;
1102 	rxr_pkt_init_req_hdr(ep, tx_entry, RXR_READ_RTW_PKT, pkt_entry);
1103 
1104 	rma_iov = rtw_hdr->rma_iov;
1105 	for (i = 0; i < tx_entry->rma_iov_count; ++i) {
1106 		rma_iov[i].addr = tx_entry->rma_iov[i].addr;
1107 		rma_iov[i].len = tx_entry->rma_iov[i].len;
1108 		rma_iov[i].key = tx_entry->rma_iov[i].key;
1109 	}
1110 
1111 	read_iov = (struct fi_rma_iov *)((char *)pkt_entry->pkt + pkt_entry->hdr_size);
1112 	err = rxr_read_init_iov(ep, tx_entry, read_iov);
1113 	if (OFI_UNLIKELY(err))
1114 		return err;
1115 
1116 	pkt_entry->pkt_size = pkt_entry->hdr_size + tx_entry->iov_count * sizeof(struct fi_rma_iov);
1117 	return 0;
1118 }
1119 
1120 /*
1121  *     handle_sent() functions for RTW packet types
1122  *
1123  *         rxr_pkt_handle_long_rtw_sent() is empty and is defined in rxr_pkt_type_req.h
1124  */
rxr_pkt_handle_long_rtw_sent(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1125 void rxr_pkt_handle_long_rtw_sent(struct rxr_ep *ep,
1126 				  struct rxr_pkt_entry *pkt_entry)
1127 {
1128 	struct rxr_tx_entry *tx_entry;
1129 
1130 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
1131 	tx_entry->bytes_sent += rxr_pkt_req_data_size(pkt_entry);
1132 	assert(tx_entry->bytes_sent < tx_entry->total_len);
1133 	if (efa_mr_cache_enable || rxr_ep_is_cuda_mr(tx_entry->desc[0]))
1134 		rxr_prepare_desc_send(rxr_ep_domain(ep), tx_entry);
1135 }
1136 
1137 /*
1138  *     handle_send_completion() functions
1139  */
rxr_pkt_handle_eager_rtw_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1140 void rxr_pkt_handle_eager_rtw_send_completion(struct rxr_ep *ep,
1141 					      struct rxr_pkt_entry *pkt_entry)
1142 {
1143 	struct rxr_tx_entry *tx_entry;
1144 
1145 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
1146 	assert(tx_entry->total_len == rxr_pkt_req_data_size(pkt_entry));
1147 	rxr_cq_handle_tx_completion(ep, tx_entry);
1148 }
1149 
rxr_pkt_handle_long_rtw_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1150 void rxr_pkt_handle_long_rtw_send_completion(struct rxr_ep *ep,
1151 					     struct rxr_pkt_entry *pkt_entry)
1152 {
1153 	struct rxr_tx_entry *tx_entry;
1154 
1155 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
1156 	tx_entry->bytes_acked += rxr_pkt_req_data_size(pkt_entry);
1157 	if (tx_entry->total_len == tx_entry->bytes_acked)
1158 		rxr_cq_handle_tx_completion(ep, tx_entry);
1159 }
1160 
1161 /*
1162  *     handle_recv() functions
1163  */
1164 
1165 static
rxr_pkt_alloc_rtw_rx_entry(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1166 struct rxr_rx_entry *rxr_pkt_alloc_rtw_rx_entry(struct rxr_ep *ep,
1167 						struct rxr_pkt_entry *pkt_entry)
1168 {
1169 	struct rxr_rx_entry *rx_entry;
1170 	struct rxr_base_hdr *base_hdr;
1171 	struct fi_msg msg = {0};
1172 
1173 	msg.addr = pkt_entry->addr;
1174 	rx_entry = rxr_ep_get_rx_entry(ep, &msg, 0, ~0, ofi_op_write, 0);
1175 	if (OFI_UNLIKELY(!rx_entry))
1176 		return NULL;
1177 
1178 	base_hdr = rxr_get_base_hdr(pkt_entry->pkt);
1179 	if (base_hdr->flags & RXR_REQ_OPT_CQ_DATA_HDR) {
1180 		rx_entry->rxr_flags |= RXR_REMOTE_CQ_DATA;
1181 		rx_entry->cq_entry.flags |= FI_REMOTE_CQ_DATA;
1182 		rx_entry->cq_entry.data = pkt_entry->cq_data;
1183 	}
1184 
1185 	rx_entry->addr = pkt_entry->addr;
1186 	rx_entry->bytes_done = 0;
1187 	return rx_entry;
1188 }
1189 
rxr_pkt_handle_eager_rtw_recv(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1190 void rxr_pkt_handle_eager_rtw_recv(struct rxr_ep *ep,
1191 				   struct rxr_pkt_entry *pkt_entry)
1192 {
1193 	struct rxr_rx_entry *rx_entry;
1194 	struct rxr_eager_rtw_hdr *rtw_hdr;
1195 	char *data;
1196 	size_t data_size;
1197 	ssize_t err, bytes_left;
1198 
1199 	rx_entry = rxr_pkt_alloc_rtw_rx_entry(ep, pkt_entry);
1200 	if (!rx_entry) {
1201 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1202 			"RX entries exhausted.\n");
1203 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1204 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1205 		return;
1206 	}
1207 
1208 	rtw_hdr = (struct rxr_eager_rtw_hdr *)pkt_entry->pkt;
1209 	rx_entry->iov_count = rtw_hdr->rma_iov_count;
1210 	err = rxr_rma_verified_copy_iov(ep, rtw_hdr->rma_iov, rtw_hdr->rma_iov_count,
1211 					FI_REMOTE_WRITE, rx_entry->iov);
1212 	if (OFI_UNLIKELY(err)) {
1213 		FI_WARN(&rxr_prov, FI_LOG_CQ, "RMA address verify failed!\n");
1214 		efa_eq_write_error(&ep->util_ep, FI_EIO, err);
1215 		rxr_release_rx_entry(ep, rx_entry);
1216 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1217 		return;
1218 	}
1219 
1220 	rx_entry->cq_entry.flags |= (FI_RMA | FI_WRITE);
1221 	rx_entry->cq_entry.len = ofi_total_iov_len(rx_entry->iov, rx_entry->iov_count);
1222 	rx_entry->cq_entry.buf = rx_entry->iov[0].iov_base;
1223 	rx_entry->total_len = rx_entry->cq_entry.len;
1224 
1225 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1226 	data_size = pkt_entry->pkt_size - pkt_entry->hdr_size;
1227 	bytes_left = rxr_pkt_req_copy_data(rx_entry, pkt_entry, data, data_size);
1228 	if (bytes_left != 0) {
1229 		FI_WARN(&rxr_prov, FI_LOG_CQ, "Eager RTM bytes_left is %ld, which should be 0.",
1230 			bytes_left);
1231 		FI_WARN(&rxr_prov, FI_LOG_CQ, "target buffer: %p length: %ld", rx_entry->iov[0].iov_base,
1232 			rx_entry->iov[0].iov_len);
1233 		efa_eq_write_error(&ep->util_ep, FI_EINVAL, -FI_EINVAL);
1234 		rxr_release_rx_entry(ep, rx_entry);
1235 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1236 		return;
1237 	}
1238 
1239 	if (rx_entry->cq_entry.flags & FI_REMOTE_CQ_DATA)
1240 		rxr_cq_write_rx_completion(ep, rx_entry);
1241 
1242 	rxr_release_rx_entry(ep, rx_entry);
1243 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1244 }
1245 
rxr_pkt_handle_long_rtw_recv(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1246 void rxr_pkt_handle_long_rtw_recv(struct rxr_ep *ep,
1247 				  struct rxr_pkt_entry *pkt_entry)
1248 {
1249 	struct rxr_rx_entry *rx_entry;
1250 	struct rxr_long_rtw_hdr *rtw_hdr;
1251 	char *data;
1252 	size_t data_size;
1253 	ssize_t err, bytes_left;
1254 
1255 	rx_entry = rxr_pkt_alloc_rtw_rx_entry(ep, pkt_entry);
1256 	if (!rx_entry) {
1257 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1258 			"RX entries exhausted.\n");
1259 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1260 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1261 		return;
1262 	}
1263 
1264 	rtw_hdr = (struct rxr_long_rtw_hdr *)pkt_entry->pkt;
1265 	rx_entry->iov_count = rtw_hdr->rma_iov_count;
1266 	err = rxr_rma_verified_copy_iov(ep, rtw_hdr->rma_iov, rtw_hdr->rma_iov_count,
1267 					FI_REMOTE_WRITE, rx_entry->iov);
1268 	if (OFI_UNLIKELY(err)) {
1269 		FI_WARN(&rxr_prov, FI_LOG_CQ, "RMA address verify failed!\n");
1270 		efa_eq_write_error(&ep->util_ep, FI_EIO, err);
1271 		rxr_release_rx_entry(ep, rx_entry);
1272 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1273 		return;
1274 	}
1275 
1276 	rx_entry->cq_entry.flags |= (FI_RMA | FI_WRITE);
1277 	rx_entry->cq_entry.len = ofi_total_iov_len(rx_entry->iov, rx_entry->iov_count);
1278 	rx_entry->cq_entry.buf = rx_entry->iov[0].iov_base;
1279 	rx_entry->total_len = rx_entry->cq_entry.len;
1280 
1281 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1282 	data_size = pkt_entry->pkt_size - pkt_entry->hdr_size;
1283 	bytes_left = rxr_pkt_req_copy_data(rx_entry, pkt_entry, data, data_size);
1284 	if (OFI_UNLIKELY(bytes_left <= 0)) {
1285 		FI_WARN(&rxr_prov, FI_LOG_CQ, "Long RTM bytes_left is %ld, which should be > 0.",
1286 			bytes_left);
1287 		FI_WARN(&rxr_prov, FI_LOG_CQ, "target buffer: %p length: %ld", rx_entry->iov[0].iov_base,
1288 			rx_entry->iov[0].iov_len);
1289 		efa_eq_write_error(&ep->util_ep, FI_EINVAL, -FI_EINVAL);
1290 		rxr_release_rx_entry(ep, rx_entry);
1291 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1292 		return;
1293 	}
1294 
1295 #if ENABLE_DEBUG
1296 	dlist_insert_tail(&rx_entry->rx_pending_entry, &ep->rx_pending_list);
1297 	ep->rx_pending++;
1298 #endif
1299 	rx_entry->state = RXR_RX_RECV;
1300 	rx_entry->tx_id = rtw_hdr->tx_id;
1301 	rx_entry->credit_request = rxr_env.tx_min_credits;
1302 	err = rxr_pkt_post_ctrl_or_queue(ep, RXR_RX_ENTRY, rx_entry, RXR_CTS_PKT, 0);
1303 	if (OFI_UNLIKELY(err)) {
1304 		FI_WARN(&rxr_prov, FI_LOG_CQ, "Cannot post CTS packet\n");
1305 		rxr_cq_handle_rx_error(ep, rx_entry, err);
1306 		rxr_release_rx_entry(ep, rx_entry);
1307 	}
1308 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1309 }
1310 
rxr_pkt_handle_read_rtw_recv(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1311 void rxr_pkt_handle_read_rtw_recv(struct rxr_ep *ep,
1312 				  struct rxr_pkt_entry *pkt_entry)
1313 {
1314 	struct rxr_rx_entry *rx_entry;
1315 	struct rxr_read_rtw_hdr *rtw_hdr;
1316 	struct fi_rma_iov *read_iov;
1317 	ssize_t err;
1318 
1319 	rx_entry = rxr_pkt_alloc_rtw_rx_entry(ep, pkt_entry);
1320 	if (!rx_entry) {
1321 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1322 			"RX entries exhausted.\n");
1323 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1324 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1325 		return;
1326 	}
1327 
1328 	rtw_hdr = (struct rxr_read_rtw_hdr *)pkt_entry->pkt;
1329 	rx_entry->iov_count = rtw_hdr->rma_iov_count;
1330 	err = rxr_rma_verified_copy_iov(ep, rtw_hdr->rma_iov, rtw_hdr->rma_iov_count,
1331 					FI_REMOTE_WRITE, rx_entry->iov);
1332 	if (OFI_UNLIKELY(err)) {
1333 		FI_WARN(&rxr_prov, FI_LOG_CQ, "RMA address verify failed!\n");
1334 		efa_eq_write_error(&ep->util_ep, FI_EINVAL, -FI_EINVAL);
1335 		rxr_release_rx_entry(ep, rx_entry);
1336 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1337 		return;
1338 	}
1339 
1340 	rx_entry->cq_entry.flags |= (FI_RMA | FI_WRITE);
1341 	rx_entry->cq_entry.len = ofi_total_iov_len(rx_entry->iov, rx_entry->iov_count);
1342 	rx_entry->cq_entry.buf = rx_entry->iov[0].iov_base;
1343 	rx_entry->total_len = rx_entry->cq_entry.len;
1344 
1345 	read_iov = (struct fi_rma_iov *)((char *)pkt_entry->pkt + pkt_entry->hdr_size);
1346 	rx_entry->addr = pkt_entry->addr;
1347 	rx_entry->tx_id = rtw_hdr->tx_id;
1348 	rx_entry->rma_iov_count = rtw_hdr->read_iov_count;
1349 	memcpy(rx_entry->rma_iov, read_iov,
1350 	       rx_entry->rma_iov_count * sizeof(struct fi_rma_iov));
1351 
1352 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1353 	err = rxr_read_post_or_queue(ep, RXR_RX_ENTRY, rx_entry);
1354 	if (OFI_UNLIKELY(err)) {
1355 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1356 			"RDMA post read or queue failed.\n");
1357 		efa_eq_write_error(&ep->util_ep, err, err);
1358 		rxr_release_rx_entry(ep, rx_entry);
1359 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1360 	}
1361 }
1362 
1363 /*
1364  * RTR packet functions
1365  *     init() functions for RTR packets
1366  */
rxr_pkt_init_rtr(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,int window,struct rxr_pkt_entry * pkt_entry)1367 void rxr_pkt_init_rtr(struct rxr_ep *ep,
1368 		      struct rxr_tx_entry *tx_entry,
1369 		      int pkt_type, int window,
1370 		      struct rxr_pkt_entry *pkt_entry)
1371 {
1372 	struct rxr_rtr_hdr *rtr_hdr;
1373 	int i;
1374 
1375 	assert(tx_entry->op == ofi_op_read_req);
1376 	rtr_hdr = (struct rxr_rtr_hdr *)pkt_entry->pkt;
1377 	rtr_hdr->rma_iov_count = tx_entry->rma_iov_count;
1378 	rxr_pkt_init_req_hdr(ep, tx_entry, pkt_type, pkt_entry);
1379 	rtr_hdr->data_len = tx_entry->total_len;
1380 	rtr_hdr->read_req_rx_id = tx_entry->rma_loc_rx_id;
1381 	rtr_hdr->read_req_window = window;
1382 	for (i = 0; i < tx_entry->rma_iov_count; ++i) {
1383 		rtr_hdr->rma_iov[i].addr = tx_entry->rma_iov[i].addr;
1384 		rtr_hdr->rma_iov[i].len = tx_entry->rma_iov[i].len;
1385 		rtr_hdr->rma_iov[i].key = tx_entry->rma_iov[i].key;
1386 	}
1387 
1388 	pkt_entry->pkt_size = pkt_entry->hdr_size;
1389 	pkt_entry->x_entry = tx_entry;
1390 }
1391 
rxr_pkt_init_short_rtr(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1392 ssize_t rxr_pkt_init_short_rtr(struct rxr_ep *ep,
1393 			       struct rxr_tx_entry *tx_entry,
1394 			       struct rxr_pkt_entry *pkt_entry)
1395 {
1396 	rxr_pkt_init_rtr(ep, tx_entry, RXR_SHORT_RTR_PKT, tx_entry->total_len, pkt_entry);
1397 	return 0;
1398 }
1399 
rxr_pkt_init_long_rtr(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1400 ssize_t rxr_pkt_init_long_rtr(struct rxr_ep *ep,
1401 			      struct rxr_tx_entry *tx_entry,
1402 			      struct rxr_pkt_entry *pkt_entry)
1403 {
1404 	rxr_pkt_init_rtr(ep, tx_entry, RXR_LONG_RTR_PKT, tx_entry->rma_window, pkt_entry);
1405 	return 0;
1406 }
1407 
1408 /*
1409  *     handle_sent() functions for RTR packet types
1410  */
rxr_pkt_handle_rtr_sent(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1411 void rxr_pkt_handle_rtr_sent(struct rxr_ep *ep,
1412 			     struct rxr_pkt_entry *pkt_entry)
1413 {
1414 	struct rxr_tx_entry *tx_entry;
1415 
1416 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
1417 	tx_entry->bytes_sent = 0;
1418 	tx_entry->state = RXR_TX_WAIT_READ_FINISH;
1419 }
1420 
1421 /*
1422  *     handle_send_completion() funciton for RTR packet
1423  */
rxr_pkt_handle_rtr_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1424 void rxr_pkt_handle_rtr_send_completion(struct rxr_ep *ep,
1425 					struct rxr_pkt_entry *pkt_entry)
1426 {
1427 	/*
1428 	 * Unlike other protocol, for emulated read, tx_entry
1429 	 * is release in rxr_cq_handle_rx_completion().
1430 	 * therefore there is nothing to be done here.
1431 	 */
1432 	return;
1433 }
1434 
1435 /*
1436  *     handle_recv() functions for RTR packet
1437  */
rxr_pkt_handle_rtr_recv(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1438 void rxr_pkt_handle_rtr_recv(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry)
1439 {
1440 	struct rxr_rtr_hdr *rtr_hdr;
1441 	struct rxr_rx_entry *rx_entry;
1442 	struct rxr_tx_entry *tx_entry;
1443 	ssize_t err;
1444 	struct fi_msg msg = {0};
1445 
1446 	msg.addr = pkt_entry->addr;
1447 	rx_entry = rxr_ep_get_rx_entry(ep, &msg, 0, ~0, ofi_op_read_rsp, 0);
1448 	if (OFI_UNLIKELY(!rx_entry)) {
1449 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1450 			"RX entries exhausted.\n");
1451 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1452 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1453 		return;
1454 	}
1455 
1456 	rx_entry->addr = pkt_entry->addr;
1457 	rx_entry->bytes_done = 0;
1458 	rx_entry->cq_entry.flags |= (FI_RMA | FI_READ);
1459 	rx_entry->cq_entry.len = ofi_total_iov_len(rx_entry->iov, rx_entry->iov_count);
1460 	rx_entry->cq_entry.buf = rx_entry->iov[0].iov_base;
1461 	rx_entry->total_len = rx_entry->cq_entry.len;
1462 
1463 	rtr_hdr = (struct rxr_rtr_hdr *)pkt_entry->pkt;
1464 	rx_entry->rma_initiator_rx_id = rtr_hdr->read_req_rx_id;
1465 	rx_entry->window = rtr_hdr->read_req_window;
1466 	rx_entry->iov_count = rtr_hdr->rma_iov_count;
1467 	err = rxr_rma_verified_copy_iov(ep, rtr_hdr->rma_iov, rtr_hdr->rma_iov_count,
1468 					FI_REMOTE_READ, rx_entry->iov);
1469 	if (OFI_UNLIKELY(err)) {
1470 		FI_WARN(&rxr_prov, FI_LOG_CQ, "RMA address verification failed!\n");
1471 		efa_eq_write_error(&ep->util_ep, FI_EINVAL, -FI_EINVAL);
1472 		rxr_release_rx_entry(ep, rx_entry);
1473 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1474 		return;
1475 	}
1476 
1477 	tx_entry = rxr_rma_alloc_readrsp_tx_entry(ep, rx_entry);
1478 	if (OFI_UNLIKELY(!tx_entry)) {
1479 		FI_WARN(&rxr_prov, FI_LOG_CQ, "Readrsp tx entry exhausted!\n");
1480 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1481 		rxr_release_rx_entry(ep, rx_entry);
1482 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1483 		return;
1484 	}
1485 
1486 	err = rxr_pkt_post_ctrl_or_queue(ep, RXR_TX_ENTRY, tx_entry, RXR_READRSP_PKT, 0);
1487 	if (OFI_UNLIKELY(err)) {
1488 		FI_WARN(&rxr_prov, FI_LOG_CQ, "Posting of readrsp packet failed! err=%ld\n", err);
1489 		efa_eq_write_error(&ep->util_ep, FI_EIO, err);
1490 		rxr_release_tx_entry(ep, tx_entry);
1491 		rxr_release_rx_entry(ep, rx_entry);
1492 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1493 		return;
1494 	}
1495 
1496 	rx_entry->state = RXR_RX_WAIT_READ_FINISH;
1497 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1498 }
1499 
1500 /*
1501  * RTA packet functions
1502  */
rxr_pkt_init_rta(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,int pkt_type,struct rxr_pkt_entry * pkt_entry)1503 ssize_t rxr_pkt_init_rta(struct rxr_ep *ep, struct rxr_tx_entry *tx_entry,
1504 			 int pkt_type, struct rxr_pkt_entry *pkt_entry)
1505 {
1506 	struct fi_rma_iov *rma_iov;
1507 	struct rxr_rta_hdr *rta_hdr;
1508 	char *data;
1509 	size_t data_size;
1510 	int i;
1511 
1512 	rta_hdr = (struct rxr_rta_hdr *)pkt_entry->pkt;
1513 	rta_hdr->msg_id = tx_entry->msg_id;
1514 	rta_hdr->rma_iov_count = tx_entry->rma_iov_count;
1515 	rta_hdr->atomic_datatype = tx_entry->atomic_hdr.datatype;
1516 	rta_hdr->atomic_op = tx_entry->atomic_hdr.atomic_op;
1517 	rta_hdr->tx_id = tx_entry->tx_id;
1518 	rxr_pkt_init_req_hdr(ep, tx_entry, pkt_type, pkt_entry);
1519 	rta_hdr->flags |= RXR_REQ_ATOMIC;
1520 	rma_iov = rta_hdr->rma_iov;
1521 	for (i=0; i < tx_entry->rma_iov_count; ++i) {
1522 		rma_iov[i].addr = tx_entry->rma_iov[i].addr;
1523 		rma_iov[i].len = tx_entry->rma_iov[i].len;
1524 		rma_iov[i].key = tx_entry->rma_iov[i].key;
1525 	}
1526 
1527 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1528 	data_size = ofi_copy_from_iov(data, ep->mtu_size - pkt_entry->hdr_size,
1529 				      tx_entry->iov, tx_entry->iov_count, 0);
1530 
1531 	pkt_entry->pkt_size = pkt_entry->hdr_size + data_size;
1532 	pkt_entry->x_entry = tx_entry;
1533 	return 0;
1534 }
1535 
rxr_pkt_init_write_rta(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1536 ssize_t rxr_pkt_init_write_rta(struct rxr_ep *ep, struct rxr_tx_entry *tx_entry,
1537 			    struct rxr_pkt_entry *pkt_entry)
1538 {
1539 	rxr_pkt_init_rta(ep, tx_entry, RXR_WRITE_RTA_PKT, pkt_entry);
1540 	return 0;
1541 }
1542 
rxr_pkt_init_fetch_rta(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1543 ssize_t rxr_pkt_init_fetch_rta(struct rxr_ep *ep, struct rxr_tx_entry *tx_entry,
1544 			      struct rxr_pkt_entry *pkt_entry)
1545 {
1546 	rxr_pkt_init_rta(ep, tx_entry, RXR_FETCH_RTA_PKT, pkt_entry);
1547 	return 0;
1548 }
1549 
rxr_pkt_init_compare_rta(struct rxr_ep * ep,struct rxr_tx_entry * tx_entry,struct rxr_pkt_entry * pkt_entry)1550 ssize_t rxr_pkt_init_compare_rta(struct rxr_ep *ep, struct rxr_tx_entry *tx_entry,
1551 				 struct rxr_pkt_entry *pkt_entry)
1552 {
1553 	char *data;
1554 	size_t data_size;
1555 
1556 	rxr_pkt_init_rta(ep, tx_entry, RXR_COMPARE_RTA_PKT, pkt_entry);
1557 
1558 	/* rxr_pkt_init_rta() will copy data from tx_entry->iov to pkt entry
1559 	 * the following append the data to be compared
1560 	 */
1561 	data = (char *)pkt_entry->pkt + pkt_entry->pkt_size;
1562 	data_size = ofi_copy_from_iov(data, ep->mtu_size - pkt_entry->pkt_size,
1563 				      tx_entry->atomic_ex.comp_iov,
1564 				      tx_entry->atomic_ex.comp_iov_count, 0);
1565 	assert(data_size == tx_entry->total_len);
1566 	pkt_entry->pkt_size += data_size;
1567 	return 0;
1568 }
1569 
rxr_pkt_handle_write_rta_send_completion(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1570 void rxr_pkt_handle_write_rta_send_completion(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry)
1571 {
1572 	struct rxr_tx_entry *tx_entry;
1573 
1574 	tx_entry = (struct rxr_tx_entry *)pkt_entry->x_entry;
1575 	rxr_cq_handle_tx_completion(ep, tx_entry);
1576 }
1577 
rxr_pkt_proc_write_rta(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1578 int rxr_pkt_proc_write_rta(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry)
1579 {
1580 	struct iovec iov[RXR_IOV_LIMIT];
1581 	struct rxr_rta_hdr *rta_hdr;
1582 	char *data;
1583 	int iov_count, op, dt, i;
1584 	size_t dtsize, offset;
1585 
1586 	rta_hdr = (struct rxr_rta_hdr *)pkt_entry->pkt;
1587 	op = rta_hdr->atomic_op;
1588 	dt = rta_hdr->atomic_datatype;
1589 	dtsize = ofi_datatype_size(dt);
1590 
1591 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1592 	iov_count = rta_hdr->rma_iov_count;
1593 	rxr_rma_verified_copy_iov(ep, rta_hdr->rma_iov, iov_count, FI_REMOTE_WRITE, iov);
1594 
1595 	offset = 0;
1596 	for (i = 0; i < iov_count; ++i) {
1597 		ofi_atomic_write_handlers[op][dt](iov[i].iov_base,
1598 						  data + offset,
1599 						  iov[i].iov_len / dtsize);
1600 		offset += iov[i].iov_len;
1601 	}
1602 
1603 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1604 	return 0;
1605 }
1606 
rxr_pkt_alloc_rta_rx_entry(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry,int op)1607 struct rxr_rx_entry *rxr_pkt_alloc_rta_rx_entry(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry, int op)
1608 {
1609 	struct rxr_rx_entry *rx_entry;
1610 	struct rxr_rta_hdr *rta_hdr;
1611 	struct fi_msg msg = {0};
1612 
1613 	msg.addr = pkt_entry->addr;
1614 	rx_entry = rxr_ep_get_rx_entry(ep, &msg, 0, ~0, op, 0);
1615 	if (OFI_UNLIKELY(!rx_entry)) {
1616 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1617 			"RX entries exhausted.\n");
1618 		return NULL;
1619 	}
1620 
1621 	rta_hdr = (struct rxr_rta_hdr *)pkt_entry->pkt;
1622 	rx_entry->atomic_hdr.atomic_op = rta_hdr->atomic_op;
1623 	rx_entry->atomic_hdr.datatype = rta_hdr->atomic_datatype;
1624 
1625 	rx_entry->iov_count = rta_hdr->rma_iov_count;
1626 	rxr_rma_verified_copy_iov(ep, rta_hdr->rma_iov, rx_entry->iov_count, FI_REMOTE_READ, rx_entry->iov);
1627 	rx_entry->tx_id = rta_hdr->tx_id;
1628 	rx_entry->total_len = ofi_total_iov_len(rx_entry->iov, rx_entry->iov_count);
1629 	/*
1630 	 * prepare a pkt entry to temporarily hold response data.
1631 	 * Atomic_op operates on 3 data buffers:
1632 	 *          local_data (input/output),
1633 	 *          request_data (input),
1634 	 *          response_data (output)
1635 	 * The fact local data will be changed by atomic_op means
1636 	 * response_data is not reproducible.
1637 	 * Because sending response packet can fail due to
1638 	 * -FI_EAGAIN, we need a temporary buffer to hold response_data.
1639 	 * This packet entry will be release in rxr_handle_atomrsp_send_completion()
1640 	 */
1641 	rx_entry->atomrsp_pkt = rxr_pkt_entry_alloc(ep, ep->tx_pkt_efa_pool);
1642 	if (!rx_entry->atomrsp_pkt) {
1643 		FI_WARN(&rxr_prov, FI_LOG_CQ,
1644 			"pkt entries exhausted.\n");
1645 		rxr_release_rx_entry(ep, rx_entry);
1646 		return NULL;
1647 	}
1648 
1649 	return rx_entry;
1650 }
1651 
rxr_pkt_proc_fetch_rta(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1652 int rxr_pkt_proc_fetch_rta(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry)
1653 {
1654 	struct rxr_rx_entry *rx_entry;
1655 	char *data;
1656 	int op, dt, i;
1657 	size_t offset, dtsize;
1658 	ssize_t err;
1659 
1660 	rx_entry = rxr_pkt_alloc_rta_rx_entry(ep, pkt_entry, ofi_op_atomic_fetch);
1661 	if(OFI_UNLIKELY(!rx_entry)) {
1662 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1663 		return -FI_ENOBUFS;
1664 	}
1665 
1666 	op = rx_entry->atomic_hdr.atomic_op;
1667  	dt = rx_entry->atomic_hdr.datatype;
1668 	dtsize = ofi_datatype_size(rx_entry->atomic_hdr.datatype);
1669 
1670 	data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1671 	rx_entry->atomrsp_buf = (char *)rx_entry->atomrsp_pkt->pkt + sizeof(struct rxr_atomrsp_hdr);
1672 
1673 	offset = 0;
1674 	for (i = 0; i < rx_entry->iov_count; ++i) {
1675 		ofi_atomic_readwrite_handlers[op][dt](rx_entry->iov[i].iov_base,
1676 						      data + offset,
1677 						      rx_entry->atomrsp_buf + offset,
1678 						      rx_entry->iov[i].iov_len / dtsize);
1679 		offset += rx_entry->iov[i].iov_len;
1680 	}
1681 
1682 	err = rxr_pkt_post_ctrl_or_queue(ep, RXR_RX_ENTRY, rx_entry, RXR_ATOMRSP_PKT, 0);
1683 	if (OFI_UNLIKELY(err)) {
1684 		if (rxr_cq_handle_rx_error(ep, rx_entry, err))
1685 			assert(0 && "Cannot handle rx error");
1686 	}
1687 
1688 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1689 	return 0;
1690 }
1691 
rxr_pkt_proc_compare_rta(struct rxr_ep * ep,struct rxr_pkt_entry * pkt_entry)1692 int rxr_pkt_proc_compare_rta(struct rxr_ep *ep, struct rxr_pkt_entry *pkt_entry)
1693 {
1694 	struct rxr_rx_entry *rx_entry;
1695 	char *src_data, *cmp_data;
1696 	int op, dt, i;
1697 	size_t offset, dtsize;
1698 	ssize_t err;
1699 
1700 	rx_entry = rxr_pkt_alloc_rta_rx_entry(ep, pkt_entry, ofi_op_atomic_compare);
1701 	if(OFI_UNLIKELY(!rx_entry)) {
1702 		efa_eq_write_error(&ep->util_ep, FI_ENOBUFS, -FI_ENOBUFS);
1703 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1704 		return -FI_ENOBUFS;
1705 	}
1706 
1707 	op = rx_entry->atomic_hdr.atomic_op;
1708 	dt = rx_entry->atomic_hdr.datatype;
1709        	dtsize = ofi_datatype_size(rx_entry->atomic_hdr.datatype);
1710 
1711 	src_data = (char *)pkt_entry->pkt + pkt_entry->hdr_size;
1712 	cmp_data = src_data + rx_entry->total_len;
1713 	rx_entry->atomrsp_buf = (char *)rx_entry->atomrsp_pkt->pkt + sizeof(struct rxr_atomrsp_hdr);
1714 
1715 	offset = 0;
1716 	for (i = 0; i < rx_entry->iov_count; ++i) {
1717 		ofi_atomic_swap_handlers[op - FI_CSWAP][dt](rx_entry->iov[i].iov_base,
1718 							    src_data + offset,
1719 							    cmp_data + offset,
1720 							    rx_entry->atomrsp_buf + offset,
1721 							    rx_entry->iov[i].iov_len / dtsize);
1722 		offset += rx_entry->iov[i].iov_len;
1723 	}
1724 
1725 	err = rxr_pkt_post_ctrl_or_queue(ep, RXR_RX_ENTRY, rx_entry, RXR_ATOMRSP_PKT, 0);
1726 	if (OFI_UNLIKELY(err)) {
1727 		efa_eq_write_error(&ep->util_ep, FI_EIO, err);
1728 		rxr_pkt_entry_release_tx(ep, rx_entry->atomrsp_pkt);
1729 		rxr_release_rx_entry(ep, rx_entry);
1730 		rxr_pkt_entry_release_rx(ep, pkt_entry);
1731 		return err;
1732 	}
1733 
1734 	rxr_pkt_entry_release_rx(ep, pkt_entry);
1735 	return 0;
1736 }
1737