1 /*
2  * Copyright (c) 2013-2017 Intel Corporation. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32 
33 #include "psmx.h"
34 
35 /* Atomics protocol:
36  *
37  * Atomics REQ:
38  *	args[0].u32w0	cmd
39  *	args[0].u32w1	count
40  *	args[1].u64	req
41  *	args[2].u64	addr
42  *	args[3].u64	key
43  *	args[4].u32w0	datatype
44  *	args[4].u32w1	op
45  *
46  * Atomics REP:
47  *	args[0].u32w0	cmd
48  *	args[0].u32w1	error
49  *	args[1].u64	req
50  */
51 
52 static fastlock_t psmx_atomic_lock;
53 
psmx_atomic_init(void)54 void psmx_atomic_init(void)
55 {
56 	fastlock_init(&psmx_atomic_lock);
57 }
58 
psmx_atomic_fini(void)59 void psmx_atomic_fini(void)
60 {
61 	fastlock_destroy(&psmx_atomic_lock);
62 }
63 
64 #define CASE_INT_TYPE(FUNC,...) \
65 		case FI_INT8:	FUNC(__VA_ARGS__,int8_t); break; \
66 		case FI_UINT8:	FUNC(__VA_ARGS__,uint8_t); break; \
67 		case FI_INT16:	FUNC(__VA_ARGS__,int16_t); break; \
68 		case FI_UINT16: FUNC(__VA_ARGS__,uint16_t); break; \
69 		case FI_INT32:	FUNC(__VA_ARGS__,int32_t); break; \
70 		case FI_UINT32: FUNC(__VA_ARGS__,uint32_t); break; \
71 		case FI_INT64:	FUNC(__VA_ARGS__,int64_t); break; \
72 		case FI_UINT64: FUNC(__VA_ARGS__,uint64_t); break;
73 
74 #define CASE_FP_TYPE(FUNC,...) \
75 		case FI_FLOAT:	FUNC(__VA_ARGS__,float); break; \
76 		case FI_DOUBLE:	FUNC(__VA_ARGS__,double); break; \
77 		case FI_LONG_DOUBLE: FUNC(__VA_ARGS__,long double); break;
78 
79 #define CASE_COMPLEX_TYPE(FUNC,...) \
80 		case FI_FLOAT_COMPLEX:	FUNC(__VA_ARGS__,float complex); break; \
81 		case FI_DOUBLE_COMPLEX:	FUNC(__VA_ARGS__,double complex); break; \
82 		case FI_LONG_DOUBLE_COMPLEX: FUNC(__VA_ARGS__,long double complex); break;
83 
84 #define SWITCH_INT_TYPE(type,...) \
85 		switch (type) { \
86 		CASE_INT_TYPE(__VA_ARGS__) \
87 		default: return -FI_EOPNOTSUPP; \
88 		}
89 
90 #define SWITCH_ORD_TYPE(type,...) \
91 		switch (type) { \
92 		CASE_INT_TYPE(__VA_ARGS__) \
93 		CASE_FP_TYPE(__VA_ARGS__) \
94 		default: return -FI_EOPNOTSUPP; \
95 		}
96 
97 #define SWITCH_ALL_TYPE(type,...) \
98 		switch (type) { \
99 		CASE_INT_TYPE(__VA_ARGS__) \
100 		CASE_FP_TYPE(__VA_ARGS__) \
101 		CASE_COMPLEX_TYPE(__VA_ARGS__) \
102 		default: return -FI_EOPNOTSUPP; \
103 		}
104 
105 #define PSMX_MIN(dst,src)	if ((dst) > (src)) (dst) = (src)
106 #define PSMX_MAX(dst,src)	if ((dst) < (src)) (dst) = (src)
107 #define PSMX_SUM(dst,src)	(dst) += (src)
108 #define PSMX_PROD(dst,src)	(dst) *= (src)
109 #define PSMX_LOR(dst,src)	(dst) = (dst) || (src)
110 #define PSMX_LAND(dst,src)	(dst) = (dst) && (src)
111 #define PSMX_BOR(dst,src)	(dst) |= (src)
112 #define PSMX_BAND(dst,src)	(dst) &= (src)
113 #define PSMX_LXOR(dst,src)	(dst) = ((dst) && !(src)) || (!(dst) && (src))
114 #define PSMX_BXOR(dst,src)	(dst) ^= (src)
115 #define PSMX_COPY(dst,src)	(dst) = (src)
116 
117 #define PSMX_ATOMIC_READ(dst,res,cnt,TYPE) \
118 		do { \
119 			int i; \
120 			TYPE *d = (dst); \
121 			TYPE *r = (res); \
122 			fastlock_acquire(&psmx_atomic_lock); \
123 			for (i=0; i<(cnt); i++) \
124 				r[i] = d[i]; \
125 			fastlock_release(&psmx_atomic_lock); \
126 		} while (0)
127 
128 #define PSMX_ATOMIC_WRITE(dst,src,cnt,OP,TYPE) \
129 		do { \
130 			int i; \
131 			TYPE *d = (dst); \
132 			TYPE *s = (src); \
133 			fastlock_acquire(&psmx_atomic_lock); \
134 			for (i=0; i<cnt; i++) \
135 				OP(d[i],s[i]); \
136 			fastlock_release(&psmx_atomic_lock); \
137 		} while (0)
138 
139 #define PSMX_ATOMIC_READWRITE(dst,src,res,cnt,OP,TYPE) \
140 		do { \
141 			int i; \
142 			TYPE *d = (dst); \
143 			TYPE *s = (src); \
144 			TYPE *r = (res); \
145 			fastlock_acquire(&psmx_atomic_lock); \
146 			for (i=0; i<(cnt); i++) {\
147 				r[i] = d[i]; \
148 				OP(d[i],s[i]); \
149 			} \
150 			fastlock_release(&psmx_atomic_lock); \
151 		} while (0)
152 
153 #define PSMX_ATOMIC_CSWAP(dst,src,cmp,res,cnt,CMP_OP,TYPE) \
154 		do { \
155 			int i; \
156 			TYPE *d = (dst); \
157 			TYPE *s = (src); \
158 			TYPE *c = (cmp); \
159 			TYPE *r = (res); \
160 			fastlock_acquire(&psmx_atomic_lock); \
161 			for (i=0; i<(cnt); i++) { \
162 				r[i] = d[i]; \
163 				if (c[i] CMP_OP d[i]) \
164 					d[i] = s[i]; \
165 			} \
166 			fastlock_release(&psmx_atomic_lock); \
167 		} while (0)
168 
169 #define PSMX_ATOMIC_MSWAP(dst,src,cmp,res,cnt,TYPE) \
170 		do { \
171 			int i; \
172 			TYPE *d = (dst); \
173 			TYPE *s = (src); \
174 			TYPE *c = (cmp); \
175 			TYPE *r = (res); \
176 			fastlock_acquire(&psmx_atomic_lock); \
177 			for (i=0; i<(cnt); i++) { \
178 				r[i] = d[i]; \
179 				d[i] = (s[i] & c[i]) | (d[i] & ~c[i]); \
180 			} \
181 			fastlock_release(&psmx_atomic_lock); \
182 		} while (0)
183 
psmx_atomic_do_write(void * dest,void * src,int datatype,int op,int count)184 static int psmx_atomic_do_write(void *dest, void *src,
185 				int datatype, int op, int count)
186 {
187 	switch (op) {
188 	case FI_MIN:
189 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_WRITE,
190 				dest,src,count,PSMX_MIN);
191 		break;
192 
193 	case FI_MAX:
194 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_WRITE,
195 				dest,src,count,PSMX_MAX);
196 		break;
197 
198 	case FI_SUM:
199 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
200 				dest,src,count,PSMX_SUM);
201 		break;
202 
203 	case FI_PROD:
204 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
205 				dest,src,count,PSMX_PROD);
206 		break;
207 
208 	case FI_LOR:
209 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
210 				dest,src,count,PSMX_LOR);
211 		break;
212 
213 	case FI_LAND:
214 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
215 				dest,src,count,PSMX_LAND);
216 		break;
217 
218 	case FI_BOR:
219 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
220 				dest,src,count,PSMX_BOR);
221 		break;
222 
223 	case FI_BAND:
224 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
225 				dest,src,count,PSMX_BAND);
226 		break;
227 
228 	case FI_LXOR:
229 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
230 				dest,src,count,PSMX_LXOR);
231 		break;
232 
233 	case FI_BXOR:
234 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
235 				dest,src,count,PSMX_BXOR);
236 		break;
237 
238 	case FI_ATOMIC_WRITE:
239 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
240 				dest,src,count,PSMX_COPY);
241 		break;
242 
243 	default:
244 		return -FI_EOPNOTSUPP;
245 	}
246 
247 	return 0;
248 }
249 
psmx_atomic_do_readwrite(void * dest,void * src,void * result,int datatype,int op,int count)250 static int psmx_atomic_do_readwrite(void *dest, void *src, void *result,
251 				    int datatype, int op, int count)
252 {
253 	switch (op) {
254 	case FI_MIN:
255 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_READWRITE,
256 				dest,src,result,count,PSMX_MIN);
257 		break;
258 
259 	case FI_MAX:
260 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_READWRITE,
261 				dest,src,result,count,PSMX_MAX);
262 		break;
263 
264 	case FI_SUM:
265 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
266 				dest,src,result,count,PSMX_SUM);
267 		break;
268 
269 	case FI_PROD:
270 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
271 				dest,src,result,count,PSMX_PROD);
272 		break;
273 
274 	case FI_LOR:
275 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
276 				dest,src,result,count,PSMX_LOR);
277 		break;
278 
279 	case FI_LAND:
280 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
281 				dest,src,result,count,PSMX_LAND);
282 		break;
283 
284 	case FI_BOR:
285 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
286 				dest,src,result,count,PSMX_BOR);
287 		break;
288 
289 	case FI_BAND:
290 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
291 				dest,src,result,count,PSMX_BAND);
292 		break;
293 
294 	case FI_LXOR:
295 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
296 				dest,src,result,count,PSMX_LXOR);
297 		break;
298 
299 	case FI_BXOR:
300 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
301 				dest,src,result,count,PSMX_BXOR);
302 		break;
303 
304 	case FI_ATOMIC_READ:
305 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READ,
306 				dest,result,count);
307 		break;
308 
309 	case FI_ATOMIC_WRITE:
310 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
311 				dest,src,result,count,PSMX_COPY);
312 		break;
313 
314 	default:
315 		return -FI_EOPNOTSUPP;
316 	}
317 
318 	return 0;
319 }
320 
psmx_atomic_do_compwrite(void * dest,void * src,void * compare,void * result,int datatype,int op,int count)321 static int psmx_atomic_do_compwrite(void *dest, void *src, void *compare,
322 				    void *result, int datatype, int op,
323 				    int count)
324 {
325 	switch (op) {
326 	case FI_CSWAP:
327 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_CSWAP,
328 				dest,src,compare,result,count,==);
329 		break;
330 
331 	case FI_CSWAP_NE:
332 		SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_CSWAP,
333 				dest,src,compare,result,count,!=);
334 		break;
335 
336 	case FI_CSWAP_LE:
337 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
338 				dest,src,compare,result,count,<=);
339 		break;
340 
341 	case FI_CSWAP_LT:
342 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
343 				dest,src,compare,result,count,<);
344 		break;
345 
346 	case FI_CSWAP_GE:
347 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
348 				dest,src,compare,result,count,>=);
349 		break;
350 
351 	case FI_CSWAP_GT:
352 		SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
353 				dest,src,compare,result,count,>);
354 		break;
355 
356 	case FI_MSWAP:
357 		SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_MSWAP,
358 				dest,src,compare,result,count);
359 		break;
360 
361 	default:
362 		return -FI_EOPNOTSUPP;
363 	}
364 
365 	return 0;
366 }
367 
psmx_am_atomic_completion(void * buf)368 static void psmx_am_atomic_completion(void *buf)
369 {
370 	if (buf)
371 		free(buf);
372 }
373 
psmx_am_atomic_handler(psm_am_token_t token,psm_epaddr_t epaddr,psm_amarg_t * args,int nargs,void * src,uint32_t len)374 int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr,
375 			   psm_amarg_t *args, int nargs, void *src,
376 			   uint32_t len)
377 {
378 	psm_amarg_t rep_args[8];
379 	int count;
380 	uint8_t *addr;
381 	uint64_t key;
382 	int datatype, op;
383 	int err = 0;
384 	int op_error = 0;
385 	struct psmx_am_request *req;
386 	struct psmx_cq_event *event;
387 	struct psmx_fid_mr *mr;
388 	struct psmx_fid_ep *target_ep;
389 	struct psmx_fid_cntr *cntr = NULL;
390 	struct psmx_fid_cntr *mr_cntr = NULL;
391 	void *tmp_buf;
392 
393 	switch (args[0].u32w0 & PSMX_AM_OP_MASK) {
394 	case PSMX_AM_REQ_ATOMIC_WRITE:
395 		count = args[0].u32w1;
396 		addr = (uint8_t *)(uintptr_t)args[2].u64;
397 		key = args[3].u64;
398 		datatype = args[4].u32w0;
399 		op = args[4].u32w1;
400 		assert(len == ofi_datatype_size(datatype) * count);
401 
402 		mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
403 		op_error = mr ?
404 			psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) :
405 			-FI_EINVAL;
406 
407 		if (!op_error) {
408 			addr += mr->offset;
409 			psmx_atomic_do_write(addr, src, datatype, op, count);
410 
411 			target_ep = mr->domain->atomics_ep;
412 			if (target_ep->caps & FI_RMA_EVENT) {
413 				cntr = target_ep->remote_write_cntr;
414 				mr_cntr = mr->cntr;
415 
416 				if (cntr)
417 					psmx_cntr_inc(cntr);
418 
419 				if (mr_cntr && mr_cntr != cntr)
420 					psmx_cntr_inc(mr_cntr);
421 			}
422 		}
423 
424 		rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE;
425 		rep_args[0].u32w1 = op_error;
426 		rep_args[1].u64 = args[1].u64;
427 		err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
428 				rep_args, 2, NULL, 0, 0,
429 				NULL, NULL );
430 		break;
431 
432 	case PSMX_AM_REQ_ATOMIC_READWRITE:
433 		count = args[0].u32w1;
434 		addr = (uint8_t *)(uintptr_t)args[2].u64;
435 		key = args[3].u64;
436 		datatype = args[4].u32w0;
437 		op = args[4].u32w1;
438 
439 		if (op == FI_ATOMIC_READ)
440 			len = ofi_datatype_size(datatype) * count;
441 
442 		assert(len == ofi_datatype_size(datatype) * count);
443 
444 		mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
445 		op_error = mr ?
446 			psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
447 			-FI_EINVAL;
448 
449 		if (!op_error) {
450 			addr += mr->offset;
451 			tmp_buf = malloc(len);
452 			if (tmp_buf)
453 				psmx_atomic_do_readwrite(addr, src, tmp_buf,
454 							 datatype, op, count);
455 			else
456 				op_error = -FI_ENOMEM;
457 
458 			target_ep = mr->domain->atomics_ep;
459 			if (target_ep->caps & FI_RMA_EVENT) {
460 				if (op == FI_ATOMIC_READ) {
461 					cntr = target_ep->remote_read_cntr;
462 				} else {
463 					cntr = target_ep->remote_write_cntr;
464 					mr_cntr = mr->cntr;
465 				}
466 
467 				if (cntr)
468 					psmx_cntr_inc(cntr);
469 
470 				if (mr_cntr && mr_cntr != cntr)
471 					psmx_cntr_inc(mr_cntr);
472 			}
473 		} else {
474 			tmp_buf = NULL;
475 		}
476 
477 		rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
478 		rep_args[0].u32w1 = op_error;
479 		rep_args[1].u64 = args[1].u64;
480 		err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
481 				rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
482 				psmx_am_atomic_completion, tmp_buf );
483 		break;
484 
485 	case PSMX_AM_REQ_ATOMIC_COMPWRITE:
486 		count = args[0].u32w1;
487 		addr = (uint8_t *)(uintptr_t)args[2].u64;
488 		key = args[3].u64;
489 		datatype = args[4].u32w0;
490 		op = args[4].u32w1;
491 		len /= 2;
492 		assert(len == ofi_datatype_size(datatype) * count);
493 
494 		mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
495 		op_error = mr ?
496 			psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
497 			-FI_EINVAL;
498 
499 		if (!op_error) {
500 			addr += mr->offset;
501 			tmp_buf = malloc(len);
502 			if (tmp_buf)
503 				psmx_atomic_do_compwrite(addr, src, (uint8_t *)src + len,
504 							 tmp_buf, datatype, op, count);
505 			else
506 				op_error = -FI_ENOMEM;
507 
508 			target_ep = mr->domain->atomics_ep;
509 			if (target_ep->caps & FI_RMA_EVENT) {
510 				cntr = target_ep->remote_write_cntr;
511 				mr_cntr = mr->cntr;
512 
513 				if (cntr)
514 					psmx_cntr_inc(cntr);
515 
516 				if (mr_cntr && mr_cntr != cntr)
517 					psmx_cntr_inc(mr_cntr);
518 			}
519 		} else {
520 			tmp_buf = NULL;
521 		}
522 
523 		rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
524 		rep_args[0].u32w1 = op_error;
525 		rep_args[1].u64 = args[1].u64;
526 		err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
527 				rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
528 				psmx_am_atomic_completion, tmp_buf );
529 		break;
530 
531 	case PSMX_AM_REP_ATOMIC_WRITE:
532 		req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
533 		op_error = (int)args[0].u32w1;
534 		assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE);
535 		if (req->ep->send_cq && (!req->no_event || op_error)) {
536 			event = psmx_cq_create_event(
537 					req->ep->send_cq,
538 					req->atomic.context,
539 					req->atomic.buf,
540 					req->cq_flags,
541 					req->atomic.len,
542 					0, /* data */
543 					0, /* tag */
544 					0, /* olen */
545 					op_error);
546 			if (event)
547 				psmx_cq_enqueue_event(req->ep->send_cq, event);
548 			else
549 				err = -FI_ENOMEM;
550 		}
551 
552 		if (req->ep->write_cntr)
553 			psmx_cntr_inc(req->ep->write_cntr);
554 
555 		free(req);
556 		break;
557 
558 	case PSMX_AM_REP_ATOMIC_READWRITE:
559 	case PSMX_AM_REP_ATOMIC_COMPWRITE:
560 		req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
561 		op_error = (int)args[0].u32w1;
562 		assert(op_error || req->atomic.len == len);
563 
564 		if (!op_error)
565 			memcpy(req->atomic.result, src, len);
566 
567 		if (req->ep->send_cq && (!req->no_event || op_error)) {
568 			event = psmx_cq_create_event(
569 					req->ep->send_cq,
570 					req->atomic.context,
571 					req->atomic.buf,
572 					req->cq_flags,
573 					req->atomic.len,
574 					0, /* data */
575 					0, /* tag */
576 					0, /* olen */
577 					op_error);
578 			if (event)
579 				psmx_cq_enqueue_event(req->ep->send_cq, event);
580 			else
581 				err = -FI_ENOMEM;
582 		}
583 
584 		if (req->ep->read_cntr)
585 			psmx_cntr_inc(req->ep->read_cntr);
586 
587 		free(req);
588 		break;
589 
590 	default:
591 		err = -FI_EINVAL;
592 	}
593 	return err;
594 }
595 
psmx_atomic_self(int am_cmd,struct psmx_fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)596 static int psmx_atomic_self(int am_cmd,
597 			    struct psmx_fid_ep *ep,
598 			    const void *buf,
599 			    size_t count, void *desc,
600 			    const void *compare, void *compare_desc,
601 			    void *result, void *result_desc,
602 			    uint64_t addr, uint64_t key,
603 			    enum fi_datatype datatype,
604 			    enum fi_op op, void *context,
605 			    uint64_t flags)
606 {
607 	struct psmx_fid_mr *mr;
608 	struct psmx_cq_event *event;
609 	struct psmx_fid_ep *target_ep;
610 	struct psmx_fid_cntr *cntr = NULL;
611 	struct psmx_fid_cntr *mr_cntr = NULL;
612 	void *tmp_buf;
613 	size_t len;
614 	int no_event;
615 	int err = 0;
616 	int op_error;
617 	int access;
618 	uint64_t cq_flags = 0;
619 
620 	if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE)
621 		access = FI_REMOTE_WRITE;
622 	else
623 		access = FI_REMOTE_READ | FI_REMOTE_WRITE;
624 
625 	len = ofi_datatype_size(datatype) * count;
626 	mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
627 	op_error = mr ?  psmx_mr_validate(mr, addr, len, access) : -FI_EINVAL;
628 
629 	if (op_error)
630 		goto gen_local_event;
631 
632 	addr += mr->offset;
633 
634 	switch (am_cmd) {
635 	case PSMX_AM_REQ_ATOMIC_WRITE:
636 		err = psmx_atomic_do_write((void *)addr, (void *)buf,
637 					   (int)datatype, (int)op, (int)count);
638 		cq_flags = FI_WRITE | FI_ATOMIC;
639 		break;
640 
641 	case PSMX_AM_REQ_ATOMIC_READWRITE:
642 		if (result != buf) {
643 			err = psmx_atomic_do_readwrite((void *)addr, (void *)buf,
644 						       (void *)result, (int)datatype,
645 						       (int)op, (int)count);
646 		} else {
647 			tmp_buf = malloc(len);
648 			if (tmp_buf) {
649 				memcpy(tmp_buf, result, len);
650 				err = psmx_atomic_do_readwrite((void *)addr, (void *)buf,
651 							       tmp_buf, (int)datatype,
652 							       (int)op, (int)count);
653 				memcpy(result, tmp_buf, len);
654 				free(tmp_buf);
655 			} else {
656 				err = -FI_ENOMEM;
657 			}
658 		}
659 		if (op == FI_ATOMIC_READ)
660 			cq_flags = FI_READ | FI_ATOMIC;
661 		else
662 			cq_flags = FI_WRITE | FI_ATOMIC;
663 		break;
664 
665 	case PSMX_AM_REQ_ATOMIC_COMPWRITE:
666 		if (result != buf && result != compare) {
667 			err = psmx_atomic_do_compwrite((void *)addr, (void *)buf,
668 						       (void *)compare, (void *)result,
669 						       (int)datatype, (int)op, (int)count);
670 		} else {
671 			tmp_buf = malloc(len);
672 			if (tmp_buf) {
673 				memcpy(tmp_buf, result, len);
674 				err = psmx_atomic_do_compwrite((void *)addr, (void *)buf,
675 							       (void *)compare, tmp_buf,
676 							       (int)datatype, (int)op, (int)count);
677 				memcpy(result, tmp_buf, len);
678 				free(tmp_buf);
679 			} else {
680 				err = -FI_ENOMEM;
681 			}
682 		}
683 		cq_flags = FI_WRITE | FI_ATOMIC;
684 		break;
685 	}
686 
687 	target_ep = mr->domain->atomics_ep;
688 	if (target_ep->caps & FI_RMA_EVENT) {
689 		if (op == FI_ATOMIC_READ) {
690 			cntr = target_ep->remote_read_cntr;
691 		} else {
692 			cntr = target_ep->remote_write_cntr;
693 			mr_cntr = mr->cntr;
694 		}
695 
696 		if (cntr)
697 			psmx_cntr_inc(cntr);
698 
699 		if (mr_cntr && mr_cntr != cntr)
700 			psmx_cntr_inc(mr_cntr);
701 	}
702 
703 gen_local_event:
704 	no_event = ((flags & PSMX_NO_COMPLETION) ||
705 		    (ep->send_selective_completion && !(flags & FI_COMPLETION)));
706 	if (ep->send_cq && (!no_event || op_error)) {
707 		event = psmx_cq_create_event(
708 				ep->send_cq,
709 				context,
710 				(void *)buf,
711 				cq_flags,
712 				len,
713 				0, /* data */
714 				0, /* tag */
715 				0, /* olen */
716 				op_error);
717 		if (event)
718 			psmx_cq_enqueue_event(ep->send_cq, event);
719 		else
720 			err = -FI_ENOMEM;
721 	}
722 
723 	switch (am_cmd) {
724 	case PSMX_AM_REQ_ATOMIC_WRITE:
725 		if (ep->write_cntr)
726 			psmx_cntr_inc(ep->write_cntr);
727 		break;
728 	case PSMX_AM_REQ_ATOMIC_READWRITE:
729 	case PSMX_AM_REQ_ATOMIC_COMPWRITE:
730 		if (ep->read_cntr)
731 			psmx_cntr_inc(ep->read_cntr);
732 		break;
733 	}
734 
735 	return err;
736 }
737 
_psmx_atomic_write(struct fid_ep * ep,const void * buf,size_t count,void * desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)738 ssize_t _psmx_atomic_write(struct fid_ep *ep,
739 			   const void *buf,
740 			   size_t count, void *desc,
741 			   fi_addr_t dest_addr,
742 			   uint64_t addr, uint64_t key,
743 			   enum fi_datatype datatype,
744 			   enum fi_op op, void *context,
745 			   uint64_t flags)
746 {
747 	struct psmx_fid_ep *ep_priv;
748 	struct psmx_fid_av *av;
749 	struct psmx_epaddr_context *epaddr_context;
750 	struct psmx_am_request *req;
751 	psm_amarg_t args[8];
752 	int am_flags = PSM_AM_FLAG_ASYNC;
753 	int chunk_size, len;
754 	size_t idx;
755 
756 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
757 
758 	if (flags & FI_TRIGGER) {
759 		struct psmx_trigger *trigger;
760 		struct fi_triggered_context *ctxt = context;
761 
762 		trigger = calloc(1, sizeof(*trigger));
763 		if (!trigger)
764 			return -FI_ENOMEM;
765 
766 		trigger->op = PSMX_TRIGGERED_ATOMIC_WRITE;
767 		trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
768 					     struct psmx_fid_cntr, cntr);
769 		trigger->threshold = ctxt->trigger.threshold.threshold;
770 		trigger->atomic_write.ep = ep;
771 		trigger->atomic_write.buf = buf;
772 		trigger->atomic_write.count = count;
773 		trigger->atomic_write.desc = desc;
774 		trigger->atomic_write.dest_addr = dest_addr;
775 		trigger->atomic_write.addr = addr;
776 		trigger->atomic_write.key = key;
777 		trigger->atomic_write.datatype = datatype;
778 		trigger->atomic_write.atomic_op = op;
779 		trigger->atomic_write.context = context;
780 		trigger->atomic_write.flags = flags & ~FI_TRIGGER;
781 
782 		psmx_cntr_add_trigger(trigger->cntr, trigger);
783 		return 0;
784 	}
785 
786 	if (!buf)
787 		return -FI_EINVAL;
788 
789 	if (datatype >= FI_DATATYPE_LAST)
790 		return -FI_EINVAL;
791 
792 	if (op >= FI_ATOMIC_OP_LAST)
793 		return -FI_EINVAL;
794 
795 	av = ep_priv->av;
796 	if (av && av->type == FI_AV_TABLE) {
797 		idx = dest_addr;
798 		if (idx >= av->last)
799 			return -FI_EINVAL;
800 
801 		dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
802 	} else if (!dest_addr) {
803 		return -FI_EINVAL;
804 	}
805 
806 	epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
807 	if (epaddr_context->epid == ep_priv->domain->psm_epid)
808 		return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_WRITE,
809 					ep_priv, buf, count, desc,
810 					NULL, NULL, NULL, NULL,
811 					addr, key, datatype, op,
812 					context, flags);
813 
814 	chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
815 	len = ofi_datatype_size(datatype)* count;
816 	if (len > chunk_size)
817 		return -FI_EMSGSIZE;
818 
819 	if (flags & FI_INJECT) {
820 		req = malloc(sizeof(*req) + len);
821 		if (!req)
822 			return -FI_ENOMEM;
823 		memset(req, 0, sizeof(*req));
824 		memcpy((uint8_t *)req+sizeof(*req), (void *)buf, len);
825 		buf = (uint8_t *)req + sizeof(*req);
826 	} else {
827 		req = calloc(1, sizeof(*req));
828 		if (!req)
829 			return -FI_ENOMEM;
830 	}
831 
832 	req->no_event = (flags & PSMX_NO_COMPLETION) ||
833 			(ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
834 
835 	req->op = PSMX_AM_REQ_ATOMIC_WRITE;
836 	req->atomic.buf = (void *)buf;
837 	req->atomic.len = len;
838 	req->atomic.addr = addr;
839 	req->atomic.key = key;
840 	req->atomic.context = context;
841 	req->ep = ep_priv;
842 	req->cq_flags = FI_WRITE | FI_ATOMIC;
843 
844 	args[0].u32w0 = PSMX_AM_REQ_ATOMIC_WRITE;
845 	args[0].u32w1 = count;
846 	args[1].u64 = (uint64_t)(uintptr_t)req;
847 	args[2].u64 = addr;
848 	args[3].u64 = key;
849 	args[4].u32w0 = datatype;
850 	args[4].u32w1 = op;
851 	psm_am_request_short((psm_epaddr_t) dest_addr,
852 				PSMX_AM_ATOMIC_HANDLER, args, 5,
853 				(void *)buf, len, am_flags, NULL, NULL);
854 
855 	return 0;
856 }
857 
psmx_atomic_write(struct fid_ep * ep,const void * buf,size_t count,void * desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)858 static ssize_t psmx_atomic_write(struct fid_ep *ep,
859 			       const void *buf,
860 			       size_t count, void *desc,
861 			       fi_addr_t dest_addr,
862 			       uint64_t addr, uint64_t key,
863 			       enum fi_datatype datatype,
864 			       enum fi_op op, void *context)
865 {
866 	struct psmx_fid_ep *ep_priv;
867 
868 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
869 	return _psmx_atomic_write(ep, buf, count, desc,
870 				  dest_addr, addr, key,
871 				  datatype, op, context, ep_priv->tx_flags);
872 }
873 
psmx_atomic_writemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,uint64_t flags)874 static ssize_t psmx_atomic_writemsg(struct fid_ep *ep,
875 				const struct fi_msg_atomic *msg,
876 				uint64_t flags)
877 {
878 	if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov)
879 		return -FI_EINVAL;
880 
881 	return _psmx_atomic_write(ep, msg->msg_iov[0].addr,
882 				  msg->msg_iov[0].count,
883 				  msg->desc ? msg->desc[0] : NULL,
884 				  msg->addr, msg->rma_iov[0].addr,
885 				  msg->rma_iov[0].key, msg->datatype,
886 				  msg->op, msg->context, flags);
887 }
888 
psmx_atomic_writev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)889 static ssize_t psmx_atomic_writev(struct fid_ep *ep,
890 			      const struct fi_ioc *iov,
891 			      void **desc, size_t count,
892 			      fi_addr_t dest_addr,
893 			      uint64_t addr, uint64_t key,
894 			      enum fi_datatype datatype,
895 			      enum fi_op op, void *context)
896 {
897 	if (!iov || count != 1)
898 		return -FI_EINVAL;
899 
900 	return psmx_atomic_write(ep, iov->addr, iov->count,
901 				 desc ? desc[0] : NULL, dest_addr, addr, key,
902 				 datatype, op, context);
903 }
904 
psmx_atomic_inject(struct fid_ep * ep,const void * buf,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op)905 static ssize_t psmx_atomic_inject(struct fid_ep *ep,
906 			       const void *buf,
907 			       size_t count, /*void *desc,*/
908 			       fi_addr_t dest_addr,
909 			       uint64_t addr, uint64_t key,
910 			       enum fi_datatype datatype,
911 			       enum fi_op op)
912 {
913 	struct psmx_fid_ep *ep_priv;
914 
915 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
916 	return _psmx_atomic_write(ep, buf, count, NULL/*desc*/,
917 				  dest_addr, addr, key,
918 				  datatype, op, NULL,
919 				  ep_priv->tx_flags | FI_INJECT | PSMX_NO_COMPLETION);
920 }
921 
_psmx_atomic_readwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)922 ssize_t _psmx_atomic_readwrite(struct fid_ep *ep,
923 				const void *buf,
924 				size_t count, void *desc,
925 				void *result, void *result_desc,
926 				fi_addr_t dest_addr,
927 				uint64_t addr, uint64_t key,
928 				enum fi_datatype datatype,
929 				enum fi_op op, void *context,
930 				uint64_t flags)
931 {
932 	struct psmx_fid_ep *ep_priv;
933 	struct psmx_fid_av *av;
934 	struct psmx_epaddr_context *epaddr_context;
935 	struct psmx_am_request *req;
936 	psm_amarg_t args[8];
937 	int am_flags = PSM_AM_FLAG_ASYNC;
938 	int chunk_size, len;
939 	size_t idx;
940 
941 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
942 
943 	if (flags & FI_TRIGGER) {
944 		struct psmx_trigger *trigger;
945 		struct fi_triggered_context *ctxt = context;
946 
947 		trigger = calloc(1, sizeof(*trigger));
948 		if (!trigger)
949 			return -FI_ENOMEM;
950 
951 		trigger->op = PSMX_TRIGGERED_ATOMIC_READWRITE;
952 		trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
953 					     struct psmx_fid_cntr, cntr);
954 		trigger->threshold = ctxt->trigger.threshold.threshold;
955 		trigger->atomic_readwrite.ep = ep;
956 		trigger->atomic_readwrite.buf = buf;
957 		trigger->atomic_readwrite.count = count;
958 		trigger->atomic_readwrite.desc = desc;
959 		trigger->atomic_readwrite.result = result;
960 		trigger->atomic_readwrite.result_desc = result_desc;
961 		trigger->atomic_readwrite.dest_addr = dest_addr;
962 		trigger->atomic_readwrite.addr = addr;
963 		trigger->atomic_readwrite.key = key;
964 		trigger->atomic_readwrite.datatype = datatype;
965 		trigger->atomic_readwrite.atomic_op = op;
966 		trigger->atomic_readwrite.context = context;
967 		trigger->atomic_readwrite.flags = flags & ~FI_TRIGGER;
968 
969 		psmx_cntr_add_trigger(trigger->cntr, trigger);
970 		return 0;
971 	}
972 
973 	if (!buf && op != FI_ATOMIC_READ)
974 		return -FI_EINVAL;
975 
976 	if (datatype >= FI_DATATYPE_LAST)
977 		return -FI_EINVAL;
978 
979 	if (op >= FI_ATOMIC_OP_LAST)
980 		return -FI_EINVAL;
981 
982 	av = ep_priv->av;
983 	if (av && av->type == FI_AV_TABLE) {
984 		idx = dest_addr;
985 		if (idx >= av->last)
986 			return -FI_EINVAL;
987 
988 		dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
989 	} else if (!dest_addr) {
990 		return -FI_EINVAL;
991 	}
992 
993 	epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
994 	if (epaddr_context->epid == ep_priv->domain->psm_epid)
995 		return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_READWRITE,
996 					ep_priv, buf, count, desc,
997 					NULL, NULL, result, result_desc,
998 					addr, key, datatype, op,
999 					context, flags);
1000 
1001 	chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
1002 	len = ofi_datatype_size(datatype) * count;
1003 	if (len > chunk_size)
1004 		return -FI_EMSGSIZE;
1005 
1006 	if ((flags & FI_INJECT) && op != FI_ATOMIC_READ) {
1007 		req = malloc(sizeof(*req) + len);
1008 		if (!req)
1009 			return -FI_ENOMEM;
1010 		memset(req, 0, sizeof(*req));
1011 		memcpy((uint8_t *)req+sizeof(*req), (void *)buf, len);
1012 		buf = (uint8_t *)req + sizeof(*req);
1013 	} else {
1014 		req = calloc(1, sizeof(*req));
1015 		if (!req)
1016 			return -FI_ENOMEM;
1017 	}
1018 
1019 	req->no_event = (flags & PSMX_NO_COMPLETION) ||
1020 			(ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
1021 
1022 	req->op = PSMX_AM_REQ_ATOMIC_READWRITE;
1023 	req->atomic.buf = (void *)buf;
1024 	req->atomic.len = len;
1025 	req->atomic.addr = addr;
1026 	req->atomic.key = key;
1027 	req->atomic.context = context;
1028 	req->atomic.result = result;
1029 	req->ep = ep_priv;
1030 	if (op == FI_ATOMIC_READ)
1031 		req->cq_flags = FI_READ | FI_ATOMIC;
1032 	else
1033 		req->cq_flags = FI_WRITE | FI_ATOMIC;
1034 
1035 	args[0].u32w0 = PSMX_AM_REQ_ATOMIC_READWRITE;
1036 	args[0].u32w1 = count;
1037 	args[1].u64 = (uint64_t)(uintptr_t)req;
1038 	args[2].u64 = addr;
1039 	args[3].u64 = key;
1040 	args[4].u32w0 = datatype;
1041 	args[4].u32w1 = op;
1042 	psm_am_request_short((psm_epaddr_t) dest_addr,
1043 				PSMX_AM_ATOMIC_HANDLER, args, 5,
1044 				(void *)buf, (buf?len:0), am_flags, NULL, NULL);
1045 
1046 	return 0;
1047 }
1048 
psmx_atomic_readwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1049 static ssize_t psmx_atomic_readwrite(struct fid_ep *ep,
1050 				   const void *buf,
1051 				   size_t count, void *desc,
1052 				   void *result, void *result_desc,
1053 				   fi_addr_t dest_addr,
1054 				   uint64_t addr, uint64_t key,
1055 				   enum fi_datatype datatype,
1056 				   enum fi_op op, void *context)
1057 {
1058 	struct psmx_fid_ep *ep_priv;
1059 
1060 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1061 	return _psmx_atomic_readwrite(ep, buf, count, desc,
1062 					result, result_desc, dest_addr,
1063 					addr, key, datatype, op,
1064 					context, ep_priv->tx_flags);
1065 }
1066 
psmx_atomic_readwritemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)1067 static ssize_t psmx_atomic_readwritemsg(struct fid_ep *ep,
1068 				    const struct fi_msg_atomic *msg,
1069 				    struct fi_ioc *resultv,
1070 				    void **result_desc,
1071 				    size_t result_count,
1072 				    uint64_t flags)
1073 {
1074 	void *buf;
1075 	size_t count;
1076 
1077 	if (!msg || !msg->rma_iov)
1078 		return -FI_EINVAL;
1079 
1080 	if (msg->op == FI_ATOMIC_READ) {
1081 		if (result_count != 1 || !resultv)
1082 			return -FI_EINVAL;
1083 
1084 		buf = NULL;
1085 		count = resultv[0].count;
1086 	} else {
1087 		if (msg->iov_count != 1 || !msg->msg_iov)
1088 			return -FI_EINVAL;
1089 
1090 		buf = msg->msg_iov[0].addr;
1091 		count = msg->msg_iov[0].count;
1092 	}
1093 
1094 	return _psmx_atomic_readwrite(ep, buf, count,
1095 					msg->desc ? msg->desc[0] : NULL,
1096 					resultv[0].addr,
1097 					result_desc ? result_desc[0] : NULL,
1098 					msg->addr, msg->rma_iov[0].addr,
1099 					msg->rma_iov[0].key, msg->datatype,
1100 					msg->op, msg->context, flags);
1101 }
1102 
psmx_atomic_readwritev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1103 static ssize_t psmx_atomic_readwritev(struct fid_ep *ep,
1104 				  const struct fi_ioc *iov,
1105 				  void **desc, size_t count,
1106 				  struct fi_ioc *resultv,
1107 				  void **result_desc, size_t result_count,
1108 				  fi_addr_t dest_addr,
1109 				  uint64_t addr, uint64_t key,
1110 				  enum fi_datatype datatype,
1111 				  enum fi_op op, void *context)
1112 {
1113 	if (!iov || count != 1 || !resultv)
1114 		return -FI_EINVAL;
1115 
1116 	return psmx_atomic_readwrite(ep, iov->addr, iov->count,
1117 				     desc ? desc[0] : NULL,
1118 				     resultv[0].addr,
1119 				     result_desc ? result_desc[0] : NULL,
1120 				     dest_addr, addr, key, datatype, op, context);
1121 }
1122 
_psmx_atomic_compwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)1123 ssize_t _psmx_atomic_compwrite(struct fid_ep *ep,
1124 				const void *buf,
1125 				size_t count, void *desc,
1126 				const void *compare, void *compare_desc,
1127 				void *result, void *result_desc,
1128 				fi_addr_t dest_addr,
1129 				uint64_t addr, uint64_t key,
1130 				enum fi_datatype datatype,
1131 				enum fi_op op, void *context,
1132 				uint64_t flags)
1133 {
1134 	struct psmx_fid_ep *ep_priv;
1135 	struct psmx_fid_av *av;
1136 	struct psmx_epaddr_context *epaddr_context;
1137 	struct psmx_am_request *req;
1138 	psm_amarg_t args[8];
1139 	int am_flags = PSM_AM_FLAG_ASYNC;
1140 	int chunk_size, len;
1141 	void *tmp_buf = NULL;
1142 	size_t idx;
1143 
1144 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1145 
1146 	if (flags & FI_TRIGGER) {
1147 		struct psmx_trigger *trigger;
1148 		struct fi_triggered_context *ctxt = context;
1149 
1150 		trigger = calloc(1, sizeof(*trigger));
1151 		if (!trigger)
1152 			return -FI_ENOMEM;
1153 
1154 		trigger->op = PSMX_TRIGGERED_ATOMIC_COMPWRITE;
1155 		trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
1156 					     struct psmx_fid_cntr, cntr);
1157 		trigger->threshold = ctxt->trigger.threshold.threshold;
1158 		trigger->atomic_compwrite.ep = ep;
1159 		trigger->atomic_compwrite.buf = buf;
1160 		trigger->atomic_compwrite.count = count;
1161 		trigger->atomic_compwrite.desc = desc;
1162 		trigger->atomic_compwrite.compare = compare;
1163 		trigger->atomic_compwrite.compare_desc = compare_desc;
1164 		trigger->atomic_compwrite.result = result;
1165 		trigger->atomic_compwrite.result_desc = result_desc;
1166 		trigger->atomic_compwrite.dest_addr = dest_addr;
1167 		trigger->atomic_compwrite.addr = addr;
1168 		trigger->atomic_compwrite.key = key;
1169 		trigger->atomic_compwrite.datatype = datatype;
1170 		trigger->atomic_compwrite.atomic_op = op;
1171 		trigger->atomic_compwrite.context = context;
1172 		trigger->atomic_compwrite.flags = flags & ~FI_TRIGGER;
1173 
1174 		psmx_cntr_add_trigger(trigger->cntr, trigger);
1175 		return 0;
1176 	}
1177 
1178 	if (!buf)
1179 		return -FI_EINVAL;
1180 
1181 	if (datatype >= FI_DATATYPE_LAST)
1182 		return -FI_EINVAL;
1183 
1184 	if (op >= FI_ATOMIC_OP_LAST)
1185 		return -FI_EINVAL;
1186 
1187 	av = ep_priv->av;
1188 	if (av && av->type == FI_AV_TABLE) {
1189 		idx = dest_addr;
1190 		if (idx >= av->last)
1191 			return -FI_EINVAL;
1192 
1193 		dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
1194 	} else if (!dest_addr) {
1195 		return -FI_EINVAL;
1196 	}
1197 
1198 	epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
1199 	if (epaddr_context->epid == ep_priv->domain->psm_epid)
1200 		return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_COMPWRITE,
1201 					ep_priv, buf, count, desc,
1202 					compare, compare_desc,
1203 					result, result_desc,
1204 					addr, key, datatype, op,
1205 					context, flags);
1206 
1207 	chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
1208 	len = ofi_datatype_size(datatype) * count;
1209 	if (len * 2 > chunk_size)
1210 		return -FI_EMSGSIZE;
1211 
1212 	if (flags & FI_INJECT) {
1213 		req = malloc(sizeof(*req) + len + len);
1214 		if (!req)
1215 			return -FI_ENOMEM;
1216 		memset(req, 0, sizeof(*req));
1217 		memcpy((uint8_t *)req + sizeof(*req), (void *)buf, len);
1218 		memcpy((uint8_t *)req + sizeof(*req) + len, (void *)compare, len);
1219 		buf = (uint8_t *)req + sizeof(*req);
1220 		compare = (uint8_t *)buf + len;
1221 	} else {
1222 		req = calloc(1, sizeof(*req));
1223 		if (!req)
1224 			return -FI_ENOMEM;
1225 
1226 		if ((uintptr_t)compare != (uintptr_t)buf + len) {
1227 			tmp_buf = malloc(len * 2);
1228 			if (!tmp_buf) {
1229 				free(req);
1230 				return -FI_ENOMEM;
1231 			}
1232 
1233 			memcpy(tmp_buf, buf, len);
1234 			memcpy((uint8_t *)tmp_buf + len, compare, len);
1235 		}
1236 	}
1237 
1238 	req->no_event = (flags & PSMX_NO_COMPLETION) ||
1239 			(ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
1240 
1241 	req->op = PSMX_AM_REQ_ATOMIC_COMPWRITE;
1242 	req->atomic.buf = (void *)buf;
1243 	req->atomic.len = len;
1244 	req->atomic.addr = addr;
1245 	req->atomic.key = key;
1246 	req->atomic.context = context;
1247 	req->atomic.result = result;
1248 	req->ep = ep_priv;
1249 	req->cq_flags = FI_WRITE | FI_ATOMIC;
1250 
1251 	args[0].u32w0 = PSMX_AM_REQ_ATOMIC_COMPWRITE;
1252 	args[0].u32w1 = count;
1253 	args[1].u64 = (uint64_t)(uintptr_t)req;
1254 	args[2].u64 = addr;
1255 	args[3].u64 = key;
1256 	args[4].u32w0 = datatype;
1257 	args[4].u32w1 = op;
1258 	psm_am_request_short((psm_epaddr_t) dest_addr,
1259 				PSMX_AM_ATOMIC_HANDLER, args, 5,
1260 				tmp_buf ? tmp_buf : (void *)buf,
1261 				len * 2, am_flags,
1262 				psmx_am_atomic_completion, tmp_buf);
1263 
1264 	return 0;
1265 }
1266 
psmx_atomic_compwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1267 static ssize_t psmx_atomic_compwrite(struct fid_ep *ep,
1268 				   const void *buf,
1269 				   size_t count, void *desc,
1270 				   const void *compare, void *compare_desc,
1271 				   void *result, void *result_desc,
1272 				   fi_addr_t dest_addr,
1273 				   uint64_t addr, uint64_t key,
1274 				   enum fi_datatype datatype,
1275 				   enum fi_op op, void *context)
1276 {
1277 	struct psmx_fid_ep *ep_priv;
1278 
1279 	ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1280 	return _psmx_atomic_compwrite(ep, buf, count, desc,
1281 					compare, compare_desc,
1282 					result, result_desc,
1283 					dest_addr, addr, key,
1284 					datatype, op, context, ep_priv->tx_flags);
1285 }
1286 
psmx_atomic_compwritemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)1287 static ssize_t psmx_atomic_compwritemsg(struct fid_ep *ep,
1288 				    const struct fi_msg_atomic *msg,
1289 				    const struct fi_ioc *comparev,
1290 				    void **compare_desc,
1291 				    size_t compare_count,
1292 				    struct fi_ioc *resultv,
1293 				    void **result_desc,
1294 				    size_t result_count,
1295 				    uint64_t flags)
1296 {
1297 	if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov || !resultv)
1298 		return -FI_EINVAL;
1299 
1300 	return _psmx_atomic_compwrite(ep, msg->msg_iov[0].addr,
1301 					msg->msg_iov[0].count,
1302 					msg->desc ? msg->desc[0] : NULL,
1303 					comparev[0].addr,
1304 					compare_desc ? compare_desc[0] : NULL,
1305 					resultv[0].addr,
1306 					result_desc ? result_desc[0] : NULL,
1307 					msg->addr, msg->rma_iov[0].addr,
1308 					msg->rma_iov[0].key, msg->datatype,
1309 					msg->op, msg->context, flags);
1310 }
1311 
psmx_atomic_compwritev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1312 static ssize_t psmx_atomic_compwritev(struct fid_ep *ep,
1313 				  const struct fi_ioc *iov,
1314 				  void **desc, size_t count,
1315 				  const struct fi_ioc *comparev,
1316 				  void **compare_desc,
1317 				  size_t compare_count,
1318 				  struct fi_ioc *resultv,
1319 				  void **result_desc,
1320 				  size_t result_count,
1321 				  fi_addr_t dest_addr,
1322 				  uint64_t addr, uint64_t key,
1323 				  enum fi_datatype datatype,
1324 				  enum fi_op op, void *context)
1325 {
1326 	if (!iov || count != 1 || !comparev || !resultv)
1327 		return -FI_EINVAL;
1328 
1329 	return psmx_atomic_compwrite(ep, iov->addr, iov->count,
1330 				     desc ? desc[0] : NULL,
1331 				     comparev[0].addr,
1332 				     compare_desc ? compare_desc[0] : NULL,
1333 				     resultv[0].addr,
1334 				     result_desc ? result_desc[0] : NULL,
1335 				     dest_addr, addr, key, datatype, op, context);
1336 }
1337 
psmx_atomic_writevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1338 static int psmx_atomic_writevalid(struct fid_ep *ep,
1339 				  enum fi_datatype datatype,
1340 				  enum fi_op op, size_t *count)
1341 {
1342 	int chunk_size;
1343 
1344 	if (datatype >= FI_DATATYPE_LAST)
1345 		return -FI_EOPNOTSUPP;
1346 
1347 	switch (op) {
1348 	case FI_MIN:
1349 	case FI_MAX:
1350 	case FI_SUM:
1351 	case FI_PROD:
1352 	case FI_LOR:
1353 	case FI_LAND:
1354 	case FI_BOR:
1355 	case FI_BAND:
1356 	case FI_LXOR:
1357 	case FI_BXOR:
1358 	case FI_ATOMIC_WRITE:
1359 		break;
1360 
1361 	default:
1362 		return -FI_EOPNOTSUPP;
1363 	}
1364 
1365 	if (count) {
1366 		chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1367 				 psmx_am_param.max_request_short);
1368 		*count = chunk_size / ofi_datatype_size(datatype);
1369 	}
1370 	return 0;
1371 }
1372 
psmx_atomic_readwritevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1373 static int psmx_atomic_readwritevalid(struct fid_ep *ep,
1374 				      enum fi_datatype datatype,
1375 				      enum fi_op op, size_t *count)
1376 {
1377 	int chunk_size;
1378 
1379 	if (datatype >= FI_DATATYPE_LAST)
1380 		return -FI_EOPNOTSUPP;
1381 
1382 	switch (op) {
1383 	case FI_MIN:
1384 	case FI_MAX:
1385 	case FI_SUM:
1386 	case FI_PROD:
1387 	case FI_LOR:
1388 	case FI_LAND:
1389 	case FI_BOR:
1390 	case FI_BAND:
1391 	case FI_LXOR:
1392 	case FI_BXOR:
1393 	case FI_ATOMIC_READ:
1394 	case FI_ATOMIC_WRITE:
1395 		break;
1396 
1397 	default:
1398 		return -FI_EOPNOTSUPP;
1399 	}
1400 
1401 	if (count) {
1402 		chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1403 				 psmx_am_param.max_request_short);
1404 		*count = chunk_size / ofi_datatype_size(datatype);
1405 	}
1406 	return 0;
1407 }
1408 
psmx_atomic_compwritevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1409 static int psmx_atomic_compwritevalid(struct fid_ep *ep,
1410 				      enum fi_datatype datatype,
1411 				      enum fi_op op, size_t *count)
1412 {
1413 	int chunk_size;
1414 
1415 	if (datatype >= FI_DATATYPE_LAST)
1416 		return -FI_EOPNOTSUPP;
1417 
1418 	switch (op) {
1419 	case FI_CSWAP:
1420 	case FI_CSWAP_NE:
1421 		break;
1422 
1423 	case FI_CSWAP_LE:
1424 	case FI_CSWAP_LT:
1425 	case FI_CSWAP_GE:
1426 	case FI_CSWAP_GT:
1427 		if (datatype == FI_FLOAT_COMPLEX ||
1428 		    datatype == FI_DOUBLE_COMPLEX ||
1429 		    datatype == FI_LONG_DOUBLE_COMPLEX)
1430 			return -FI_EOPNOTSUPP;
1431 		break;
1432 
1433 	case FI_MSWAP:
1434 		if (datatype == FI_FLOAT_COMPLEX ||
1435 		    datatype == FI_DOUBLE_COMPLEX ||
1436 		    datatype == FI_LONG_DOUBLE_COMPLEX ||
1437 		    datatype == FI_FLOAT ||
1438 		    datatype == FI_DOUBLE ||
1439 		    datatype == FI_LONG_DOUBLE)
1440 			return -FI_EOPNOTSUPP;
1441 		break;
1442 
1443 	default:
1444 		return -FI_EOPNOTSUPP;
1445 	}
1446 
1447 	if (count) {
1448 		chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1449 				 psmx_am_param.max_request_short);
1450 		*count = chunk_size / (2 * ofi_datatype_size(datatype));
1451 	}
1452 	return 0;
1453 }
1454 
psmx_query_atomic(struct fid_domain * doamin,enum fi_datatype datatype,enum fi_op op,struct fi_atomic_attr * attr,uint64_t flags)1455 int psmx_query_atomic(struct fid_domain *doamin, enum fi_datatype datatype,
1456 		      enum fi_op op, struct fi_atomic_attr *attr, uint64_t flags)
1457 {
1458 	int ret;
1459 	size_t count;
1460 
1461 	if (flags & FI_TAGGED)
1462 		return -FI_EOPNOTSUPP;
1463 
1464 	if (flags & FI_COMPARE_ATOMIC) {
1465 		if (flags & FI_FETCH_ATOMIC)
1466 			return -FI_EINVAL;
1467 		ret = psmx_atomic_compwritevalid(NULL, datatype, op, &count);
1468 	} else if (flags & FI_FETCH_ATOMIC) {
1469 		ret = psmx_atomic_readwritevalid(NULL, datatype, op, &count);
1470 	} else {
1471 		ret = psmx_atomic_writevalid(NULL, datatype, op, &count);
1472 	}
1473 
1474 	if (attr && !ret) {
1475 		attr->size = ofi_datatype_size(datatype);
1476 		attr->count = count;
1477 	}
1478 
1479 	return ret;
1480 }
1481 
1482 struct fi_ops_atomic psmx_atomic_ops = {
1483 	.size = sizeof(struct fi_ops_atomic),
1484 	.write = psmx_atomic_write,
1485 	.writev = psmx_atomic_writev,
1486 	.writemsg = psmx_atomic_writemsg,
1487 	.inject = psmx_atomic_inject,
1488 	.readwrite = psmx_atomic_readwrite,
1489 	.readwritev = psmx_atomic_readwritev,
1490 	.readwritemsg = psmx_atomic_readwritemsg,
1491 	.compwrite = psmx_atomic_compwrite,
1492 	.compwritev = psmx_atomic_compwritev,
1493 	.compwritemsg = psmx_atomic_compwritemsg,
1494 	.writevalid = psmx_atomic_writevalid,
1495 	.readwritevalid = psmx_atomic_readwritevalid,
1496 	.compwritevalid = psmx_atomic_compwritevalid,
1497 };
1498 
1499