1 /*
2 * Copyright (c) 2013-2017 Intel Corporation. All rights reserved.
3 *
4 * This software is available to you under a choice of one of two
5 * licenses. You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * BSD license below:
9 *
10 * Redistribution and use in source and binary forms, with or
11 * without modification, are permitted provided that the following
12 * conditions are met:
13 *
14 * - Redistributions of source code must retain the above
15 * copyright notice, this list of conditions and the following
16 * disclaimer.
17 *
18 * - Redistributions in binary form must reproduce the above
19 * copyright notice, this list of conditions and the following
20 * disclaimer in the documentation and/or other materials
21 * provided with the distribution.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 * SOFTWARE.
31 */
32
33 #include "psmx.h"
34
35 /* Atomics protocol:
36 *
37 * Atomics REQ:
38 * args[0].u32w0 cmd
39 * args[0].u32w1 count
40 * args[1].u64 req
41 * args[2].u64 addr
42 * args[3].u64 key
43 * args[4].u32w0 datatype
44 * args[4].u32w1 op
45 *
46 * Atomics REP:
47 * args[0].u32w0 cmd
48 * args[0].u32w1 error
49 * args[1].u64 req
50 */
51
52 static fastlock_t psmx_atomic_lock;
53
psmx_atomic_init(void)54 void psmx_atomic_init(void)
55 {
56 fastlock_init(&psmx_atomic_lock);
57 }
58
psmx_atomic_fini(void)59 void psmx_atomic_fini(void)
60 {
61 fastlock_destroy(&psmx_atomic_lock);
62 }
63
64 #define CASE_INT_TYPE(FUNC,...) \
65 case FI_INT8: FUNC(__VA_ARGS__,int8_t); break; \
66 case FI_UINT8: FUNC(__VA_ARGS__,uint8_t); break; \
67 case FI_INT16: FUNC(__VA_ARGS__,int16_t); break; \
68 case FI_UINT16: FUNC(__VA_ARGS__,uint16_t); break; \
69 case FI_INT32: FUNC(__VA_ARGS__,int32_t); break; \
70 case FI_UINT32: FUNC(__VA_ARGS__,uint32_t); break; \
71 case FI_INT64: FUNC(__VA_ARGS__,int64_t); break; \
72 case FI_UINT64: FUNC(__VA_ARGS__,uint64_t); break;
73
74 #define CASE_FP_TYPE(FUNC,...) \
75 case FI_FLOAT: FUNC(__VA_ARGS__,float); break; \
76 case FI_DOUBLE: FUNC(__VA_ARGS__,double); break; \
77 case FI_LONG_DOUBLE: FUNC(__VA_ARGS__,long double); break;
78
79 #define CASE_COMPLEX_TYPE(FUNC,...) \
80 case FI_FLOAT_COMPLEX: FUNC(__VA_ARGS__,float complex); break; \
81 case FI_DOUBLE_COMPLEX: FUNC(__VA_ARGS__,double complex); break; \
82 case FI_LONG_DOUBLE_COMPLEX: FUNC(__VA_ARGS__,long double complex); break;
83
84 #define SWITCH_INT_TYPE(type,...) \
85 switch (type) { \
86 CASE_INT_TYPE(__VA_ARGS__) \
87 default: return -FI_EOPNOTSUPP; \
88 }
89
90 #define SWITCH_ORD_TYPE(type,...) \
91 switch (type) { \
92 CASE_INT_TYPE(__VA_ARGS__) \
93 CASE_FP_TYPE(__VA_ARGS__) \
94 default: return -FI_EOPNOTSUPP; \
95 }
96
97 #define SWITCH_ALL_TYPE(type,...) \
98 switch (type) { \
99 CASE_INT_TYPE(__VA_ARGS__) \
100 CASE_FP_TYPE(__VA_ARGS__) \
101 CASE_COMPLEX_TYPE(__VA_ARGS__) \
102 default: return -FI_EOPNOTSUPP; \
103 }
104
105 #define PSMX_MIN(dst,src) if ((dst) > (src)) (dst) = (src)
106 #define PSMX_MAX(dst,src) if ((dst) < (src)) (dst) = (src)
107 #define PSMX_SUM(dst,src) (dst) += (src)
108 #define PSMX_PROD(dst,src) (dst) *= (src)
109 #define PSMX_LOR(dst,src) (dst) = (dst) || (src)
110 #define PSMX_LAND(dst,src) (dst) = (dst) && (src)
111 #define PSMX_BOR(dst,src) (dst) |= (src)
112 #define PSMX_BAND(dst,src) (dst) &= (src)
113 #define PSMX_LXOR(dst,src) (dst) = ((dst) && !(src)) || (!(dst) && (src))
114 #define PSMX_BXOR(dst,src) (dst) ^= (src)
115 #define PSMX_COPY(dst,src) (dst) = (src)
116
117 #define PSMX_ATOMIC_READ(dst,res,cnt,TYPE) \
118 do { \
119 int i; \
120 TYPE *d = (dst); \
121 TYPE *r = (res); \
122 fastlock_acquire(&psmx_atomic_lock); \
123 for (i=0; i<(cnt); i++) \
124 r[i] = d[i]; \
125 fastlock_release(&psmx_atomic_lock); \
126 } while (0)
127
128 #define PSMX_ATOMIC_WRITE(dst,src,cnt,OP,TYPE) \
129 do { \
130 int i; \
131 TYPE *d = (dst); \
132 TYPE *s = (src); \
133 fastlock_acquire(&psmx_atomic_lock); \
134 for (i=0; i<cnt; i++) \
135 OP(d[i],s[i]); \
136 fastlock_release(&psmx_atomic_lock); \
137 } while (0)
138
139 #define PSMX_ATOMIC_READWRITE(dst,src,res,cnt,OP,TYPE) \
140 do { \
141 int i; \
142 TYPE *d = (dst); \
143 TYPE *s = (src); \
144 TYPE *r = (res); \
145 fastlock_acquire(&psmx_atomic_lock); \
146 for (i=0; i<(cnt); i++) {\
147 r[i] = d[i]; \
148 OP(d[i],s[i]); \
149 } \
150 fastlock_release(&psmx_atomic_lock); \
151 } while (0)
152
153 #define PSMX_ATOMIC_CSWAP(dst,src,cmp,res,cnt,CMP_OP,TYPE) \
154 do { \
155 int i; \
156 TYPE *d = (dst); \
157 TYPE *s = (src); \
158 TYPE *c = (cmp); \
159 TYPE *r = (res); \
160 fastlock_acquire(&psmx_atomic_lock); \
161 for (i=0; i<(cnt); i++) { \
162 r[i] = d[i]; \
163 if (c[i] CMP_OP d[i]) \
164 d[i] = s[i]; \
165 } \
166 fastlock_release(&psmx_atomic_lock); \
167 } while (0)
168
169 #define PSMX_ATOMIC_MSWAP(dst,src,cmp,res,cnt,TYPE) \
170 do { \
171 int i; \
172 TYPE *d = (dst); \
173 TYPE *s = (src); \
174 TYPE *c = (cmp); \
175 TYPE *r = (res); \
176 fastlock_acquire(&psmx_atomic_lock); \
177 for (i=0; i<(cnt); i++) { \
178 r[i] = d[i]; \
179 d[i] = (s[i] & c[i]) | (d[i] & ~c[i]); \
180 } \
181 fastlock_release(&psmx_atomic_lock); \
182 } while (0)
183
psmx_atomic_do_write(void * dest,void * src,int datatype,int op,int count)184 static int psmx_atomic_do_write(void *dest, void *src,
185 int datatype, int op, int count)
186 {
187 switch (op) {
188 case FI_MIN:
189 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_WRITE,
190 dest,src,count,PSMX_MIN);
191 break;
192
193 case FI_MAX:
194 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_WRITE,
195 dest,src,count,PSMX_MAX);
196 break;
197
198 case FI_SUM:
199 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
200 dest,src,count,PSMX_SUM);
201 break;
202
203 case FI_PROD:
204 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
205 dest,src,count,PSMX_PROD);
206 break;
207
208 case FI_LOR:
209 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
210 dest,src,count,PSMX_LOR);
211 break;
212
213 case FI_LAND:
214 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
215 dest,src,count,PSMX_LAND);
216 break;
217
218 case FI_BOR:
219 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
220 dest,src,count,PSMX_BOR);
221 break;
222
223 case FI_BAND:
224 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
225 dest,src,count,PSMX_BAND);
226 break;
227
228 case FI_LXOR:
229 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
230 dest,src,count,PSMX_LXOR);
231 break;
232
233 case FI_BXOR:
234 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_WRITE,
235 dest,src,count,PSMX_BXOR);
236 break;
237
238 case FI_ATOMIC_WRITE:
239 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_WRITE,
240 dest,src,count,PSMX_COPY);
241 break;
242
243 default:
244 return -FI_EOPNOTSUPP;
245 }
246
247 return 0;
248 }
249
psmx_atomic_do_readwrite(void * dest,void * src,void * result,int datatype,int op,int count)250 static int psmx_atomic_do_readwrite(void *dest, void *src, void *result,
251 int datatype, int op, int count)
252 {
253 switch (op) {
254 case FI_MIN:
255 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_READWRITE,
256 dest,src,result,count,PSMX_MIN);
257 break;
258
259 case FI_MAX:
260 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_READWRITE,
261 dest,src,result,count,PSMX_MAX);
262 break;
263
264 case FI_SUM:
265 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
266 dest,src,result,count,PSMX_SUM);
267 break;
268
269 case FI_PROD:
270 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
271 dest,src,result,count,PSMX_PROD);
272 break;
273
274 case FI_LOR:
275 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
276 dest,src,result,count,PSMX_LOR);
277 break;
278
279 case FI_LAND:
280 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
281 dest,src,result,count,PSMX_LAND);
282 break;
283
284 case FI_BOR:
285 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
286 dest,src,result,count,PSMX_BOR);
287 break;
288
289 case FI_BAND:
290 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
291 dest,src,result,count,PSMX_BAND);
292 break;
293
294 case FI_LXOR:
295 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
296 dest,src,result,count,PSMX_LXOR);
297 break;
298
299 case FI_BXOR:
300 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_READWRITE,
301 dest,src,result,count,PSMX_BXOR);
302 break;
303
304 case FI_ATOMIC_READ:
305 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READ,
306 dest,result,count);
307 break;
308
309 case FI_ATOMIC_WRITE:
310 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_READWRITE,
311 dest,src,result,count,PSMX_COPY);
312 break;
313
314 default:
315 return -FI_EOPNOTSUPP;
316 }
317
318 return 0;
319 }
320
psmx_atomic_do_compwrite(void * dest,void * src,void * compare,void * result,int datatype,int op,int count)321 static int psmx_atomic_do_compwrite(void *dest, void *src, void *compare,
322 void *result, int datatype, int op,
323 int count)
324 {
325 switch (op) {
326 case FI_CSWAP:
327 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_CSWAP,
328 dest,src,compare,result,count,==);
329 break;
330
331 case FI_CSWAP_NE:
332 SWITCH_ALL_TYPE(datatype,PSMX_ATOMIC_CSWAP,
333 dest,src,compare,result,count,!=);
334 break;
335
336 case FI_CSWAP_LE:
337 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
338 dest,src,compare,result,count,<=);
339 break;
340
341 case FI_CSWAP_LT:
342 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
343 dest,src,compare,result,count,<);
344 break;
345
346 case FI_CSWAP_GE:
347 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
348 dest,src,compare,result,count,>=);
349 break;
350
351 case FI_CSWAP_GT:
352 SWITCH_ORD_TYPE(datatype,PSMX_ATOMIC_CSWAP,
353 dest,src,compare,result,count,>);
354 break;
355
356 case FI_MSWAP:
357 SWITCH_INT_TYPE(datatype,PSMX_ATOMIC_MSWAP,
358 dest,src,compare,result,count);
359 break;
360
361 default:
362 return -FI_EOPNOTSUPP;
363 }
364
365 return 0;
366 }
367
psmx_am_atomic_completion(void * buf)368 static void psmx_am_atomic_completion(void *buf)
369 {
370 if (buf)
371 free(buf);
372 }
373
psmx_am_atomic_handler(psm_am_token_t token,psm_epaddr_t epaddr,psm_amarg_t * args,int nargs,void * src,uint32_t len)374 int psmx_am_atomic_handler(psm_am_token_t token, psm_epaddr_t epaddr,
375 psm_amarg_t *args, int nargs, void *src,
376 uint32_t len)
377 {
378 psm_amarg_t rep_args[8];
379 int count;
380 uint8_t *addr;
381 uint64_t key;
382 int datatype, op;
383 int err = 0;
384 int op_error = 0;
385 struct psmx_am_request *req;
386 struct psmx_cq_event *event;
387 struct psmx_fid_mr *mr;
388 struct psmx_fid_ep *target_ep;
389 struct psmx_fid_cntr *cntr = NULL;
390 struct psmx_fid_cntr *mr_cntr = NULL;
391 void *tmp_buf;
392
393 switch (args[0].u32w0 & PSMX_AM_OP_MASK) {
394 case PSMX_AM_REQ_ATOMIC_WRITE:
395 count = args[0].u32w1;
396 addr = (uint8_t *)(uintptr_t)args[2].u64;
397 key = args[3].u64;
398 datatype = args[4].u32w0;
399 op = args[4].u32w1;
400 assert(len == ofi_datatype_size(datatype) * count);
401
402 mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
403 op_error = mr ?
404 psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_WRITE) :
405 -FI_EINVAL;
406
407 if (!op_error) {
408 addr += mr->offset;
409 psmx_atomic_do_write(addr, src, datatype, op, count);
410
411 target_ep = mr->domain->atomics_ep;
412 if (target_ep->caps & FI_RMA_EVENT) {
413 cntr = target_ep->remote_write_cntr;
414 mr_cntr = mr->cntr;
415
416 if (cntr)
417 psmx_cntr_inc(cntr);
418
419 if (mr_cntr && mr_cntr != cntr)
420 psmx_cntr_inc(mr_cntr);
421 }
422 }
423
424 rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_WRITE;
425 rep_args[0].u32w1 = op_error;
426 rep_args[1].u64 = args[1].u64;
427 err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
428 rep_args, 2, NULL, 0, 0,
429 NULL, NULL );
430 break;
431
432 case PSMX_AM_REQ_ATOMIC_READWRITE:
433 count = args[0].u32w1;
434 addr = (uint8_t *)(uintptr_t)args[2].u64;
435 key = args[3].u64;
436 datatype = args[4].u32w0;
437 op = args[4].u32w1;
438
439 if (op == FI_ATOMIC_READ)
440 len = ofi_datatype_size(datatype) * count;
441
442 assert(len == ofi_datatype_size(datatype) * count);
443
444 mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
445 op_error = mr ?
446 psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
447 -FI_EINVAL;
448
449 if (!op_error) {
450 addr += mr->offset;
451 tmp_buf = malloc(len);
452 if (tmp_buf)
453 psmx_atomic_do_readwrite(addr, src, tmp_buf,
454 datatype, op, count);
455 else
456 op_error = -FI_ENOMEM;
457
458 target_ep = mr->domain->atomics_ep;
459 if (target_ep->caps & FI_RMA_EVENT) {
460 if (op == FI_ATOMIC_READ) {
461 cntr = target_ep->remote_read_cntr;
462 } else {
463 cntr = target_ep->remote_write_cntr;
464 mr_cntr = mr->cntr;
465 }
466
467 if (cntr)
468 psmx_cntr_inc(cntr);
469
470 if (mr_cntr && mr_cntr != cntr)
471 psmx_cntr_inc(mr_cntr);
472 }
473 } else {
474 tmp_buf = NULL;
475 }
476
477 rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
478 rep_args[0].u32w1 = op_error;
479 rep_args[1].u64 = args[1].u64;
480 err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
481 rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
482 psmx_am_atomic_completion, tmp_buf );
483 break;
484
485 case PSMX_AM_REQ_ATOMIC_COMPWRITE:
486 count = args[0].u32w1;
487 addr = (uint8_t *)(uintptr_t)args[2].u64;
488 key = args[3].u64;
489 datatype = args[4].u32w0;
490 op = args[4].u32w1;
491 len /= 2;
492 assert(len == ofi_datatype_size(datatype) * count);
493
494 mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
495 op_error = mr ?
496 psmx_mr_validate(mr, (uint64_t)addr, len, FI_REMOTE_READ|FI_REMOTE_WRITE) :
497 -FI_EINVAL;
498
499 if (!op_error) {
500 addr += mr->offset;
501 tmp_buf = malloc(len);
502 if (tmp_buf)
503 psmx_atomic_do_compwrite(addr, src, (uint8_t *)src + len,
504 tmp_buf, datatype, op, count);
505 else
506 op_error = -FI_ENOMEM;
507
508 target_ep = mr->domain->atomics_ep;
509 if (target_ep->caps & FI_RMA_EVENT) {
510 cntr = target_ep->remote_write_cntr;
511 mr_cntr = mr->cntr;
512
513 if (cntr)
514 psmx_cntr_inc(cntr);
515
516 if (mr_cntr && mr_cntr != cntr)
517 psmx_cntr_inc(mr_cntr);
518 }
519 } else {
520 tmp_buf = NULL;
521 }
522
523 rep_args[0].u32w0 = PSMX_AM_REP_ATOMIC_READWRITE;
524 rep_args[0].u32w1 = op_error;
525 rep_args[1].u64 = args[1].u64;
526 err = psm_am_reply_short(token, PSMX_AM_ATOMIC_HANDLER,
527 rep_args, 2, tmp_buf, (tmp_buf?len:0), 0,
528 psmx_am_atomic_completion, tmp_buf );
529 break;
530
531 case PSMX_AM_REP_ATOMIC_WRITE:
532 req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
533 op_error = (int)args[0].u32w1;
534 assert(req->op == PSMX_AM_REQ_ATOMIC_WRITE);
535 if (req->ep->send_cq && (!req->no_event || op_error)) {
536 event = psmx_cq_create_event(
537 req->ep->send_cq,
538 req->atomic.context,
539 req->atomic.buf,
540 req->cq_flags,
541 req->atomic.len,
542 0, /* data */
543 0, /* tag */
544 0, /* olen */
545 op_error);
546 if (event)
547 psmx_cq_enqueue_event(req->ep->send_cq, event);
548 else
549 err = -FI_ENOMEM;
550 }
551
552 if (req->ep->write_cntr)
553 psmx_cntr_inc(req->ep->write_cntr);
554
555 free(req);
556 break;
557
558 case PSMX_AM_REP_ATOMIC_READWRITE:
559 case PSMX_AM_REP_ATOMIC_COMPWRITE:
560 req = (struct psmx_am_request *)(uintptr_t)args[1].u64;
561 op_error = (int)args[0].u32w1;
562 assert(op_error || req->atomic.len == len);
563
564 if (!op_error)
565 memcpy(req->atomic.result, src, len);
566
567 if (req->ep->send_cq && (!req->no_event || op_error)) {
568 event = psmx_cq_create_event(
569 req->ep->send_cq,
570 req->atomic.context,
571 req->atomic.buf,
572 req->cq_flags,
573 req->atomic.len,
574 0, /* data */
575 0, /* tag */
576 0, /* olen */
577 op_error);
578 if (event)
579 psmx_cq_enqueue_event(req->ep->send_cq, event);
580 else
581 err = -FI_ENOMEM;
582 }
583
584 if (req->ep->read_cntr)
585 psmx_cntr_inc(req->ep->read_cntr);
586
587 free(req);
588 break;
589
590 default:
591 err = -FI_EINVAL;
592 }
593 return err;
594 }
595
psmx_atomic_self(int am_cmd,struct psmx_fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)596 static int psmx_atomic_self(int am_cmd,
597 struct psmx_fid_ep *ep,
598 const void *buf,
599 size_t count, void *desc,
600 const void *compare, void *compare_desc,
601 void *result, void *result_desc,
602 uint64_t addr, uint64_t key,
603 enum fi_datatype datatype,
604 enum fi_op op, void *context,
605 uint64_t flags)
606 {
607 struct psmx_fid_mr *mr;
608 struct psmx_cq_event *event;
609 struct psmx_fid_ep *target_ep;
610 struct psmx_fid_cntr *cntr = NULL;
611 struct psmx_fid_cntr *mr_cntr = NULL;
612 void *tmp_buf;
613 size_t len;
614 int no_event;
615 int err = 0;
616 int op_error;
617 int access;
618 uint64_t cq_flags = 0;
619
620 if (am_cmd == PSMX_AM_REQ_ATOMIC_WRITE)
621 access = FI_REMOTE_WRITE;
622 else
623 access = FI_REMOTE_READ | FI_REMOTE_WRITE;
624
625 len = ofi_datatype_size(datatype) * count;
626 mr = psmx_mr_get(psmx_active_fabric->active_domain, key);
627 op_error = mr ? psmx_mr_validate(mr, addr, len, access) : -FI_EINVAL;
628
629 if (op_error)
630 goto gen_local_event;
631
632 addr += mr->offset;
633
634 switch (am_cmd) {
635 case PSMX_AM_REQ_ATOMIC_WRITE:
636 err = psmx_atomic_do_write((void *)addr, (void *)buf,
637 (int)datatype, (int)op, (int)count);
638 cq_flags = FI_WRITE | FI_ATOMIC;
639 break;
640
641 case PSMX_AM_REQ_ATOMIC_READWRITE:
642 if (result != buf) {
643 err = psmx_atomic_do_readwrite((void *)addr, (void *)buf,
644 (void *)result, (int)datatype,
645 (int)op, (int)count);
646 } else {
647 tmp_buf = malloc(len);
648 if (tmp_buf) {
649 memcpy(tmp_buf, result, len);
650 err = psmx_atomic_do_readwrite((void *)addr, (void *)buf,
651 tmp_buf, (int)datatype,
652 (int)op, (int)count);
653 memcpy(result, tmp_buf, len);
654 free(tmp_buf);
655 } else {
656 err = -FI_ENOMEM;
657 }
658 }
659 if (op == FI_ATOMIC_READ)
660 cq_flags = FI_READ | FI_ATOMIC;
661 else
662 cq_flags = FI_WRITE | FI_ATOMIC;
663 break;
664
665 case PSMX_AM_REQ_ATOMIC_COMPWRITE:
666 if (result != buf && result != compare) {
667 err = psmx_atomic_do_compwrite((void *)addr, (void *)buf,
668 (void *)compare, (void *)result,
669 (int)datatype, (int)op, (int)count);
670 } else {
671 tmp_buf = malloc(len);
672 if (tmp_buf) {
673 memcpy(tmp_buf, result, len);
674 err = psmx_atomic_do_compwrite((void *)addr, (void *)buf,
675 (void *)compare, tmp_buf,
676 (int)datatype, (int)op, (int)count);
677 memcpy(result, tmp_buf, len);
678 free(tmp_buf);
679 } else {
680 err = -FI_ENOMEM;
681 }
682 }
683 cq_flags = FI_WRITE | FI_ATOMIC;
684 break;
685 }
686
687 target_ep = mr->domain->atomics_ep;
688 if (target_ep->caps & FI_RMA_EVENT) {
689 if (op == FI_ATOMIC_READ) {
690 cntr = target_ep->remote_read_cntr;
691 } else {
692 cntr = target_ep->remote_write_cntr;
693 mr_cntr = mr->cntr;
694 }
695
696 if (cntr)
697 psmx_cntr_inc(cntr);
698
699 if (mr_cntr && mr_cntr != cntr)
700 psmx_cntr_inc(mr_cntr);
701 }
702
703 gen_local_event:
704 no_event = ((flags & PSMX_NO_COMPLETION) ||
705 (ep->send_selective_completion && !(flags & FI_COMPLETION)));
706 if (ep->send_cq && (!no_event || op_error)) {
707 event = psmx_cq_create_event(
708 ep->send_cq,
709 context,
710 (void *)buf,
711 cq_flags,
712 len,
713 0, /* data */
714 0, /* tag */
715 0, /* olen */
716 op_error);
717 if (event)
718 psmx_cq_enqueue_event(ep->send_cq, event);
719 else
720 err = -FI_ENOMEM;
721 }
722
723 switch (am_cmd) {
724 case PSMX_AM_REQ_ATOMIC_WRITE:
725 if (ep->write_cntr)
726 psmx_cntr_inc(ep->write_cntr);
727 break;
728 case PSMX_AM_REQ_ATOMIC_READWRITE:
729 case PSMX_AM_REQ_ATOMIC_COMPWRITE:
730 if (ep->read_cntr)
731 psmx_cntr_inc(ep->read_cntr);
732 break;
733 }
734
735 return err;
736 }
737
_psmx_atomic_write(struct fid_ep * ep,const void * buf,size_t count,void * desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)738 ssize_t _psmx_atomic_write(struct fid_ep *ep,
739 const void *buf,
740 size_t count, void *desc,
741 fi_addr_t dest_addr,
742 uint64_t addr, uint64_t key,
743 enum fi_datatype datatype,
744 enum fi_op op, void *context,
745 uint64_t flags)
746 {
747 struct psmx_fid_ep *ep_priv;
748 struct psmx_fid_av *av;
749 struct psmx_epaddr_context *epaddr_context;
750 struct psmx_am_request *req;
751 psm_amarg_t args[8];
752 int am_flags = PSM_AM_FLAG_ASYNC;
753 int chunk_size, len;
754 size_t idx;
755
756 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
757
758 if (flags & FI_TRIGGER) {
759 struct psmx_trigger *trigger;
760 struct fi_triggered_context *ctxt = context;
761
762 trigger = calloc(1, sizeof(*trigger));
763 if (!trigger)
764 return -FI_ENOMEM;
765
766 trigger->op = PSMX_TRIGGERED_ATOMIC_WRITE;
767 trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
768 struct psmx_fid_cntr, cntr);
769 trigger->threshold = ctxt->trigger.threshold.threshold;
770 trigger->atomic_write.ep = ep;
771 trigger->atomic_write.buf = buf;
772 trigger->atomic_write.count = count;
773 trigger->atomic_write.desc = desc;
774 trigger->atomic_write.dest_addr = dest_addr;
775 trigger->atomic_write.addr = addr;
776 trigger->atomic_write.key = key;
777 trigger->atomic_write.datatype = datatype;
778 trigger->atomic_write.atomic_op = op;
779 trigger->atomic_write.context = context;
780 trigger->atomic_write.flags = flags & ~FI_TRIGGER;
781
782 psmx_cntr_add_trigger(trigger->cntr, trigger);
783 return 0;
784 }
785
786 if (!buf)
787 return -FI_EINVAL;
788
789 if (datatype >= FI_DATATYPE_LAST)
790 return -FI_EINVAL;
791
792 if (op >= FI_ATOMIC_OP_LAST)
793 return -FI_EINVAL;
794
795 av = ep_priv->av;
796 if (av && av->type == FI_AV_TABLE) {
797 idx = dest_addr;
798 if (idx >= av->last)
799 return -FI_EINVAL;
800
801 dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
802 } else if (!dest_addr) {
803 return -FI_EINVAL;
804 }
805
806 epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
807 if (epaddr_context->epid == ep_priv->domain->psm_epid)
808 return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_WRITE,
809 ep_priv, buf, count, desc,
810 NULL, NULL, NULL, NULL,
811 addr, key, datatype, op,
812 context, flags);
813
814 chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
815 len = ofi_datatype_size(datatype)* count;
816 if (len > chunk_size)
817 return -FI_EMSGSIZE;
818
819 if (flags & FI_INJECT) {
820 req = malloc(sizeof(*req) + len);
821 if (!req)
822 return -FI_ENOMEM;
823 memset(req, 0, sizeof(*req));
824 memcpy((uint8_t *)req+sizeof(*req), (void *)buf, len);
825 buf = (uint8_t *)req + sizeof(*req);
826 } else {
827 req = calloc(1, sizeof(*req));
828 if (!req)
829 return -FI_ENOMEM;
830 }
831
832 req->no_event = (flags & PSMX_NO_COMPLETION) ||
833 (ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
834
835 req->op = PSMX_AM_REQ_ATOMIC_WRITE;
836 req->atomic.buf = (void *)buf;
837 req->atomic.len = len;
838 req->atomic.addr = addr;
839 req->atomic.key = key;
840 req->atomic.context = context;
841 req->ep = ep_priv;
842 req->cq_flags = FI_WRITE | FI_ATOMIC;
843
844 args[0].u32w0 = PSMX_AM_REQ_ATOMIC_WRITE;
845 args[0].u32w1 = count;
846 args[1].u64 = (uint64_t)(uintptr_t)req;
847 args[2].u64 = addr;
848 args[3].u64 = key;
849 args[4].u32w0 = datatype;
850 args[4].u32w1 = op;
851 psm_am_request_short((psm_epaddr_t) dest_addr,
852 PSMX_AM_ATOMIC_HANDLER, args, 5,
853 (void *)buf, len, am_flags, NULL, NULL);
854
855 return 0;
856 }
857
psmx_atomic_write(struct fid_ep * ep,const void * buf,size_t count,void * desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)858 static ssize_t psmx_atomic_write(struct fid_ep *ep,
859 const void *buf,
860 size_t count, void *desc,
861 fi_addr_t dest_addr,
862 uint64_t addr, uint64_t key,
863 enum fi_datatype datatype,
864 enum fi_op op, void *context)
865 {
866 struct psmx_fid_ep *ep_priv;
867
868 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
869 return _psmx_atomic_write(ep, buf, count, desc,
870 dest_addr, addr, key,
871 datatype, op, context, ep_priv->tx_flags);
872 }
873
psmx_atomic_writemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,uint64_t flags)874 static ssize_t psmx_atomic_writemsg(struct fid_ep *ep,
875 const struct fi_msg_atomic *msg,
876 uint64_t flags)
877 {
878 if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov)
879 return -FI_EINVAL;
880
881 return _psmx_atomic_write(ep, msg->msg_iov[0].addr,
882 msg->msg_iov[0].count,
883 msg->desc ? msg->desc[0] : NULL,
884 msg->addr, msg->rma_iov[0].addr,
885 msg->rma_iov[0].key, msg->datatype,
886 msg->op, msg->context, flags);
887 }
888
psmx_atomic_writev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)889 static ssize_t psmx_atomic_writev(struct fid_ep *ep,
890 const struct fi_ioc *iov,
891 void **desc, size_t count,
892 fi_addr_t dest_addr,
893 uint64_t addr, uint64_t key,
894 enum fi_datatype datatype,
895 enum fi_op op, void *context)
896 {
897 if (!iov || count != 1)
898 return -FI_EINVAL;
899
900 return psmx_atomic_write(ep, iov->addr, iov->count,
901 desc ? desc[0] : NULL, dest_addr, addr, key,
902 datatype, op, context);
903 }
904
psmx_atomic_inject(struct fid_ep * ep,const void * buf,size_t count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op)905 static ssize_t psmx_atomic_inject(struct fid_ep *ep,
906 const void *buf,
907 size_t count, /*void *desc,*/
908 fi_addr_t dest_addr,
909 uint64_t addr, uint64_t key,
910 enum fi_datatype datatype,
911 enum fi_op op)
912 {
913 struct psmx_fid_ep *ep_priv;
914
915 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
916 return _psmx_atomic_write(ep, buf, count, NULL/*desc*/,
917 dest_addr, addr, key,
918 datatype, op, NULL,
919 ep_priv->tx_flags | FI_INJECT | PSMX_NO_COMPLETION);
920 }
921
_psmx_atomic_readwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)922 ssize_t _psmx_atomic_readwrite(struct fid_ep *ep,
923 const void *buf,
924 size_t count, void *desc,
925 void *result, void *result_desc,
926 fi_addr_t dest_addr,
927 uint64_t addr, uint64_t key,
928 enum fi_datatype datatype,
929 enum fi_op op, void *context,
930 uint64_t flags)
931 {
932 struct psmx_fid_ep *ep_priv;
933 struct psmx_fid_av *av;
934 struct psmx_epaddr_context *epaddr_context;
935 struct psmx_am_request *req;
936 psm_amarg_t args[8];
937 int am_flags = PSM_AM_FLAG_ASYNC;
938 int chunk_size, len;
939 size_t idx;
940
941 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
942
943 if (flags & FI_TRIGGER) {
944 struct psmx_trigger *trigger;
945 struct fi_triggered_context *ctxt = context;
946
947 trigger = calloc(1, sizeof(*trigger));
948 if (!trigger)
949 return -FI_ENOMEM;
950
951 trigger->op = PSMX_TRIGGERED_ATOMIC_READWRITE;
952 trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
953 struct psmx_fid_cntr, cntr);
954 trigger->threshold = ctxt->trigger.threshold.threshold;
955 trigger->atomic_readwrite.ep = ep;
956 trigger->atomic_readwrite.buf = buf;
957 trigger->atomic_readwrite.count = count;
958 trigger->atomic_readwrite.desc = desc;
959 trigger->atomic_readwrite.result = result;
960 trigger->atomic_readwrite.result_desc = result_desc;
961 trigger->atomic_readwrite.dest_addr = dest_addr;
962 trigger->atomic_readwrite.addr = addr;
963 trigger->atomic_readwrite.key = key;
964 trigger->atomic_readwrite.datatype = datatype;
965 trigger->atomic_readwrite.atomic_op = op;
966 trigger->atomic_readwrite.context = context;
967 trigger->atomic_readwrite.flags = flags & ~FI_TRIGGER;
968
969 psmx_cntr_add_trigger(trigger->cntr, trigger);
970 return 0;
971 }
972
973 if (!buf && op != FI_ATOMIC_READ)
974 return -FI_EINVAL;
975
976 if (datatype >= FI_DATATYPE_LAST)
977 return -FI_EINVAL;
978
979 if (op >= FI_ATOMIC_OP_LAST)
980 return -FI_EINVAL;
981
982 av = ep_priv->av;
983 if (av && av->type == FI_AV_TABLE) {
984 idx = dest_addr;
985 if (idx >= av->last)
986 return -FI_EINVAL;
987
988 dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
989 } else if (!dest_addr) {
990 return -FI_EINVAL;
991 }
992
993 epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
994 if (epaddr_context->epid == ep_priv->domain->psm_epid)
995 return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_READWRITE,
996 ep_priv, buf, count, desc,
997 NULL, NULL, result, result_desc,
998 addr, key, datatype, op,
999 context, flags);
1000
1001 chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
1002 len = ofi_datatype_size(datatype) * count;
1003 if (len > chunk_size)
1004 return -FI_EMSGSIZE;
1005
1006 if ((flags & FI_INJECT) && op != FI_ATOMIC_READ) {
1007 req = malloc(sizeof(*req) + len);
1008 if (!req)
1009 return -FI_ENOMEM;
1010 memset(req, 0, sizeof(*req));
1011 memcpy((uint8_t *)req+sizeof(*req), (void *)buf, len);
1012 buf = (uint8_t *)req + sizeof(*req);
1013 } else {
1014 req = calloc(1, sizeof(*req));
1015 if (!req)
1016 return -FI_ENOMEM;
1017 }
1018
1019 req->no_event = (flags & PSMX_NO_COMPLETION) ||
1020 (ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
1021
1022 req->op = PSMX_AM_REQ_ATOMIC_READWRITE;
1023 req->atomic.buf = (void *)buf;
1024 req->atomic.len = len;
1025 req->atomic.addr = addr;
1026 req->atomic.key = key;
1027 req->atomic.context = context;
1028 req->atomic.result = result;
1029 req->ep = ep_priv;
1030 if (op == FI_ATOMIC_READ)
1031 req->cq_flags = FI_READ | FI_ATOMIC;
1032 else
1033 req->cq_flags = FI_WRITE | FI_ATOMIC;
1034
1035 args[0].u32w0 = PSMX_AM_REQ_ATOMIC_READWRITE;
1036 args[0].u32w1 = count;
1037 args[1].u64 = (uint64_t)(uintptr_t)req;
1038 args[2].u64 = addr;
1039 args[3].u64 = key;
1040 args[4].u32w0 = datatype;
1041 args[4].u32w1 = op;
1042 psm_am_request_short((psm_epaddr_t) dest_addr,
1043 PSMX_AM_ATOMIC_HANDLER, args, 5,
1044 (void *)buf, (buf?len:0), am_flags, NULL, NULL);
1045
1046 return 0;
1047 }
1048
psmx_atomic_readwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1049 static ssize_t psmx_atomic_readwrite(struct fid_ep *ep,
1050 const void *buf,
1051 size_t count, void *desc,
1052 void *result, void *result_desc,
1053 fi_addr_t dest_addr,
1054 uint64_t addr, uint64_t key,
1055 enum fi_datatype datatype,
1056 enum fi_op op, void *context)
1057 {
1058 struct psmx_fid_ep *ep_priv;
1059
1060 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1061 return _psmx_atomic_readwrite(ep, buf, count, desc,
1062 result, result_desc, dest_addr,
1063 addr, key, datatype, op,
1064 context, ep_priv->tx_flags);
1065 }
1066
psmx_atomic_readwritemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)1067 static ssize_t psmx_atomic_readwritemsg(struct fid_ep *ep,
1068 const struct fi_msg_atomic *msg,
1069 struct fi_ioc *resultv,
1070 void **result_desc,
1071 size_t result_count,
1072 uint64_t flags)
1073 {
1074 void *buf;
1075 size_t count;
1076
1077 if (!msg || !msg->rma_iov)
1078 return -FI_EINVAL;
1079
1080 if (msg->op == FI_ATOMIC_READ) {
1081 if (result_count != 1 || !resultv)
1082 return -FI_EINVAL;
1083
1084 buf = NULL;
1085 count = resultv[0].count;
1086 } else {
1087 if (msg->iov_count != 1 || !msg->msg_iov)
1088 return -FI_EINVAL;
1089
1090 buf = msg->msg_iov[0].addr;
1091 count = msg->msg_iov[0].count;
1092 }
1093
1094 return _psmx_atomic_readwrite(ep, buf, count,
1095 msg->desc ? msg->desc[0] : NULL,
1096 resultv[0].addr,
1097 result_desc ? result_desc[0] : NULL,
1098 msg->addr, msg->rma_iov[0].addr,
1099 msg->rma_iov[0].key, msg->datatype,
1100 msg->op, msg->context, flags);
1101 }
1102
psmx_atomic_readwritev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1103 static ssize_t psmx_atomic_readwritev(struct fid_ep *ep,
1104 const struct fi_ioc *iov,
1105 void **desc, size_t count,
1106 struct fi_ioc *resultv,
1107 void **result_desc, size_t result_count,
1108 fi_addr_t dest_addr,
1109 uint64_t addr, uint64_t key,
1110 enum fi_datatype datatype,
1111 enum fi_op op, void *context)
1112 {
1113 if (!iov || count != 1 || !resultv)
1114 return -FI_EINVAL;
1115
1116 return psmx_atomic_readwrite(ep, iov->addr, iov->count,
1117 desc ? desc[0] : NULL,
1118 resultv[0].addr,
1119 result_desc ? result_desc[0] : NULL,
1120 dest_addr, addr, key, datatype, op, context);
1121 }
1122
_psmx_atomic_compwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context,uint64_t flags)1123 ssize_t _psmx_atomic_compwrite(struct fid_ep *ep,
1124 const void *buf,
1125 size_t count, void *desc,
1126 const void *compare, void *compare_desc,
1127 void *result, void *result_desc,
1128 fi_addr_t dest_addr,
1129 uint64_t addr, uint64_t key,
1130 enum fi_datatype datatype,
1131 enum fi_op op, void *context,
1132 uint64_t flags)
1133 {
1134 struct psmx_fid_ep *ep_priv;
1135 struct psmx_fid_av *av;
1136 struct psmx_epaddr_context *epaddr_context;
1137 struct psmx_am_request *req;
1138 psm_amarg_t args[8];
1139 int am_flags = PSM_AM_FLAG_ASYNC;
1140 int chunk_size, len;
1141 void *tmp_buf = NULL;
1142 size_t idx;
1143
1144 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1145
1146 if (flags & FI_TRIGGER) {
1147 struct psmx_trigger *trigger;
1148 struct fi_triggered_context *ctxt = context;
1149
1150 trigger = calloc(1, sizeof(*trigger));
1151 if (!trigger)
1152 return -FI_ENOMEM;
1153
1154 trigger->op = PSMX_TRIGGERED_ATOMIC_COMPWRITE;
1155 trigger->cntr = container_of(ctxt->trigger.threshold.cntr,
1156 struct psmx_fid_cntr, cntr);
1157 trigger->threshold = ctxt->trigger.threshold.threshold;
1158 trigger->atomic_compwrite.ep = ep;
1159 trigger->atomic_compwrite.buf = buf;
1160 trigger->atomic_compwrite.count = count;
1161 trigger->atomic_compwrite.desc = desc;
1162 trigger->atomic_compwrite.compare = compare;
1163 trigger->atomic_compwrite.compare_desc = compare_desc;
1164 trigger->atomic_compwrite.result = result;
1165 trigger->atomic_compwrite.result_desc = result_desc;
1166 trigger->atomic_compwrite.dest_addr = dest_addr;
1167 trigger->atomic_compwrite.addr = addr;
1168 trigger->atomic_compwrite.key = key;
1169 trigger->atomic_compwrite.datatype = datatype;
1170 trigger->atomic_compwrite.atomic_op = op;
1171 trigger->atomic_compwrite.context = context;
1172 trigger->atomic_compwrite.flags = flags & ~FI_TRIGGER;
1173
1174 psmx_cntr_add_trigger(trigger->cntr, trigger);
1175 return 0;
1176 }
1177
1178 if (!buf)
1179 return -FI_EINVAL;
1180
1181 if (datatype >= FI_DATATYPE_LAST)
1182 return -FI_EINVAL;
1183
1184 if (op >= FI_ATOMIC_OP_LAST)
1185 return -FI_EINVAL;
1186
1187 av = ep_priv->av;
1188 if (av && av->type == FI_AV_TABLE) {
1189 idx = dest_addr;
1190 if (idx >= av->last)
1191 return -FI_EINVAL;
1192
1193 dest_addr = (fi_addr_t) av->psm_epaddrs[idx];
1194 } else if (!dest_addr) {
1195 return -FI_EINVAL;
1196 }
1197
1198 epaddr_context = psm_epaddr_getctxt((void *)dest_addr);
1199 if (epaddr_context->epid == ep_priv->domain->psm_epid)
1200 return psmx_atomic_self(PSMX_AM_REQ_ATOMIC_COMPWRITE,
1201 ep_priv, buf, count, desc,
1202 compare, compare_desc,
1203 result, result_desc,
1204 addr, key, datatype, op,
1205 context, flags);
1206
1207 chunk_size = MIN(PSMX_AM_CHUNK_SIZE, psmx_am_param.max_request_short);
1208 len = ofi_datatype_size(datatype) * count;
1209 if (len * 2 > chunk_size)
1210 return -FI_EMSGSIZE;
1211
1212 if (flags & FI_INJECT) {
1213 req = malloc(sizeof(*req) + len + len);
1214 if (!req)
1215 return -FI_ENOMEM;
1216 memset(req, 0, sizeof(*req));
1217 memcpy((uint8_t *)req + sizeof(*req), (void *)buf, len);
1218 memcpy((uint8_t *)req + sizeof(*req) + len, (void *)compare, len);
1219 buf = (uint8_t *)req + sizeof(*req);
1220 compare = (uint8_t *)buf + len;
1221 } else {
1222 req = calloc(1, sizeof(*req));
1223 if (!req)
1224 return -FI_ENOMEM;
1225
1226 if ((uintptr_t)compare != (uintptr_t)buf + len) {
1227 tmp_buf = malloc(len * 2);
1228 if (!tmp_buf) {
1229 free(req);
1230 return -FI_ENOMEM;
1231 }
1232
1233 memcpy(tmp_buf, buf, len);
1234 memcpy((uint8_t *)tmp_buf + len, compare, len);
1235 }
1236 }
1237
1238 req->no_event = (flags & PSMX_NO_COMPLETION) ||
1239 (ep_priv->send_selective_completion && !(flags & FI_COMPLETION));
1240
1241 req->op = PSMX_AM_REQ_ATOMIC_COMPWRITE;
1242 req->atomic.buf = (void *)buf;
1243 req->atomic.len = len;
1244 req->atomic.addr = addr;
1245 req->atomic.key = key;
1246 req->atomic.context = context;
1247 req->atomic.result = result;
1248 req->ep = ep_priv;
1249 req->cq_flags = FI_WRITE | FI_ATOMIC;
1250
1251 args[0].u32w0 = PSMX_AM_REQ_ATOMIC_COMPWRITE;
1252 args[0].u32w1 = count;
1253 args[1].u64 = (uint64_t)(uintptr_t)req;
1254 args[2].u64 = addr;
1255 args[3].u64 = key;
1256 args[4].u32w0 = datatype;
1257 args[4].u32w1 = op;
1258 psm_am_request_short((psm_epaddr_t) dest_addr,
1259 PSMX_AM_ATOMIC_HANDLER, args, 5,
1260 tmp_buf ? tmp_buf : (void *)buf,
1261 len * 2, am_flags,
1262 psmx_am_atomic_completion, tmp_buf);
1263
1264 return 0;
1265 }
1266
psmx_atomic_compwrite(struct fid_ep * ep,const void * buf,size_t count,void * desc,const void * compare,void * compare_desc,void * result,void * result_desc,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1267 static ssize_t psmx_atomic_compwrite(struct fid_ep *ep,
1268 const void *buf,
1269 size_t count, void *desc,
1270 const void *compare, void *compare_desc,
1271 void *result, void *result_desc,
1272 fi_addr_t dest_addr,
1273 uint64_t addr, uint64_t key,
1274 enum fi_datatype datatype,
1275 enum fi_op op, void *context)
1276 {
1277 struct psmx_fid_ep *ep_priv;
1278
1279 ep_priv = container_of(ep, struct psmx_fid_ep, ep);
1280 return _psmx_atomic_compwrite(ep, buf, count, desc,
1281 compare, compare_desc,
1282 result, result_desc,
1283 dest_addr, addr, key,
1284 datatype, op, context, ep_priv->tx_flags);
1285 }
1286
psmx_atomic_compwritemsg(struct fid_ep * ep,const struct fi_msg_atomic * msg,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,uint64_t flags)1287 static ssize_t psmx_atomic_compwritemsg(struct fid_ep *ep,
1288 const struct fi_msg_atomic *msg,
1289 const struct fi_ioc *comparev,
1290 void **compare_desc,
1291 size_t compare_count,
1292 struct fi_ioc *resultv,
1293 void **result_desc,
1294 size_t result_count,
1295 uint64_t flags)
1296 {
1297 if (!msg || msg->iov_count != 1 || !msg->msg_iov || !msg->rma_iov || !resultv)
1298 return -FI_EINVAL;
1299
1300 return _psmx_atomic_compwrite(ep, msg->msg_iov[0].addr,
1301 msg->msg_iov[0].count,
1302 msg->desc ? msg->desc[0] : NULL,
1303 comparev[0].addr,
1304 compare_desc ? compare_desc[0] : NULL,
1305 resultv[0].addr,
1306 result_desc ? result_desc[0] : NULL,
1307 msg->addr, msg->rma_iov[0].addr,
1308 msg->rma_iov[0].key, msg->datatype,
1309 msg->op, msg->context, flags);
1310 }
1311
psmx_atomic_compwritev(struct fid_ep * ep,const struct fi_ioc * iov,void ** desc,size_t count,const struct fi_ioc * comparev,void ** compare_desc,size_t compare_count,struct fi_ioc * resultv,void ** result_desc,size_t result_count,fi_addr_t dest_addr,uint64_t addr,uint64_t key,enum fi_datatype datatype,enum fi_op op,void * context)1312 static ssize_t psmx_atomic_compwritev(struct fid_ep *ep,
1313 const struct fi_ioc *iov,
1314 void **desc, size_t count,
1315 const struct fi_ioc *comparev,
1316 void **compare_desc,
1317 size_t compare_count,
1318 struct fi_ioc *resultv,
1319 void **result_desc,
1320 size_t result_count,
1321 fi_addr_t dest_addr,
1322 uint64_t addr, uint64_t key,
1323 enum fi_datatype datatype,
1324 enum fi_op op, void *context)
1325 {
1326 if (!iov || count != 1 || !comparev || !resultv)
1327 return -FI_EINVAL;
1328
1329 return psmx_atomic_compwrite(ep, iov->addr, iov->count,
1330 desc ? desc[0] : NULL,
1331 comparev[0].addr,
1332 compare_desc ? compare_desc[0] : NULL,
1333 resultv[0].addr,
1334 result_desc ? result_desc[0] : NULL,
1335 dest_addr, addr, key, datatype, op, context);
1336 }
1337
psmx_atomic_writevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1338 static int psmx_atomic_writevalid(struct fid_ep *ep,
1339 enum fi_datatype datatype,
1340 enum fi_op op, size_t *count)
1341 {
1342 int chunk_size;
1343
1344 if (datatype >= FI_DATATYPE_LAST)
1345 return -FI_EOPNOTSUPP;
1346
1347 switch (op) {
1348 case FI_MIN:
1349 case FI_MAX:
1350 case FI_SUM:
1351 case FI_PROD:
1352 case FI_LOR:
1353 case FI_LAND:
1354 case FI_BOR:
1355 case FI_BAND:
1356 case FI_LXOR:
1357 case FI_BXOR:
1358 case FI_ATOMIC_WRITE:
1359 break;
1360
1361 default:
1362 return -FI_EOPNOTSUPP;
1363 }
1364
1365 if (count) {
1366 chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1367 psmx_am_param.max_request_short);
1368 *count = chunk_size / ofi_datatype_size(datatype);
1369 }
1370 return 0;
1371 }
1372
psmx_atomic_readwritevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1373 static int psmx_atomic_readwritevalid(struct fid_ep *ep,
1374 enum fi_datatype datatype,
1375 enum fi_op op, size_t *count)
1376 {
1377 int chunk_size;
1378
1379 if (datatype >= FI_DATATYPE_LAST)
1380 return -FI_EOPNOTSUPP;
1381
1382 switch (op) {
1383 case FI_MIN:
1384 case FI_MAX:
1385 case FI_SUM:
1386 case FI_PROD:
1387 case FI_LOR:
1388 case FI_LAND:
1389 case FI_BOR:
1390 case FI_BAND:
1391 case FI_LXOR:
1392 case FI_BXOR:
1393 case FI_ATOMIC_READ:
1394 case FI_ATOMIC_WRITE:
1395 break;
1396
1397 default:
1398 return -FI_EOPNOTSUPP;
1399 }
1400
1401 if (count) {
1402 chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1403 psmx_am_param.max_request_short);
1404 *count = chunk_size / ofi_datatype_size(datatype);
1405 }
1406 return 0;
1407 }
1408
psmx_atomic_compwritevalid(struct fid_ep * ep,enum fi_datatype datatype,enum fi_op op,size_t * count)1409 static int psmx_atomic_compwritevalid(struct fid_ep *ep,
1410 enum fi_datatype datatype,
1411 enum fi_op op, size_t *count)
1412 {
1413 int chunk_size;
1414
1415 if (datatype >= FI_DATATYPE_LAST)
1416 return -FI_EOPNOTSUPP;
1417
1418 switch (op) {
1419 case FI_CSWAP:
1420 case FI_CSWAP_NE:
1421 break;
1422
1423 case FI_CSWAP_LE:
1424 case FI_CSWAP_LT:
1425 case FI_CSWAP_GE:
1426 case FI_CSWAP_GT:
1427 if (datatype == FI_FLOAT_COMPLEX ||
1428 datatype == FI_DOUBLE_COMPLEX ||
1429 datatype == FI_LONG_DOUBLE_COMPLEX)
1430 return -FI_EOPNOTSUPP;
1431 break;
1432
1433 case FI_MSWAP:
1434 if (datatype == FI_FLOAT_COMPLEX ||
1435 datatype == FI_DOUBLE_COMPLEX ||
1436 datatype == FI_LONG_DOUBLE_COMPLEX ||
1437 datatype == FI_FLOAT ||
1438 datatype == FI_DOUBLE ||
1439 datatype == FI_LONG_DOUBLE)
1440 return -FI_EOPNOTSUPP;
1441 break;
1442
1443 default:
1444 return -FI_EOPNOTSUPP;
1445 }
1446
1447 if (count) {
1448 chunk_size = MIN(PSMX_AM_CHUNK_SIZE,
1449 psmx_am_param.max_request_short);
1450 *count = chunk_size / (2 * ofi_datatype_size(datatype));
1451 }
1452 return 0;
1453 }
1454
psmx_query_atomic(struct fid_domain * doamin,enum fi_datatype datatype,enum fi_op op,struct fi_atomic_attr * attr,uint64_t flags)1455 int psmx_query_atomic(struct fid_domain *doamin, enum fi_datatype datatype,
1456 enum fi_op op, struct fi_atomic_attr *attr, uint64_t flags)
1457 {
1458 int ret;
1459 size_t count;
1460
1461 if (flags & FI_TAGGED)
1462 return -FI_EOPNOTSUPP;
1463
1464 if (flags & FI_COMPARE_ATOMIC) {
1465 if (flags & FI_FETCH_ATOMIC)
1466 return -FI_EINVAL;
1467 ret = psmx_atomic_compwritevalid(NULL, datatype, op, &count);
1468 } else if (flags & FI_FETCH_ATOMIC) {
1469 ret = psmx_atomic_readwritevalid(NULL, datatype, op, &count);
1470 } else {
1471 ret = psmx_atomic_writevalid(NULL, datatype, op, &count);
1472 }
1473
1474 if (attr && !ret) {
1475 attr->size = ofi_datatype_size(datatype);
1476 attr->count = count;
1477 }
1478
1479 return ret;
1480 }
1481
1482 struct fi_ops_atomic psmx_atomic_ops = {
1483 .size = sizeof(struct fi_ops_atomic),
1484 .write = psmx_atomic_write,
1485 .writev = psmx_atomic_writev,
1486 .writemsg = psmx_atomic_writemsg,
1487 .inject = psmx_atomic_inject,
1488 .readwrite = psmx_atomic_readwrite,
1489 .readwritev = psmx_atomic_readwritev,
1490 .readwritemsg = psmx_atomic_readwritemsg,
1491 .compwrite = psmx_atomic_compwrite,
1492 .compwritev = psmx_atomic_compwritev,
1493 .compwritemsg = psmx_atomic_compwritemsg,
1494 .writevalid = psmx_atomic_writevalid,
1495 .readwritevalid = psmx_atomic_readwritevalid,
1496 .compwritevalid = psmx_atomic_compwritevalid,
1497 };
1498
1499