1 /*
2  * Copyright (c) 2013-2018 Intel Corporation. All rights reserved
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32 
33 #include <stdlib.h>
34 #include <string.h>
35 #include <sys/uio.h>
36 
37 #include "ofi_iov.h"
38 #include "smr.h"
39 
40 
smr_format_rma_ioc(struct smr_cmd * cmd,const struct fi_rma_ioc * rma_ioc,size_t ioc_count)41 static void smr_format_rma_ioc(struct smr_cmd *cmd, const struct fi_rma_ioc *rma_ioc,
42 			       size_t ioc_count)
43 {
44 	cmd->rma.rma_count = ioc_count;
45 	memcpy(cmd->rma.rma_ioc, rma_ioc, sizeof(*rma_ioc) * ioc_count);
46 }
47 
smr_generic_atomic_format(struct smr_cmd * cmd,uint8_t datatype,uint8_t atomic_op)48 static void smr_generic_atomic_format(struct smr_cmd *cmd, uint8_t datatype,
49 				      uint8_t atomic_op)
50 {
51 	cmd->msg.hdr.datatype = datatype;
52 	cmd->msg.hdr.atomic_op = atomic_op;
53 }
54 
smr_format_inline_atomic(struct smr_cmd * cmd,const struct iovec * iov,size_t count,const struct iovec * compv,size_t comp_count)55 static void smr_format_inline_atomic(struct smr_cmd *cmd,
56 				     const struct iovec *iov, size_t count,
57 				     const struct iovec *compv,
58 				     size_t comp_count)
59 {
60 	size_t comp_size;
61 
62 	cmd->msg.hdr.op_src = smr_src_inline;
63 
64 	switch (cmd->msg.hdr.op) {
65 	case ofi_op_atomic:
66 	case ofi_op_atomic_fetch:
67 		cmd->msg.hdr.size = ofi_copy_from_iov(cmd->msg.data.msg,
68 						SMR_MSG_DATA_LEN, iov, count, 0);
69 		break;
70 	case ofi_op_atomic_compare:
71 		cmd->msg.hdr.size = ofi_copy_from_iov(cmd->msg.data.buf,
72 						SMR_MSG_DATA_LEN, iov, count, 0);
73 		comp_size = ofi_copy_from_iov(cmd->msg.data.comp,
74 					      SMR_MSG_DATA_LEN, compv,
75 					      comp_count, 0);
76 		if (comp_size != cmd->msg.hdr.size)
77 			FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
78 				"atomic and compare buffer size mismatch\n");
79 		break;
80 	default:
81 		break;
82 	}
83 }
84 
smr_format_inject_atomic(struct smr_cmd * cmd,const struct iovec * iov,size_t count,const struct iovec * resultv,size_t result_count,const struct iovec * compv,size_t comp_count,struct smr_region * smr,struct smr_inject_buf * tx_buf)85 static void smr_format_inject_atomic(struct smr_cmd *cmd,
86 			const struct iovec *iov, size_t count,
87 			const struct iovec *resultv, size_t result_count,
88 			const struct iovec *compv, size_t comp_count,
89 			struct smr_region *smr, struct smr_inject_buf *tx_buf)
90 {
91 	size_t comp_size;
92 
93 	cmd->msg.hdr.op_src = smr_src_inject;
94 	cmd->msg.hdr.src_data = smr_get_offset(smr, tx_buf);
95 
96 	switch (cmd->msg.hdr.op) {
97 	case ofi_op_atomic:
98 	case ofi_op_atomic_fetch:
99 		if (cmd->msg.hdr.atomic_op == FI_ATOMIC_READ)
100 			cmd->msg.hdr.size = ofi_total_iov_len(resultv, result_count);
101 		else
102 			cmd->msg.hdr.size = ofi_copy_from_iov(tx_buf->data,
103 						SMR_INJECT_SIZE, iov, count, 0);
104 		break;
105 	case ofi_op_atomic_compare:
106 		cmd->msg.hdr.size = ofi_copy_from_iov(tx_buf->buf,
107 						SMR_COMP_INJECT_SIZE, iov, count, 0);
108 		comp_size = ofi_copy_from_iov(tx_buf->comp, SMR_COMP_INJECT_SIZE,
109 					      compv, comp_count, 0);
110 		if (comp_size != cmd->msg.hdr.size)
111 			FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
112 				"atomic and compare buffer size mismatch\n");
113 		break;
114 	default:
115 		break;
116 	}
117 }
118 
smr_generic_atomic(struct smr_ep * ep,const struct fi_ioc * ioc,void ** desc,size_t count,const struct fi_ioc * compare_ioc,void ** compare_desc,size_t compare_count,struct fi_ioc * result_ioc,void ** result_desc,size_t result_count,fi_addr_t addr,const struct fi_rma_ioc * rma_ioc,size_t rma_count,enum fi_datatype datatype,enum fi_op atomic_op,void * context,uint32_t op,uint64_t op_flags)119 static ssize_t smr_generic_atomic(struct smr_ep *ep,
120 			const struct fi_ioc *ioc, void **desc, size_t count,
121 			const struct fi_ioc *compare_ioc, void **compare_desc,
122 			size_t compare_count, struct fi_ioc *result_ioc,
123 			void **result_desc, size_t result_count,
124 			fi_addr_t addr, const struct fi_rma_ioc *rma_ioc,
125 			size_t rma_count, enum fi_datatype datatype,
126 			enum fi_op atomic_op, void *context, uint32_t op,
127 			uint64_t op_flags)
128 {
129 	struct smr_region *peer_smr;
130 	struct smr_inject_buf *tx_buf;
131 	struct smr_tx_entry *pend;
132 	struct smr_resp *resp = NULL;
133 	struct smr_cmd *cmd;
134 	struct iovec iov[SMR_IOV_LIMIT];
135 	struct iovec compare_iov[SMR_IOV_LIMIT];
136 	struct iovec result_iov[SMR_IOV_LIMIT];
137 	int id, peer_id, err = 0;
138 	uint16_t flags = 0;
139 	ssize_t ret = 0;
140 	size_t total_len;
141 
142 	assert(count <= SMR_IOV_LIMIT);
143 	assert(result_count <= SMR_IOV_LIMIT);
144 	assert(compare_count <= SMR_IOV_LIMIT);
145 	assert(rma_count <= SMR_IOV_LIMIT);
146 
147 	id = (int) addr;
148 	peer_id = smr_peer_data(ep->region)[id].addr.addr;
149 
150 	ret = smr_verify_peer(ep, id);
151 	if (ret)
152 		return ret;
153 
154 	peer_smr = smr_peer_region(ep->region, id);
155 	fastlock_acquire(&peer_smr->lock);
156 	if (peer_smr->cmd_cnt < 2 || smr_peer_data(ep->region)[id].sar_status) {
157 		ret = -FI_EAGAIN;
158 		goto unlock_region;
159 	}
160 
161 	fastlock_acquire(&ep->util_ep.tx_cq->cq_lock);
162 	if (ofi_cirque_isfull(ep->util_ep.tx_cq->cirq)) {
163 		ret = -FI_EAGAIN;
164 		goto unlock_cq;
165 	}
166 
167 	cmd = ofi_cirque_tail(smr_cmd_queue(peer_smr));
168 	total_len = ofi_datatype_size(datatype) * ofi_total_ioc_cnt(ioc, count);
169 
170 	switch (op) {
171 	case ofi_op_atomic_compare:
172 		assert(compare_ioc);
173 		ofi_ioc_to_iov(compare_ioc, compare_iov, compare_count,
174 			       ofi_datatype_size(datatype));
175 		total_len *= 2;
176 		/* fall through */
177 	case ofi_op_atomic_fetch:
178 		assert(result_ioc);
179 		ofi_ioc_to_iov(result_ioc, result_iov, result_count,
180 			       ofi_datatype_size(datatype));
181 		flags |= SMR_RMA_REQ;
182 		/* fall through */
183 	case ofi_op_atomic:
184 		if (atomic_op != FI_ATOMIC_READ) {
185 			assert(ioc);
186 			ofi_ioc_to_iov(ioc, iov, count, ofi_datatype_size(datatype));
187 		} else {
188 			count = 0;
189 		}
190 		break;
191 	default:
192 		break;
193 	}
194 
195 	smr_generic_format(cmd, peer_id, op, 0, 0, op_flags);
196 	smr_generic_atomic_format(cmd, datatype, atomic_op);
197 
198 	if (total_len <= SMR_MSG_DATA_LEN && !(flags & SMR_RMA_REQ) &&
199 	    !(op_flags & FI_DELIVERY_COMPLETE)) {
200 		smr_format_inline_atomic(cmd, iov, count, compare_iov,
201 					 compare_count);
202 	} else if (total_len <= SMR_INJECT_SIZE) {
203 		tx_buf = smr_freestack_pop(smr_inject_pool(peer_smr));
204 		smr_format_inject_atomic(cmd, iov, count, result_iov,
205 					 result_count, compare_iov, compare_count,
206 					 peer_smr, tx_buf);
207 		if (flags & SMR_RMA_REQ || op_flags & FI_DELIVERY_COMPLETE) {
208 			if (ofi_cirque_isfull(smr_resp_queue(ep->region))) {
209 				smr_freestack_push(smr_inject_pool(peer_smr), tx_buf);
210 				ret = -FI_EAGAIN;
211 				goto unlock_cq;
212 			}
213 			resp = ofi_cirque_tail(smr_resp_queue(ep->region));
214 			pend = freestack_pop(ep->pend_fs);
215 			smr_format_pend_resp(pend, cmd, context, result_iov,
216 					     result_count, id, resp);
217 			cmd->msg.hdr.data = smr_get_offset(ep->region, resp);
218 			ofi_cirque_commit(smr_resp_queue(ep->region));
219 		}
220 	} else {
221 		FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
222 			"message too large\n");
223 		ret = -FI_EINVAL;
224 		goto unlock_cq;
225 	}
226 	cmd->msg.hdr.op_flags |= flags;
227 
228 	ofi_cirque_commit(smr_cmd_queue(peer_smr));
229 	peer_smr->cmd_cnt--;
230 
231 	if (!resp) {
232 		ret = smr_complete_tx(ep, context, op, cmd->msg.hdr.op_flags,
233 				      err);
234 		if (ret) {
235 			FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
236 				"unable to process tx completion\n");
237 		}
238 	}
239 
240 	cmd = ofi_cirque_tail(smr_cmd_queue(peer_smr));
241 	smr_format_rma_ioc(cmd, rma_ioc, rma_count);
242 	ofi_cirque_commit(smr_cmd_queue(peer_smr));
243 	peer_smr->cmd_cnt--;
244 unlock_cq:
245 	fastlock_release(&ep->util_ep.tx_cq->cq_lock);
246 unlock_region:
247 	fastlock_release(&peer_smr->lock);
248 	return ret;
249 }
250 
smr_atomic_writemsg(struct fid_ep * ep_fid,const struct fi_msg_atomic * msg,uint64_t flags)251 static ssize_t smr_atomic_writemsg(struct fid_ep *ep_fid,
252 			const struct fi_msg_atomic *msg, uint64_t flags)
253 {
254 	struct smr_ep *ep;
255 
256 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
257 
258 	return smr_generic_atomic(ep, msg->msg_iov, msg->desc, msg->iov_count,
259 				  NULL, NULL, 0, NULL, NULL, 0, msg->addr,
260 				  msg->rma_iov, msg->rma_iov_count,
261 				  msg->datatype, msg->op, msg->context,
262 				  ofi_op_atomic, flags | ep->util_ep.tx_msg_flags);
263 }
264 
smr_atomic_writev(struct fid_ep * ep_fid,const struct fi_ioc * iov,void ** desc,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)265 static ssize_t smr_atomic_writev(struct fid_ep *ep_fid,
266 			const struct fi_ioc *iov, void **desc, size_t count,
267 			fi_addr_t dest_addr, uint64_t addr, uint64_t key,
268 			enum fi_datatype datatype, enum fi_op op, void *context)
269 {
270 	struct smr_ep *ep;
271 	struct fi_rma_ioc rma_iov;
272 
273 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
274 
275 	rma_iov.addr = addr;
276 	rma_iov.count = ofi_total_ioc_cnt(iov, count);
277 	rma_iov.key = key;
278 
279 	return smr_generic_atomic(ep, iov, desc, count, NULL, NULL, 0, NULL,
280 				  NULL, 0, dest_addr, &rma_iov, 1, datatype,
281 				  op, context, ofi_op_atomic, smr_ep_tx_flags(ep));
282 }
283 
smr_atomic_write(struct fid_ep * ep_fid,const void * buf,size_t count,void * desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)284 static ssize_t smr_atomic_write(struct fid_ep *ep_fid, const void *buf, size_t count,
285 			void *desc, fi_addr_t dest_addr, uint64_t addr,
286 			uint64_t key, enum fi_datatype datatype, enum fi_op op,
287 			void *context)
288 {
289 	struct smr_ep *ep;
290 	struct fi_ioc iov;
291 	struct fi_rma_ioc rma_iov;
292 
293 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
294 
295 	iov.addr = (void *) buf;
296 	iov.count = count;
297 
298 	rma_iov.addr = addr;
299 	rma_iov.count = count;
300 	rma_iov.key = key;
301 
302 	return smr_generic_atomic(ep, &iov, &desc, 1, NULL, NULL, 0, NULL, NULL, 0,
303 				  dest_addr, &rma_iov, 1, datatype, op, context,
304 				  ofi_op_atomic, smr_ep_tx_flags(ep));
305 }
306 
smr_atomic_inject(struct fid_ep * ep_fid,const void * buf,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op)307 static ssize_t smr_atomic_inject(struct fid_ep *ep_fid, const void *buf,
308 			size_t count, fi_addr_t dest_addr, uint64_t addr,
309 			uint64_t key, enum fi_datatype datatype, enum fi_op op)
310 {
311 	struct smr_ep *ep;
312 	struct smr_region *peer_smr;
313 	struct smr_inject_buf *tx_buf;
314 	struct smr_cmd *cmd;
315 	struct iovec iov;
316 	struct fi_rma_ioc rma_ioc;
317 	int id, peer_id;
318 	ssize_t ret = 0;
319 	size_t total_len;
320 
321 	assert(count <= SMR_INJECT_SIZE);
322 
323 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
324 
325 	id = (int) dest_addr;
326 	peer_id = smr_peer_data(ep->region)[id].addr.addr;
327 
328 	ret = smr_verify_peer(ep, id);
329 	if (ret)
330 		return ret;
331 
332 	peer_smr = smr_peer_region(ep->region, id);
333 	fastlock_acquire(&peer_smr->lock);
334 	if (peer_smr->cmd_cnt < 2 || smr_peer_data(ep->region)[id].sar_status) {
335 		ret = -FI_EAGAIN;
336 		goto unlock_region;
337 	}
338 
339 	cmd = ofi_cirque_tail(smr_cmd_queue(peer_smr));
340 	total_len = count * ofi_datatype_size(datatype);
341 
342 	iov.iov_base = (void *) buf;
343 	iov.iov_len = total_len;
344 
345 	rma_ioc.addr = addr;
346 	rma_ioc.count = count;
347 	rma_ioc.key = key;
348 
349 	smr_generic_format(cmd, peer_id, ofi_op_atomic, 0, 0, 0);
350 	smr_generic_atomic_format(cmd, datatype, op);
351 
352 	if (total_len <= SMR_MSG_DATA_LEN) {
353 		smr_format_inline_atomic(cmd, &iov, 1, NULL, 0);
354 	} else if (total_len <= SMR_INJECT_SIZE) {
355 		tx_buf = smr_freestack_pop(smr_inject_pool(peer_smr));
356 		smr_format_inject_atomic(cmd, &iov, 1, NULL, 0, NULL, 0,
357 					 peer_smr, tx_buf);
358 	}
359 
360 	ofi_cirque_commit(smr_cmd_queue(peer_smr));
361 	peer_smr->cmd_cnt--;
362 	cmd = ofi_cirque_tail(smr_cmd_queue(peer_smr));
363 	smr_format_rma_ioc(cmd, &rma_ioc, 1);
364 	ofi_cirque_commit(smr_cmd_queue(peer_smr));
365 	peer_smr->cmd_cnt--;
366 
367 	ofi_ep_tx_cntr_inc_func(&ep->util_ep, ofi_op_atomic);
368 unlock_region:
369 	fastlock_release(&peer_smr->lock);
370 	return ret;
371 }
372 
smr_atomic_readwritemsg(struct fid_ep * ep_fid,const struct fi_msg_atomic * msg,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)373 static ssize_t smr_atomic_readwritemsg(struct fid_ep *ep_fid,
374 			const struct fi_msg_atomic *msg, struct fi_ioc *resultv,
375 			void **result_desc, size_t result_count, uint64_t flags)
376 {
377 	struct smr_ep *ep;
378 
379 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
380 
381 	return smr_generic_atomic(ep, msg->msg_iov, msg->desc, msg->iov_count,
382 				  NULL, NULL, 0, resultv, result_desc,
383 				  result_count, msg->addr,
384 				  msg->rma_iov, msg->rma_iov_count,
385 				  msg->datatype, msg->op, msg->context,
386 				  ofi_op_atomic_fetch,
387 				  flags | ep->util_ep.tx_msg_flags);
388 }
389 
smr_atomic_readwritev(struct fid_ep * ep_fid,const struct fi_ioc * iov,void ** desc,size_t count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)390 static ssize_t smr_atomic_readwritev(struct fid_ep *ep_fid,
391 			const struct fi_ioc *iov, void **desc, size_t count,
392 			struct fi_ioc *resultv, void **result_desc,
393 			size_t result_count, fi_addr_t dest_addr, uint64_t addr,
394 			uint64_t key, enum fi_datatype datatype, enum fi_op op,
395 			void *context)
396 {
397 	struct smr_ep *ep;
398 	struct fi_rma_ioc rma_iov;
399 
400 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
401 
402 	rma_iov.addr = addr;
403 	rma_iov.count = ofi_total_ioc_cnt(iov, count);
404 	rma_iov.key = key;
405 
406 	return smr_generic_atomic(ep, iov, desc, count, NULL, NULL, 0, resultv,
407 				  result_desc, result_count, dest_addr,
408 				  &rma_iov, 1, datatype, op, context,
409 				  ofi_op_atomic_fetch, smr_ep_tx_flags(ep));
410 }
411 
smr_atomic_readwrite(struct fid_ep * ep_fid,const void * buf,size_t count,void * desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)412 static ssize_t smr_atomic_readwrite(struct fid_ep *ep_fid, const void *buf,
413 			size_t count, void *desc, void *result,
414 			void *result_desc, fi_addr_t dest_addr, uint64_t addr,
415 			uint64_t key, enum fi_datatype datatype, enum fi_op op,
416 			void *context)
417 {
418 	struct smr_ep *ep;
419 	struct fi_ioc iov, resultv;
420 	struct fi_rma_ioc rma_iov;
421 
422 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
423 
424 	iov.addr = (void *) buf;
425 	iov.count = count;
426 
427 	resultv.addr = result;
428 	resultv.count = count;
429 
430 	rma_iov.addr = addr;
431 	rma_iov.count = count;
432 	rma_iov.key = key;
433 
434 	return smr_generic_atomic(ep, &iov, &desc, 1, NULL, NULL, 0, &resultv,
435 				  &result_desc, 1, dest_addr, &rma_iov, 1,
436 				  datatype, op, context, ofi_op_atomic_fetch,
437 				  smr_ep_tx_flags(ep));
438 }
439 
smr_atomic_compwritemsg(struct fid_ep * ep_fid,const struct fi_msg_atomic * msg,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)440 static ssize_t smr_atomic_compwritemsg(struct fid_ep *ep_fid,
441 			const struct fi_msg_atomic *msg,
442 			const struct fi_ioc *comparev, void **compare_desc,
443 			size_t compare_count, struct fi_ioc *resultv,
444 			void **result_desc, size_t result_count, uint64_t flags)
445 {
446 	struct smr_ep *ep;
447 
448 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
449 
450 	return smr_generic_atomic(ep, msg->msg_iov, msg->desc, msg->iov_count,
451 				  comparev, compare_desc, compare_count,
452 				  resultv, result_desc,
453 				  result_count, msg->addr,
454 				  msg->rma_iov, msg->rma_iov_count,
455 				  msg->datatype, msg->op, msg->context,
456 				  ofi_op_atomic_compare,
457 				  flags | ep->util_ep.tx_msg_flags);
458 }
459 
smr_atomic_compwritev(struct fid_ep * ep_fid,const struct fi_ioc * iov,void ** desc,size_t count,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)460 static ssize_t smr_atomic_compwritev(struct fid_ep *ep_fid,
461 			const struct fi_ioc *iov, void **desc, size_t count,
462 			const struct fi_ioc *comparev, void **compare_desc,
463 			size_t compare_count, struct fi_ioc *resultv,
464 			void **result_desc, size_t result_count,
465 			fi_addr_t dest_addr, uint64_t addr, uint64_t key,
466 			enum fi_datatype datatype, enum fi_op op, void *context)
467 {
468 	struct smr_ep *ep;
469 	struct fi_rma_ioc rma_iov;
470 
471 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
472 
473 	rma_iov.addr = addr;
474 	rma_iov.count = ofi_total_ioc_cnt(iov, count);
475 	rma_iov.key = key;
476 
477 	return smr_generic_atomic(ep, iov, desc, count, comparev, compare_desc,
478 				  compare_count, resultv, result_desc,
479 				  result_count, dest_addr, &rma_iov, 1,
480 				  datatype, op, context, ofi_op_atomic_compare,
481 				  smr_ep_tx_flags(ep));
482 }
483 
smr_atomic_compwrite(struct fid_ep * ep_fid,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)484 static ssize_t smr_atomic_compwrite(struct fid_ep *ep_fid, const void *buf,
485 			size_t count, void *desc, const void *compare,
486 			void *compare_desc, void *result, void *result_desc,
487 			fi_addr_t dest_addr, uint64_t addr, uint64_t key,
488 			enum fi_datatype datatype, enum fi_op op, void *context)
489 {
490 	struct smr_ep *ep;
491 	struct fi_ioc iov, resultv, comparev;
492 	struct fi_rma_ioc rma_iov;
493 
494 	ep = container_of(ep_fid, struct smr_ep, util_ep.ep_fid.fid);
495 
496 	iov.addr = (void *) buf;
497 	iov.count = count;
498 
499 	resultv.addr = result;
500 	resultv.count = count;
501 
502 	comparev.addr = (void *) compare;
503 	comparev.count = count;
504 
505 	rma_iov.addr = addr;
506 	rma_iov.count = count;
507 	rma_iov.key = key;
508 
509 	return smr_generic_atomic(ep, &iov, &desc, 1, &comparev, &compare_desc,
510 				  1, &resultv, &result_desc, 1, dest_addr,
511 				  &rma_iov, 1, datatype, op, context,
512 				  ofi_op_atomic_compare, smr_ep_tx_flags(ep));
513 }
514 
smr_query_atomic(struct fid_domain * domain,enum fi_datatype datatype,enum fi_op op,struct fi_atomic_attr * attr,uint64_t flags)515 int smr_query_atomic(struct fid_domain *domain, enum fi_datatype datatype,
516 		     enum fi_op op, struct fi_atomic_attr *attr, uint64_t flags)
517 {
518 	int ret;
519 	size_t total_size;
520 
521 	if (flags & FI_TAGGED) {
522 		FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
523 			"tagged atomic op not supported\n");
524 		return -FI_EINVAL;
525 	}
526 
527 	ret = ofi_atomic_valid(&smr_prov, datatype, op, flags);
528 	if (ret || !attr)
529 		return ret;
530 
531 	attr->size = ofi_datatype_size(datatype);
532 
533 	total_size = (flags & FI_COMPARE_ATOMIC) ? SMR_COMP_INJECT_SIZE :
534 		      SMR_INJECT_SIZE;
535 	attr->count = total_size / attr->size;
536 
537 	return ret;
538 }
539 
smr_atomic_valid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)540 static int smr_atomic_valid(struct fid_ep *ep, enum fi_datatype datatype,
541 			    enum fi_op op, size_t *count)
542 {
543 	struct fi_atomic_attr attr;
544 	int ret;
545 
546 	ret = smr_query_atomic(NULL, datatype, op, &attr, 0);
547 
548 	if (!ret)
549 		*count = attr.count;
550 
551 	return ret;
552 }
553 
smr_atomic_fetch_valid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)554 static int smr_atomic_fetch_valid(struct fid_ep *ep, enum fi_datatype datatype,
555 				  enum fi_op op, size_t *count)
556 {
557 	struct fi_atomic_attr attr;
558 	int ret;
559 
560 	ret = smr_query_atomic(NULL, datatype, op, &attr, FI_FETCH_ATOMIC);
561 
562 	if (!ret)
563 		*count = attr.count;
564 
565 	return ret;
566 }
567 
smr_atomic_comp_valid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)568 static int smr_atomic_comp_valid(struct fid_ep *ep, enum fi_datatype datatype,
569 				 enum fi_op op, size_t *count)
570 {
571 	struct fi_atomic_attr attr;
572 	int ret;
573 
574 	ret = smr_query_atomic(NULL, datatype, op, &attr, FI_COMPARE_ATOMIC);
575 
576 	if (!ret)
577 		*count = attr.count;
578 
579 	return ret;
580 }
581 
582 struct fi_ops_atomic smr_atomic_ops = {
583 	.size = sizeof(struct fi_ops_atomic),
584 	.write = smr_atomic_write,
585 	.writev = smr_atomic_writev,
586 	.writemsg = smr_atomic_writemsg,
587 	.inject = smr_atomic_inject,
588 	.readwrite = smr_atomic_readwrite,
589 	.readwritev = smr_atomic_readwritev,
590 	.readwritemsg = smr_atomic_readwritemsg,
591 	.compwrite = smr_atomic_compwrite,
592 	.compwritev = smr_atomic_compwritev,
593 	.compwritemsg = smr_atomic_compwritemsg,
594 	.writevalid = smr_atomic_valid,
595 	.readwritevalid = smr_atomic_fetch_valid,
596 	.compwritevalid = smr_atomic_comp_valid,
597 };
598