1 /* -----------------------------------------------------------------
2 * Programmer(s): Slaven Peles, Cody J. Balos, Daniel McGreer @ LLNL
3 * -----------------------------------------------------------------
4 * SUNDIALS Copyright Start
5 * Copyright (c) 2002-2021, Lawrence Livermore National Security
6 * and Southern Methodist University.
7 * All rights reserved.
8 *
9 * See the top-level LICENSE and NOTICE files for details.
10 *
11 * SPDX-License-Identifier: BSD-3-Clause
12 * SUNDIALS Copyright End
13 * -----------------------------------------------------------------
14 * This is the implementation file for a RAJA implementation
15 * of the NVECTOR package. This will support CUDA and HIP
16 * -----------------------------------------------------------------*/
17
18 #include <stdio.h>
19 #include <stdlib.h>
20
21 #include <RAJA/RAJA.hpp>
22 #include <nvector/nvector_raja.h>
23
24 #include "sundials_debug.h"
25
26 // RAJA defines
27 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
28 #include <sunmemory/sunmemory_cuda.h>
29 #include "sundials_cuda.h"
30 #define RAJA_NODE_TYPE RAJA::cuda_exec< 256 >
31 #define RAJA_REDUCE_TYPE RAJA::cuda_reduce
32 #define SUNDIALS_GPU_PREFIX(val) cuda ## val
33 #define SUNDIALS_GPU_VERIFY SUNDIALS_CUDA_VERIFY
34 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
35 #include <sunmemory/sunmemory_hip.h>
36 #include "sundials_hip.h"
37 #define RAJA_NODE_TYPE RAJA::hip_exec< 512 >
38 #define RAJA_REDUCE_TYPE RAJA::hip_reduce
39 #define SUNDIALS_GPU_PREFIX(val) hip ## val
40 #define SUNDIALS_GPU_VERIFY SUNDIALS_HIP_VERIFY
41 #endif
42
43 #define RAJA_LAMBDA [=] __device__
44
45 #define ZERO RCONST(0.0)
46 #define HALF RCONST(0.5)
47 #define ONE RCONST(1.0)
48 #define ONEPT5 RCONST(1.5)
49
50 extern "C" {
51
52 // Static constants
53 static constexpr sunindextype zeroIdx = 0;
54
55 // Helpful macros
56 #define NVEC_RAJA_CONTENT(x) ((N_VectorContent_Raja)(x->content))
57 #define NVEC_RAJA_PRIVATE(x) ((N_PrivateVectorContent_Raja)(NVEC_RAJA_CONTENT(x)->priv))
58 #define NVEC_RAJA_MEMSIZE(x) (NVEC_RAJA_CONTENT(x)->length * sizeof(realtype))
59 #define NVEC_RAJA_MEMHELP(x) (NVEC_RAJA_CONTENT(x)->mem_helper)
60 #define NVEC_RAJA_HDATAp(x) ((realtype*) NVEC_RAJA_CONTENT(x)->host_data->ptr)
61 #define NVEC_RAJA_DDATAp(x) ((realtype*) NVEC_RAJA_CONTENT(x)->device_data->ptr)
62
63 /*
64 * Private structure definition
65 */
66
67 struct _N_PrivateVectorContent_Raja
68 {
69 booleantype use_managed_mem; /* indicates if the data pointers and buffer pointers are managed memory */
70 };
71
72 typedef struct _N_PrivateVectorContent_Raja *N_PrivateVectorContent_Raja;
73
74
75 /*
76 * Utility functions
77 */
78
79 static int AllocateData(N_Vector v);
80 static void CreateArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
81 int nvec, N_Vector *V);
82 static void Create2DArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
83 int nvec, int nsum, N_Vector **V);
84
N_VNewEmpty_Raja()85 N_Vector N_VNewEmpty_Raja()
86 {
87 N_Vector v;
88
89 /* Create an empty vector object */
90 v = NULL;
91 v = N_VNewEmpty();
92 if (v == NULL) return(NULL);
93
94 /* Attach operations */
95
96 /* constructors, destructors, and utility operations */
97 v->ops->nvgetvectorid = N_VGetVectorID_Raja;
98 v->ops->nvclone = N_VClone_Raja;
99 v->ops->nvcloneempty = N_VCloneEmpty_Raja;
100 v->ops->nvdestroy = N_VDestroy_Raja;
101 v->ops->nvspace = N_VSpace_Raja;
102 v->ops->nvgetlength = N_VGetLength_Raja;
103 v->ops->nvgetarraypointer = N_VGetHostArrayPointer_Raja;
104 v->ops->nvgetdevicearraypointer = N_VGetDeviceArrayPointer_Raja;
105 v->ops->nvsetarraypointer = N_VSetHostArrayPointer_Raja;
106
107
108 /* standard vector operations */
109 v->ops->nvlinearsum = N_VLinearSum_Raja;
110 v->ops->nvconst = N_VConst_Raja;
111 v->ops->nvprod = N_VProd_Raja;
112 v->ops->nvdiv = N_VDiv_Raja;
113 v->ops->nvscale = N_VScale_Raja;
114 v->ops->nvabs = N_VAbs_Raja;
115 v->ops->nvinv = N_VInv_Raja;
116 v->ops->nvaddconst = N_VAddConst_Raja;
117 v->ops->nvdotprod = N_VDotProd_Raja;
118 v->ops->nvmaxnorm = N_VMaxNorm_Raja;
119 v->ops->nvmin = N_VMin_Raja;
120 v->ops->nvl1norm = N_VL1Norm_Raja;
121 v->ops->nvinvtest = N_VInvTest_Raja;
122 v->ops->nvconstrmask = N_VConstrMask_Raja;
123 v->ops->nvminquotient = N_VMinQuotient_Raja;
124 v->ops->nvwrmsnormmask = N_VWrmsNormMask_Raja;
125 v->ops->nvwrmsnorm = N_VWrmsNorm_Raja;
126 v->ops->nvwl2norm = N_VWL2Norm_Raja;
127 v->ops->nvcompare = N_VCompare_Raja;
128
129 /* fused and vector array operations are disabled (NULL) by default */
130
131 /* local reduction operations */
132 v->ops->nvwsqrsumlocal = N_VWSqrSumLocal_Raja;
133 v->ops->nvwsqrsummasklocal = N_VWSqrSumMaskLocal_Raja;
134 v->ops->nvdotprodlocal = N_VDotProd_Raja;
135 v->ops->nvmaxnormlocal = N_VMaxNorm_Raja;
136 v->ops->nvminlocal = N_VMin_Raja;
137 v->ops->nvl1normlocal = N_VL1Norm_Raja;
138 v->ops->nvinvtestlocal = N_VInvTest_Raja;
139 v->ops->nvconstrmasklocal = N_VConstrMask_Raja;
140 v->ops->nvminquotientlocal = N_VMinQuotient_Raja;
141
142 /* XBraid interface operations */
143 v->ops->nvbufsize = N_VBufSize_Raja;
144 v->ops->nvbufpack = N_VBufPack_Raja;
145 v->ops->nvbufunpack = N_VBufUnpack_Raja;
146
147 /* print operation for debugging */
148 v->ops->nvprint = N_VPrint_Raja;
149 v->ops->nvprintfile = N_VPrintFile_Raja;
150
151 v->content = (N_VectorContent_Raja) malloc(sizeof(_N_VectorContent_Raja));
152 if (v->content == NULL)
153 {
154 N_VDestroy(v);
155 return NULL;
156 }
157
158 NVEC_RAJA_CONTENT(v)->priv = malloc(sizeof(_N_PrivateVectorContent_Raja));
159 if (NVEC_RAJA_CONTENT(v)->priv == NULL)
160 {
161 N_VDestroy(v);
162 return NULL;
163 }
164
165 NVEC_RAJA_CONTENT(v)->length = 0;
166 NVEC_RAJA_CONTENT(v)->mem_helper = NULL;
167 NVEC_RAJA_CONTENT(v)->own_helper = SUNFALSE;
168 NVEC_RAJA_CONTENT(v)->host_data = NULL;
169 NVEC_RAJA_CONTENT(v)->device_data = NULL;
170 NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
171
172 return(v);
173 }
174
N_VNew_Raja(sunindextype length)175 N_Vector N_VNew_Raja(sunindextype length)
176 {
177 N_Vector v;
178
179 v = NULL;
180 v = N_VNewEmpty_Raja();
181 if (v == NULL) return(NULL);
182
183 NVEC_RAJA_CONTENT(v)->length = length;
184 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
185 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Cuda();
186 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
187 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Hip();
188 #endif
189 NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
190 NVEC_RAJA_CONTENT(v)->host_data = NULL;
191 NVEC_RAJA_CONTENT(v)->device_data = NULL;
192 NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
193
194 if (NVEC_RAJA_MEMHELP(v) == NULL)
195 {
196 SUNDIALS_DEBUG_PRINT("ERROR in N_VNew_Raja: memory helper is NULL\n");
197 N_VDestroy(v);
198 return(NULL);
199 }
200
201 if (AllocateData(v))
202 {
203 SUNDIALS_DEBUG_PRINT("ERROR in N_VNew_Raja: AllocateData returned nonzero\n");
204 N_VDestroy(v);
205 return NULL;
206 }
207
208 return(v);
209 }
210
N_VNewWithMemHelp_Raja(sunindextype length,booleantype use_managed_mem,SUNMemoryHelper helper)211 N_Vector N_VNewWithMemHelp_Raja(sunindextype length, booleantype use_managed_mem, SUNMemoryHelper helper)
212 {
213 N_Vector v;
214
215 if (helper == NULL)
216 {
217 SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: helper is NULL\n");
218 return(NULL);
219 }
220
221 if (!SUNMemoryHelper_ImplementsRequiredOps(helper))
222 {
223 SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: helper doesn't implement all required ops\n");
224 return(NULL);
225 }
226
227 v = NULL;
228 v = N_VNewEmpty_Raja();
229 if (v == NULL) return(NULL);
230
231 NVEC_RAJA_CONTENT(v)->length = length;
232 NVEC_RAJA_CONTENT(v)->mem_helper = helper;
233 NVEC_RAJA_CONTENT(v)->own_helper = SUNFALSE;
234 NVEC_RAJA_CONTENT(v)->host_data = NULL;
235 NVEC_RAJA_CONTENT(v)->device_data = NULL;
236 NVEC_RAJA_PRIVATE(v)->use_managed_mem = use_managed_mem;
237
238 if (AllocateData(v))
239 {
240 SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: AllocateData returned nonzero\n");
241 N_VDestroy(v);
242 return(NULL);
243 }
244
245 return(v);
246 }
247
N_VNewManaged_Raja(sunindextype length)248 N_Vector N_VNewManaged_Raja(sunindextype length)
249 {
250 N_Vector v;
251
252 v = NULL;
253 v = N_VNewEmpty_Raja();
254 if (v == NULL) return(NULL);
255
256 NVEC_RAJA_CONTENT(v)->length = length;
257 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
258 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Cuda();
259 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
260 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Hip();
261 #endif
262 NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
263 NVEC_RAJA_CONTENT(v)->host_data = NULL;
264 NVEC_RAJA_CONTENT(v)->device_data = NULL;
265 NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNTRUE;
266
267 if (NVEC_RAJA_MEMHELP(v) == NULL)
268 {
269 SUNDIALS_DEBUG_PRINT("ERROR in N_VNewManaged_Raja: memory helper is NULL\n");
270 N_VDestroy(v);
271 return(NULL);
272 }
273
274 if (AllocateData(v))
275 {
276 SUNDIALS_DEBUG_PRINT("ERROR in N_VNewManaged_Raja: AllocateData returned nonzero\n");
277 N_VDestroy(v);
278 return NULL;
279 }
280
281 return(v);
282 }
283
N_VMake_Raja(sunindextype length,realtype * h_vdata,realtype * d_vdata)284 N_Vector N_VMake_Raja(sunindextype length, realtype *h_vdata, realtype *d_vdata)
285 {
286 N_Vector v;
287
288 if (h_vdata == NULL || d_vdata == NULL) return(NULL);
289
290 v = NULL;
291 v = N_VNewEmpty_Raja();
292 if (v == NULL) return(NULL);
293
294 NVEC_RAJA_CONTENT(v)->length = length;
295 NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap(h_vdata, SUNMEMTYPE_HOST);
296 NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Wrap(d_vdata, SUNMEMTYPE_DEVICE);
297 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
298 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Cuda();
299 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
300 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Hip();
301 #endif
302 NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
303 NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
304
305 if (NVEC_RAJA_MEMHELP(v) == NULL)
306 {
307 SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: memory helper is NULL\n");
308 N_VDestroy(v);
309 return(NULL);
310 }
311
312
313 if (NVEC_RAJA_CONTENT(v)->device_data == NULL ||
314 NVEC_RAJA_CONTENT(v)->host_data == NULL)
315 {
316 SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: SUNMemoryHelper_Wrap returned NULL\n");
317 N_VDestroy(v);
318 return(NULL);
319 }
320
321 return(v);
322 }
323
N_VMakeManaged_Raja(sunindextype length,realtype * vdata)324 N_Vector N_VMakeManaged_Raja(sunindextype length, realtype *vdata)
325 {
326 N_Vector v;
327
328 if (vdata == NULL) return(NULL);
329
330 v = NULL;
331 v = N_VNewEmpty_Raja();
332 if (v == NULL) return(NULL);
333
334 NVEC_RAJA_CONTENT(v)->length = length;
335 NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap(vdata, SUNMEMTYPE_UVM);
336 NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->host_data);
337 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
338 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Cuda();
339 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
340 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Hip();
341 #endif
342 NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
343 NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNTRUE;
344
345 if (NVEC_RAJA_MEMHELP(v) == NULL)
346 {
347 SUNDIALS_DEBUG_PRINT("ERROR in N_VMakeManaged_Raja: memory helper is NULL\n");
348 N_VDestroy(v);
349 return(NULL);
350 }
351
352 if (NVEC_RAJA_CONTENT(v)->device_data == NULL ||
353 NVEC_RAJA_CONTENT(v)->host_data == NULL)
354 {
355 SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: SUNMemoryHelper_Wrap returned NULL\n");
356 N_VDestroy(v);
357 return(NULL);
358 }
359
360 return(v);
361 }
362
363 /* -----------------------------------------------------------------
364 * Function to return the global length of the vector.
365 * This is defined as an inline function in nvector_raja.h, so
366 * we just mark it as extern here.
367 */
368 extern sunindextype N_VGetLength_Raja(N_Vector v);
369
370 /* ----------------------------------------------------------------------------
371 * Return pointer to the raw host data.
372 * This is defined as an inline function in nvector_raja.h, so
373 * we just mark it as extern here.
374 */
375
376 extern realtype *N_VGetHostArrayPointer_Raja(N_Vector x);
377
378 /* ----------------------------------------------------------------------------
379 * Return pointer to the raw device data.
380 * This is defined as an inline function in nvector_raja.h, so
381 * we just mark it as extern here.
382 */
383
384 extern realtype *N_VGetDeviceArrayPointer_Raja(N_Vector x);
385
386
387 /* ----------------------------------------------------------------------------
388 * Set pointer to the raw host data. Does not free the existing pointer.
389 */
390
N_VSetHostArrayPointer_Raja(realtype * h_vdata,N_Vector v)391 void N_VSetHostArrayPointer_Raja(realtype* h_vdata, N_Vector v)
392 {
393 if (N_VIsManagedMemory_Raja(v))
394 {
395 if (NVEC_RAJA_CONTENT(v)->host_data)
396 {
397 NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) h_vdata;
398 NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) h_vdata;
399 }
400 else
401 {
402 NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap((void*) h_vdata, SUNMEMTYPE_UVM);
403 NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->host_data);
404 }
405 }
406 else
407 {
408 if (NVEC_RAJA_CONTENT(v)->host_data)
409 {
410 NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) h_vdata;
411 }
412 else
413 {
414 NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap((void*) h_vdata, SUNMEMTYPE_HOST);
415 }
416 }
417 }
418
419 /* ----------------------------------------------------------------------------
420 * Set pointer to the raw device data
421 */
422
N_VSetDeviceArrayPointer_Raja(realtype * d_vdata,N_Vector v)423 void N_VSetDeviceArrayPointer_Raja(realtype* d_vdata, N_Vector v)
424 {
425 if (N_VIsManagedMemory_Raja(v))
426 {
427 if (NVEC_RAJA_CONTENT(v)->device_data)
428 {
429 NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) d_vdata;
430 NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) d_vdata;
431 }
432 else
433 {
434 NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Wrap((void*) d_vdata, SUNMEMTYPE_UVM);
435 NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->device_data);
436 }
437 }
438 else
439 {
440 if (NVEC_RAJA_CONTENT(v)->device_data)
441 {
442 NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) d_vdata;
443 }
444 else
445 {
446 NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Wrap((void*) d_vdata, SUNMEMTYPE_DEVICE);
447 }
448 }
449 }
450
451 /* ----------------------------------------------------------------------------
452 * Return a flag indicating if the memory for the vector data is managed
453 */
N_VIsManagedMemory_Raja(N_Vector x)454 booleantype N_VIsManagedMemory_Raja(N_Vector x)
455 {
456 return NVEC_RAJA_PRIVATE(x)->use_managed_mem;
457 }
458
459 /* ----------------------------------------------------------------------------
460 * Copy vector data to the device
461 */
462
N_VCopyToDevice_Raja(N_Vector x)463 void N_VCopyToDevice_Raja(N_Vector x)
464 {
465 int copy_fail;
466
467 copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
468 NVEC_RAJA_CONTENT(x)->device_data,
469 NVEC_RAJA_CONTENT(x)->host_data,
470 NVEC_RAJA_MEMSIZE(x),
471 0);
472
473 if (copy_fail)
474 {
475 SUNDIALS_DEBUG_PRINT("ERROR in N_VCopyToDevice_Raja: SUNMemoryHelper_CopyAsync returned nonzero\n");
476 }
477
478 /* we synchronize with respect to the host, but on the default stream currently */
479 SUNDIALS_GPU_VERIFY(SUNDIALS_GPU_PREFIX(StreamSynchronize)(0));
480 }
481
482 /* ----------------------------------------------------------------------------
483 * Copy vector data from the device to the host
484 */
485
N_VCopyFromDevice_Raja(N_Vector x)486 void N_VCopyFromDevice_Raja(N_Vector x)
487 {
488 int copy_fail;
489
490 copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
491 NVEC_RAJA_CONTENT(x)->host_data,
492 NVEC_RAJA_CONTENT(x)->device_data,
493 NVEC_RAJA_MEMSIZE(x),
494 0);
495
496 if (copy_fail)
497 {
498 SUNDIALS_DEBUG_PRINT("ERROR in N_VCopyFromDevice_Raja: SUNMemoryHelper_CopyAsync returned nonzero\n");
499 }
500
501 /* we synchronize with respect to the host, but only in this stream */
502 SUNDIALS_GPU_VERIFY(SUNDIALS_GPU_PREFIX(StreamSynchronize)(0));
503 }
504
505 /* ----------------------------------------------------------------------------
506 * Function to print the a serial vector to stdout
507 */
508
N_VPrint_Raja(N_Vector X)509 void N_VPrint_Raja(N_Vector X)
510 {
511 N_VPrintFile_Raja(X, stdout);
512 }
513
514 /* ----------------------------------------------------------------------------
515 * Function to print the a serial vector to outfile
516 */
517
N_VPrintFile_Raja(N_Vector X,FILE * outfile)518 void N_VPrintFile_Raja(N_Vector X, FILE *outfile)
519 {
520 sunindextype i;
521
522 for (i = 0; i < NVEC_RAJA_CONTENT(X)->length; i++) {
523 #if defined(SUNDIALS_EXTENDED_PRECISION)
524 fprintf(outfile, "%35.32Lg\n", NVEC_RAJA_HDATAp(X)[i]);
525 #elif defined(SUNDIALS_DOUBLE_PRECISION)
526 fprintf(outfile, "%19.16g\n", NVEC_RAJA_HDATAp(X)[i]);
527 #else
528 fprintf(outfile, "%11.8g\n", NVEC_RAJA_HDATAp(X)[i]);
529 #endif
530 }
531 fprintf(outfile, "\n");
532
533 return;
534 }
535
536 /*
537 * -----------------------------------------------------------------
538 * implementation of vector operations
539 * -----------------------------------------------------------------
540 */
541
N_VCloneEmpty_Raja(N_Vector w)542 N_Vector N_VCloneEmpty_Raja(N_Vector w)
543 {
544 N_Vector v;
545
546 if (w == NULL) return(NULL);
547
548 /* Create vector */
549 v = NULL;
550 v = N_VNewEmpty_Raja();
551 if (v == NULL) return(NULL);
552
553 /* Attach operations */
554 if (N_VCopyOps(w, v)) { N_VDestroy(v); return(NULL); }
555
556 /* Set content */
557 NVEC_RAJA_CONTENT(v)->length = NVEC_RAJA_CONTENT(w)->length;
558 NVEC_RAJA_CONTENT(v)->host_data = NULL;
559 NVEC_RAJA_CONTENT(v)->device_data = NULL;
560 NVEC_RAJA_PRIVATE(v)->use_managed_mem = NVEC_RAJA_PRIVATE(w)->use_managed_mem;
561
562
563 return(v);
564 }
565
N_VClone_Raja(N_Vector w)566 N_Vector N_VClone_Raja(N_Vector w)
567 {
568 N_Vector v;
569 v = NULL;
570 v = N_VCloneEmpty_Raja(w);
571 if (v == NULL) return(NULL);
572
573 NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Clone(NVEC_RAJA_MEMHELP(w));
574 NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
575
576 if (AllocateData(v))
577 {
578 SUNDIALS_DEBUG_PRINT("ERROR in N_VClone_Raja: AllocateData returned nonzero\n");
579 N_VDestroy(v);
580 return NULL;
581 }
582
583 return(v);
584
585 }
586
587
N_VDestroy_Raja(N_Vector v)588 void N_VDestroy_Raja(N_Vector v)
589 {
590 N_VectorContent_Raja vc;
591 N_PrivateVectorContent_Raja vcp;
592
593 if (v == NULL) return;
594
595 /* free ops structure */
596 if (v->ops != NULL)
597 {
598 free(v->ops);
599 v->ops = NULL;
600 }
601
602 /* extract content */
603 vc = NVEC_RAJA_CONTENT(v);
604 if (vc == NULL)
605 {
606 free(v);
607 v = NULL;
608 return;
609 }
610
611 /* free private content */
612 vcp = (N_PrivateVectorContent_Raja) vc->priv;
613 if (vcp != NULL)
614 {
615 /* free items in private content */
616 free(vcp);
617 vc->priv = NULL;
618 }
619
620 /* free items in content */
621 if (NVEC_RAJA_MEMHELP(v))
622 {
623 SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(v), vc->host_data);
624 vc->host_data = NULL;
625 SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(v), vc->device_data);
626 vc->device_data = NULL;
627 if (vc->own_helper) SUNMemoryHelper_Destroy(vc->mem_helper);
628 vc->mem_helper = NULL;
629 }
630 else
631 {
632 SUNDIALS_DEBUG_PRINT("WARNING in N_VDestroy_Raja: mem_helper was NULL when trying to dealloc data, this could result in a memory leak\n");
633 }
634
635 /* free content struct */
636 free(vc);
637
638 /* free vector */
639 free(v);
640
641 return;
642 }
643
N_VSpace_Raja(N_Vector X,sunindextype * lrw,sunindextype * liw)644 void N_VSpace_Raja(N_Vector X, sunindextype *lrw, sunindextype *liw)
645 {
646 *lrw = NVEC_RAJA_CONTENT(X)->length;
647 *liw = 2;
648 }
649
N_VConst_Raja(realtype c,N_Vector Z)650 void N_VConst_Raja(realtype c, N_Vector Z)
651 {
652 const sunindextype N = NVEC_RAJA_CONTENT(Z)->length;
653 realtype *zdata = NVEC_RAJA_DDATAp(Z);
654
655 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N), RAJA_LAMBDA(sunindextype i) {
656 zdata[i] = c;
657 });
658 }
659
N_VLinearSum_Raja(realtype a,N_Vector X,realtype b,N_Vector Y,N_Vector Z)660 void N_VLinearSum_Raja(realtype a, N_Vector X, realtype b, N_Vector Y, N_Vector Z)
661 {
662 const realtype *xdata = NVEC_RAJA_DDATAp(X);
663 const realtype *ydata = NVEC_RAJA_DDATAp(Y);
664 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
665 realtype *zdata = NVEC_RAJA_DDATAp(Z);
666
667 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
668 RAJA_LAMBDA(sunindextype i) {
669 zdata[i] = a*xdata[i] + b*ydata[i];
670 }
671 );
672 }
673
N_VProd_Raja(N_Vector X,N_Vector Y,N_Vector Z)674 void N_VProd_Raja(N_Vector X, N_Vector Y, N_Vector Z)
675 {
676 const realtype *xdata = NVEC_RAJA_DDATAp(X);
677 const realtype *ydata = NVEC_RAJA_DDATAp(Y);
678 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
679 realtype *zdata = NVEC_RAJA_DDATAp(Z);
680
681 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
682 RAJA_LAMBDA(sunindextype i) {
683 zdata[i] = xdata[i] * ydata[i];
684 }
685 );
686 }
687
N_VDiv_Raja(N_Vector X,N_Vector Y,N_Vector Z)688 void N_VDiv_Raja(N_Vector X, N_Vector Y, N_Vector Z)
689 {
690 const realtype *xdata = NVEC_RAJA_DDATAp(X);
691 const realtype *ydata = NVEC_RAJA_DDATAp(Y);
692 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
693 realtype *zdata = NVEC_RAJA_DDATAp(Z);
694
695 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
696 RAJA_LAMBDA(sunindextype i) {
697 zdata[i] = xdata[i] / ydata[i];
698 }
699 );
700 }
701
N_VScale_Raja(realtype c,N_Vector X,N_Vector Z)702 void N_VScale_Raja(realtype c, N_Vector X, N_Vector Z)
703 {
704 const realtype *xdata = NVEC_RAJA_DDATAp(X);
705 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
706 realtype *zdata = NVEC_RAJA_DDATAp(Z);
707
708 RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
709 RAJA_LAMBDA(sunindextype i) {
710 zdata[i] = c * xdata[i];
711 }
712 );
713 }
714
N_VAbs_Raja(N_Vector X,N_Vector Z)715 void N_VAbs_Raja(N_Vector X, N_Vector Z)
716 {
717 const realtype *xdata = NVEC_RAJA_DDATAp(X);
718 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
719 realtype *zdata = NVEC_RAJA_DDATAp(Z);
720
721 RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
722 RAJA_LAMBDA(sunindextype i) {
723 zdata[i] = abs(xdata[i]);
724 }
725 );
726 }
727
N_VInv_Raja(N_Vector X,N_Vector Z)728 void N_VInv_Raja(N_Vector X, N_Vector Z)
729 {
730 const realtype *xdata = NVEC_RAJA_DDATAp(X);
731 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
732 realtype *zdata = NVEC_RAJA_DDATAp(Z);
733
734 RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
735 RAJA_LAMBDA(sunindextype i) {
736 zdata[i] = ONE / xdata[i];
737 }
738 );
739 }
740
N_VAddConst_Raja(N_Vector X,realtype b,N_Vector Z)741 void N_VAddConst_Raja(N_Vector X, realtype b, N_Vector Z)
742 {
743 const realtype *xdata = NVEC_RAJA_DDATAp(X);
744 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
745 realtype *zdata = NVEC_RAJA_DDATAp(Z);
746
747 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
748 RAJA_LAMBDA(sunindextype i) {
749 zdata[i] = xdata[i] + b;
750 }
751 );
752 }
753
N_VDotProd_Raja(N_Vector X,N_Vector Y)754 realtype N_VDotProd_Raja(N_Vector X, N_Vector Y)
755 {
756 const realtype *xdata = NVEC_RAJA_DDATAp(X);
757 const realtype *ydata = NVEC_RAJA_DDATAp(Y);
758 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
759
760 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
761 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
762 RAJA_LAMBDA(sunindextype i) {
763 gpu_result += xdata[i] * ydata[i] ;
764 }
765 );
766
767 return (static_cast<realtype>(gpu_result));
768 }
769
N_VMaxNorm_Raja(N_Vector X)770 realtype N_VMaxNorm_Raja(N_Vector X)
771 {
772 const realtype *xdata = NVEC_RAJA_DDATAp(X);
773 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
774
775 RAJA::ReduceMax< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
776 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
777 RAJA_LAMBDA(sunindextype i) {
778 gpu_result.max(abs(xdata[i]));
779 }
780 );
781
782 return (static_cast<realtype>(gpu_result));
783 }
784
N_VWSqrSumLocal_Raja(N_Vector X,N_Vector W)785 realtype N_VWSqrSumLocal_Raja(N_Vector X, N_Vector W)
786 {
787 const realtype *xdata = NVEC_RAJA_DDATAp(X);
788 const realtype *wdata = NVEC_RAJA_DDATAp(W);
789 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
790
791 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
792 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
793 RAJA_LAMBDA(sunindextype i) {
794 gpu_result += (xdata[i] * wdata[i] * xdata[i] * wdata[i]);
795 }
796 );
797
798 return (static_cast<realtype>(gpu_result));
799 }
800
N_VWrmsNorm_Raja(N_Vector X,N_Vector W)801 realtype N_VWrmsNorm_Raja(N_Vector X, N_Vector W)
802 {
803 const realtype sum = N_VWSqrSumLocal_Raja(X, W);
804 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
805 return std::sqrt(sum/N);
806 }
807
N_VWSqrSumMaskLocal_Raja(N_Vector X,N_Vector W,N_Vector ID)808 realtype N_VWSqrSumMaskLocal_Raja(N_Vector X, N_Vector W, N_Vector ID)
809 {
810 const realtype *xdata = NVEC_RAJA_DDATAp(X);
811 const realtype *wdata = NVEC_RAJA_DDATAp(W);
812 const realtype *iddata = NVEC_RAJA_DDATAp(ID);
813 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
814
815 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
816 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
817 RAJA_LAMBDA(sunindextype i) {
818 if (iddata[i] > ZERO)
819 gpu_result += (xdata[i] * wdata[i] * xdata[i] * wdata[i]);
820 }
821 );
822
823 return (static_cast<realtype>(gpu_result));
824 }
825
N_VWrmsNormMask_Raja(N_Vector X,N_Vector W,N_Vector ID)826 realtype N_VWrmsNormMask_Raja(N_Vector X, N_Vector W, N_Vector ID)
827 {
828 const realtype sum = N_VWSqrSumMaskLocal_Raja(X, W, ID);
829 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
830 return std::sqrt(sum/N);
831 }
832
N_VMin_Raja(N_Vector X)833 realtype N_VMin_Raja(N_Vector X)
834 {
835 const realtype *xdata = NVEC_RAJA_DDATAp(X);
836 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
837
838 RAJA::ReduceMin< RAJA_REDUCE_TYPE, realtype> gpu_result(std::numeric_limits<realtype>::max());
839 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
840 RAJA_LAMBDA(sunindextype i) {
841 gpu_result.min(xdata[i]);
842 }
843 );
844
845 return (static_cast<realtype>(gpu_result));
846 }
847
N_VWL2Norm_Raja(N_Vector X,N_Vector W)848 realtype N_VWL2Norm_Raja(N_Vector X, N_Vector W)
849 {
850 return std::sqrt(N_VWSqrSumLocal_Raja(X, W));
851 }
852
N_VL1Norm_Raja(N_Vector X)853 realtype N_VL1Norm_Raja(N_Vector X)
854 {
855 const realtype *xdata = NVEC_RAJA_DDATAp(X);
856 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
857
858 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
859 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
860 RAJA_LAMBDA(sunindextype i) {
861 gpu_result += (abs(xdata[i]));
862 }
863 );
864
865 return (static_cast<realtype>(gpu_result));
866 }
867
N_VCompare_Raja(realtype c,N_Vector X,N_Vector Z)868 void N_VCompare_Raja(realtype c, N_Vector X, N_Vector Z)
869 {
870 const realtype *xdata = NVEC_RAJA_DDATAp(X);
871 const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
872 realtype *zdata = NVEC_RAJA_DDATAp(Z);
873
874 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
875 RAJA_LAMBDA(sunindextype i) {
876 zdata[i] = abs(xdata[i]) >= c ? ONE : ZERO;
877 }
878 );
879 }
880
N_VInvTest_Raja(N_Vector x,N_Vector z)881 booleantype N_VInvTest_Raja(N_Vector x, N_Vector z)
882 {
883 const realtype *xdata = NVEC_RAJA_DDATAp(x);
884 const sunindextype N = NVEC_RAJA_CONTENT(x)->length;
885 realtype *zdata = NVEC_RAJA_DDATAp(z);
886
887 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(ZERO);
888 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
889 RAJA_LAMBDA(sunindextype i) {
890 if (xdata[i] == ZERO) {
891 gpu_result += ONE;
892 } else {
893 zdata[i] = ONE/xdata[i];
894 }
895 }
896 );
897 realtype minimum = static_cast<realtype>(gpu_result);
898 return (minimum < HALF);
899 }
900
N_VConstrMask_Raja(N_Vector c,N_Vector x,N_Vector m)901 booleantype N_VConstrMask_Raja(N_Vector c, N_Vector x, N_Vector m)
902 {
903 const realtype *cdata = NVEC_RAJA_DDATAp(c);
904 const realtype *xdata = NVEC_RAJA_DDATAp(x);
905 const sunindextype N = NVEC_RAJA_CONTENT(x)->length;
906 realtype *mdata = NVEC_RAJA_DDATAp(m);
907
908 RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(ZERO);
909 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
910 RAJA_LAMBDA(sunindextype i) {
911 bool test = (abs(cdata[i]) > ONEPT5 && cdata[i]*xdata[i] <= ZERO) ||
912 (abs(cdata[i]) > HALF && cdata[i]*xdata[i] < ZERO);
913 mdata[i] = test ? ONE : ZERO;
914 gpu_result += mdata[i];
915 }
916 );
917
918 realtype sum = static_cast<realtype>(gpu_result);
919 return(sum < HALF);
920 }
921
N_VMinQuotient_Raja(N_Vector num,N_Vector denom)922 realtype N_VMinQuotient_Raja(N_Vector num, N_Vector denom)
923 {
924 const realtype *ndata = NVEC_RAJA_DDATAp(num);
925 const realtype *ddata = NVEC_RAJA_DDATAp(denom);
926 const sunindextype N = NVEC_RAJA_CONTENT(num)->length;
927
928 RAJA::ReduceMin< RAJA_REDUCE_TYPE, realtype> gpu_result(std::numeric_limits<realtype>::max());
929 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
930 RAJA_LAMBDA(sunindextype i) {
931 if (ddata[i] != ZERO)
932 gpu_result.min(ndata[i]/ddata[i]);
933 }
934 );
935 return (static_cast<realtype>(gpu_result));
936 }
937
938
939 /*
940 * -----------------------------------------------------------------------------
941 * fused vector operations
942 * -----------------------------------------------------------------------------
943 */
944
N_VLinearCombination_Raja(int nvec,realtype * c,N_Vector * X,N_Vector z)945 int N_VLinearCombination_Raja(int nvec, realtype* c, N_Vector* X, N_Vector z)
946 {
947 int retval;
948
949 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(X[0]);
950 sunindextype N = NVEC_RAJA_CONTENT(z)->length;
951 realtype* d_zd = NVEC_RAJA_DDATAp(z);
952
953 // Create device c array for device
954 SUNMemory h_c, d_c;
955 h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
956 retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
957 if (retval) return(-1);
958
959 // Copy c array to device
960 retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
961 if (retval) return(-1);
962
963 SUNMemory d_X;
964 realtype** d_Xd;
965 CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
966
967 // Shortcut to the arrays to work on
968 realtype* d_cd = (realtype*) d_c->ptr;
969 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
970 RAJA_LAMBDA(sunindextype i) {
971 d_zd[i] = d_cd[0] * d_Xd[0][i];
972 for (int j=1; j<nvec; j++)
973 d_zd[i] += d_cd[j] * d_Xd[j][i];
974 }
975 );
976
977 SUNMemoryHelper_Dealloc(h, h_c);
978 SUNMemoryHelper_Dealloc(h, d_c);
979 SUNMemoryHelper_Dealloc(h, d_X);
980
981 return(0);
982 }
983
984
N_VScaleAddMulti_Raja(int nvec,realtype * c,N_Vector x,N_Vector * Y,N_Vector * Z)985 int N_VScaleAddMulti_Raja(int nvec, realtype* c, N_Vector x, N_Vector* Y, N_Vector* Z)
986 {
987 int retval;
988
989 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(x);
990 sunindextype N = NVEC_RAJA_CONTENT(x)->length;
991 realtype* d_xd = NVEC_RAJA_DDATAp(x);
992
993 // Create c array for device
994 SUNMemory h_c, d_c;
995 h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
996 retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
997 if (retval) return(-1);
998
999 // Copy c array to device
1000 retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
1001 if (retval) return(-1);
1002
1003 SUNMemory d_Y, d_Z;
1004 realtype **d_Yd, **d_Zd;
1005 CreateArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, Y);
1006 CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1007
1008 // Perform operation
1009 realtype* d_cd = (realtype*) d_c->ptr;
1010 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1011 RAJA_LAMBDA(sunindextype i) {
1012 for (int j=0; j<nvec; j++)
1013 d_Zd[j][i] = d_cd[j] * d_xd[i] + d_Yd[j][i];
1014 }
1015 );
1016
1017 SUNMemoryHelper_Dealloc(h, h_c);
1018 SUNMemoryHelper_Dealloc(h, d_c);
1019 SUNMemoryHelper_Dealloc(h, d_Y);
1020 SUNMemoryHelper_Dealloc(h, d_Z);
1021
1022 return(0);
1023 }
1024
1025
1026 /*
1027 * -----------------------------------------------------------------------------
1028 * vector array operations
1029 * -----------------------------------------------------------------------------
1030 */
1031
N_VLinearSumVectorArray_Raja(int nvec,realtype a,N_Vector * X,realtype b,N_Vector * Y,N_Vector * Z)1032 int N_VLinearSumVectorArray_Raja(int nvec,
1033 realtype a, N_Vector* X,
1034 realtype b, N_Vector* Y,
1035 N_Vector* Z)
1036 {
1037 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1038 sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1039
1040 SUNMemory d_X, d_Y, d_Z;
1041 realtype **d_Xd, **d_Yd, **d_Zd;
1042 CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1043 CreateArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, Y);
1044 CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1045
1046 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1047 RAJA_LAMBDA(sunindextype i) {
1048 for (int j=0; j<nvec; j++)
1049 d_Zd[j][i] = a * d_Xd[j][i] + b * d_Yd[j][i];
1050 }
1051 );
1052
1053 SUNMemoryHelper_Dealloc(h, d_X);
1054 SUNMemoryHelper_Dealloc(h, d_Y);
1055 SUNMemoryHelper_Dealloc(h, d_Z);
1056
1057 return(0);
1058 }
1059
1060
N_VScaleVectorArray_Raja(int nvec,realtype * c,N_Vector * X,N_Vector * Z)1061 int N_VScaleVectorArray_Raja(int nvec, realtype* c, N_Vector* X, N_Vector* Z)
1062 {
1063 int retval;
1064
1065 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1066 sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1067
1068 // Create c array for device
1069 SUNMemory h_c, d_c;
1070 h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1071 retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
1072 if (retval) return(-1);
1073
1074 // Copy c array to device
1075 retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
1076 if (retval) return(-1);
1077
1078 SUNMemory d_X, d_Z;
1079 realtype **d_Xd, **d_Zd;
1080 CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1081 CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1082
1083 realtype* d_cd = (realtype*) d_c->ptr;
1084 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1085 RAJA_LAMBDA(sunindextype i) {
1086 for (int j=0; j<nvec; j++)
1087 d_Zd[j][i] = d_cd[j] * d_Xd[j][i];
1088 }
1089 );
1090
1091 SUNMemoryHelper_Dealloc(h, h_c);
1092 SUNMemoryHelper_Dealloc(h, d_c);
1093 SUNMemoryHelper_Dealloc(h, d_X);
1094 SUNMemoryHelper_Dealloc(h, d_Z);
1095
1096 return(0);
1097 }
1098
1099
N_VConstVectorArray_Raja(int nvec,realtype c,N_Vector * Z)1100 int N_VConstVectorArray_Raja(int nvec, realtype c, N_Vector* Z)
1101 {
1102 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1103 sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1104
1105 SUNMemory d_Z;
1106 realtype** d_Zd;
1107 CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1108
1109 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1110 RAJA_LAMBDA(sunindextype i) {
1111 for (int j=0; j<nvec; j++)
1112 d_Zd[j][i] = c;
1113 }
1114 );
1115
1116 SUNMemoryHelper_Dealloc(h, d_Z);
1117
1118 return(0);
1119 }
1120
1121
N_VScaleAddMultiVectorArray_Raja(int nvec,int nsum,realtype * c,N_Vector * X,N_Vector ** Y,N_Vector ** Z)1122 int N_VScaleAddMultiVectorArray_Raja(int nvec, int nsum, realtype* c,
1123 N_Vector* X, N_Vector** Y, N_Vector** Z)
1124 {
1125 int retval;
1126
1127 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(X[0]);
1128 sunindextype N = NVEC_RAJA_CONTENT(X[0])->length;
1129
1130 // Create c array for device
1131 SUNMemory h_c, d_c;
1132 h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1133 retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nsum, SUNMEMTYPE_DEVICE);
1134 if (retval) return(-1);
1135
1136 // Copy c array to device
1137 retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nsum);
1138 if (retval) return(-1);
1139
1140 SUNMemory d_X, d_Y, d_Z;
1141 realtype **d_Xd, **d_Yd, **d_Zd;
1142 CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1143 Create2DArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, nsum, Y);
1144 Create2DArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, nsum, Z);
1145
1146 realtype* d_cd = (realtype*) d_c->ptr;
1147 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1148 RAJA_LAMBDA(sunindextype i) {
1149 for (int j=0; j<nvec; j++)
1150 for (int k=0; k<nsum; k++)
1151 d_Zd[j*nsum+k][i] = d_cd[k] * d_Xd[j][i] + d_Yd[j*nsum+k][i];
1152 }
1153 );
1154
1155 SUNMemoryHelper_Dealloc(h, h_c);
1156 SUNMemoryHelper_Dealloc(h, d_c);
1157 SUNMemoryHelper_Dealloc(h, d_X);
1158 SUNMemoryHelper_Dealloc(h, d_Y);
1159 SUNMemoryHelper_Dealloc(h, d_Z);
1160
1161 return(0);
1162 }
1163
1164
N_VLinearCombinationVectorArray_Raja(int nvec,int nsum,realtype * c,N_Vector ** X,N_Vector * Z)1165 int N_VLinearCombinationVectorArray_Raja(int nvec, int nsum, realtype* c,
1166 N_Vector** X, N_Vector* Z)
1167 {
1168 int retval;
1169
1170 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1171 sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1172
1173 // Create c array for device
1174 SUNMemory h_c, d_c;
1175 h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1176 retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nsum, SUNMEMTYPE_DEVICE);
1177 if (retval) return(-1);
1178
1179 // Copy c array to device
1180 retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nsum);
1181 if (retval) return(-1);
1182
1183 SUNMemory d_X, d_Z;
1184 realtype **d_Xd, **d_Zd;
1185 CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1186 Create2DArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, nsum, X);
1187
1188 realtype *d_cd = (realtype*) d_c->ptr;
1189 RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1190 RAJA_LAMBDA(sunindextype i) {
1191 for (int j=0; j<nvec; j++) {
1192 d_Zd[j][i] = d_cd[0] * d_Xd[j*nsum][i];
1193 for (int k=1; k<nsum; k++) {
1194 d_Zd[j][i] += d_cd[k] * d_Xd[j*nsum+k][i];
1195 }
1196 }
1197 }
1198 );
1199
1200 SUNMemoryHelper_Dealloc(h, h_c);
1201 SUNMemoryHelper_Dealloc(h, d_c);
1202 SUNMemoryHelper_Dealloc(h, d_X);
1203 SUNMemoryHelper_Dealloc(h, d_Z);
1204
1205 return(0);
1206 }
1207
1208
1209 /*
1210 * -----------------------------------------------------------------
1211 * OPTIONAL XBraid interface operations
1212 * -----------------------------------------------------------------
1213 */
1214
1215
N_VBufSize_Raja(N_Vector x,sunindextype * size)1216 int N_VBufSize_Raja(N_Vector x, sunindextype *size)
1217 {
1218 if (x == NULL) return(-1);
1219 *size = (sunindextype)NVEC_RAJA_MEMSIZE(x);
1220 return(0);
1221 }
1222
1223
N_VBufPack_Raja(N_Vector x,void * buf)1224 int N_VBufPack_Raja(N_Vector x, void *buf)
1225 {
1226 int copy_fail = 0;
1227 SUNDIALS_GPU_PREFIX(Error_t) cuerr;
1228
1229 if (x == NULL || buf == NULL) return(-1);
1230
1231 SUNMemory buf_mem = SUNMemoryHelper_Wrap(buf, SUNMEMTYPE_HOST);
1232 if (buf_mem == NULL) return(-1);
1233
1234 copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
1235 buf_mem,
1236 NVEC_RAJA_CONTENT(x)->device_data,
1237 NVEC_RAJA_MEMSIZE(x),
1238 0);
1239
1240 /* we synchronize with respect to the host, but only in this stream */
1241 cuerr = SUNDIALS_GPU_PREFIX(StreamSynchronize)(0);
1242
1243 SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(x), buf_mem);
1244
1245 return (!SUNDIALS_GPU_VERIFY(cuerr) || copy_fail ? -1 : 0);
1246 }
1247
1248
N_VBufUnpack_Raja(N_Vector x,void * buf)1249 int N_VBufUnpack_Raja(N_Vector x, void *buf)
1250 {
1251 int copy_fail = 0;
1252 SUNDIALS_GPU_PREFIX(Error_t) cuerr;
1253
1254 if (x == NULL || buf == NULL) return(-1);
1255
1256 SUNMemory buf_mem = SUNMemoryHelper_Wrap(buf, SUNMEMTYPE_HOST);
1257 if (buf_mem == NULL) return(-1);
1258
1259 copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
1260 NVEC_RAJA_CONTENT(x)->device_data,
1261 buf_mem,
1262 NVEC_RAJA_MEMSIZE(x),
1263 0);
1264
1265 /* we synchronize with respect to the host, but only in this stream */
1266 cuerr = SUNDIALS_GPU_PREFIX(StreamSynchronize)(0);
1267
1268 SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(x), buf_mem);
1269
1270 return (!SUNDIALS_GPU_VERIFY(cuerr) || copy_fail ? -1 : 0);
1271 }
1272
1273
1274 /*
1275 * -----------------------------------------------------------------
1276 * Enable / Disable fused and vector array operations
1277 * -----------------------------------------------------------------
1278 */
1279
N_VEnableFusedOps_Raja(N_Vector v,booleantype tf)1280 int N_VEnableFusedOps_Raja(N_Vector v, booleantype tf)
1281 {
1282 /* check that vector is non-NULL */
1283 if (v == NULL) return(-1);
1284
1285 /* check that ops structure is non-NULL */
1286 if (v->ops == NULL) return(-1);
1287
1288 if (tf) {
1289 /* enable all fused vector operations */
1290 v->ops->nvlinearcombination = N_VLinearCombination_Raja;
1291 v->ops->nvscaleaddmulti = N_VScaleAddMulti_Raja;
1292 v->ops->nvdotprodmulti = NULL;
1293 /* enable all vector array operations */
1294 v->ops->nvlinearsumvectorarray = N_VLinearSumVectorArray_Raja;
1295 v->ops->nvscalevectorarray = N_VScaleVectorArray_Raja;
1296 v->ops->nvconstvectorarray = N_VConstVectorArray_Raja;
1297 v->ops->nvwrmsnormvectorarray = NULL;
1298 v->ops->nvwrmsnormmaskvectorarray = NULL;
1299 v->ops->nvscaleaddmultivectorarray = N_VScaleAddMultiVectorArray_Raja;
1300 v->ops->nvlinearcombinationvectorarray = N_VLinearCombinationVectorArray_Raja;
1301 } else {
1302 /* disable all fused vector operations */
1303 v->ops->nvlinearcombination = NULL;
1304 v->ops->nvscaleaddmulti = NULL;
1305 v->ops->nvdotprodmulti = NULL;
1306 /* disable all vector array operations */
1307 v->ops->nvlinearsumvectorarray = NULL;
1308 v->ops->nvscalevectorarray = NULL;
1309 v->ops->nvconstvectorarray = NULL;
1310 v->ops->nvwrmsnormvectorarray = NULL;
1311 v->ops->nvwrmsnormmaskvectorarray = NULL;
1312 v->ops->nvscaleaddmultivectorarray = NULL;
1313 v->ops->nvlinearcombinationvectorarray = NULL;
1314 }
1315
1316 /* return success */
1317 return(0);
1318 }
1319
N_VEnableLinearCombination_Raja(N_Vector v,booleantype tf)1320 int N_VEnableLinearCombination_Raja(N_Vector v, booleantype tf)
1321 {
1322 /* check that vector is non-NULL */
1323 if (v == NULL) return(-1);
1324
1325 /* check that ops structure is non-NULL */
1326 if (v->ops == NULL) return(-1);
1327
1328 /* enable/disable operation */
1329 if (tf)
1330 v->ops->nvlinearcombination = N_VLinearCombination_Raja;
1331 else
1332 v->ops->nvlinearcombination = NULL;
1333
1334 /* return success */
1335 return(0);
1336 }
1337
N_VEnableScaleAddMulti_Raja(N_Vector v,booleantype tf)1338 int N_VEnableScaleAddMulti_Raja(N_Vector v, booleantype tf)
1339 {
1340 /* check that vector is non-NULL */
1341 if (v == NULL) return(-1);
1342
1343 /* check that ops structure is non-NULL */
1344 if (v->ops == NULL) return(-1);
1345
1346 /* enable/disable operation */
1347 if (tf)
1348 v->ops->nvscaleaddmulti = N_VScaleAddMulti_Raja;
1349 else
1350 v->ops->nvscaleaddmulti = NULL;
1351
1352 /* return success */
1353 return(0);
1354 }
1355
N_VEnableLinearSumVectorArray_Raja(N_Vector v,booleantype tf)1356 int N_VEnableLinearSumVectorArray_Raja(N_Vector v, booleantype tf)
1357 {
1358 /* check that vector is non-NULL */
1359 if (v == NULL) return(-1);
1360
1361 /* check that ops structure is non-NULL */
1362 if (v->ops == NULL) return(-1);
1363
1364 /* enable/disable operation */
1365 if (tf)
1366 v->ops->nvlinearsumvectorarray = N_VLinearSumVectorArray_Raja;
1367 else
1368 v->ops->nvlinearsumvectorarray = NULL;
1369
1370 /* return success */
1371 return(0);
1372 }
1373
N_VEnableScaleVectorArray_Raja(N_Vector v,booleantype tf)1374 int N_VEnableScaleVectorArray_Raja(N_Vector v, booleantype tf)
1375 {
1376 /* check that vector is non-NULL */
1377 if (v == NULL) return(-1);
1378
1379 /* check that ops structure is non-NULL */
1380 if (v->ops == NULL) return(-1);
1381
1382 /* enable/disable operation */
1383 if (tf)
1384 v->ops->nvscalevectorarray = N_VScaleVectorArray_Raja;
1385 else
1386 v->ops->nvscalevectorarray = NULL;
1387
1388 /* return success */
1389 return(0);
1390 }
1391
N_VEnableConstVectorArray_Raja(N_Vector v,booleantype tf)1392 int N_VEnableConstVectorArray_Raja(N_Vector v, booleantype tf)
1393 {
1394 /* check that vector is non-NULL */
1395 if (v == NULL) return(-1);
1396
1397 /* check that ops structure is non-NULL */
1398 if (v->ops == NULL) return(-1);
1399
1400 /* enable/disable operation */
1401 if (tf)
1402 v->ops->nvconstvectorarray = N_VConstVectorArray_Raja;
1403 else
1404 v->ops->nvconstvectorarray = NULL;
1405
1406 /* return success */
1407 return(0);
1408 }
1409
N_VEnableScaleAddMultiVectorArray_Raja(N_Vector v,booleantype tf)1410 int N_VEnableScaleAddMultiVectorArray_Raja(N_Vector v, booleantype tf)
1411 {
1412 /* check that vector is non-NULL */
1413 if (v == NULL) return(-1);
1414
1415 /* check that ops structure is non-NULL */
1416 if (v->ops == NULL) return(-1);
1417
1418 /* enable/disable operation */
1419 if (tf)
1420 v->ops->nvscaleaddmultivectorarray = N_VScaleAddMultiVectorArray_Raja;
1421 else
1422 v->ops->nvscaleaddmultivectorarray = NULL;
1423
1424 /* return success */
1425 return(0);
1426 }
1427
N_VEnableLinearCombinationVectorArray_Raja(N_Vector v,booleantype tf)1428 int N_VEnableLinearCombinationVectorArray_Raja(N_Vector v, booleantype tf)
1429 {
1430 /* check that vector is non-NULL */
1431 if (v == NULL) return(-1);
1432
1433 /* check that ops structure is non-NULL */
1434 if (v->ops == NULL) return(-1);
1435
1436 /* enable/disable operation */
1437 if (tf)
1438 v->ops->nvlinearcombinationvectorarray = N_VLinearCombinationVectorArray_Raja;
1439 else
1440 v->ops->nvlinearcombinationvectorarray = NULL;
1441
1442 /* return success */
1443 return(0);
1444 }
1445
1446
1447 /*
1448 * -----------------------------------------------------------------
1449 * Private utility functions
1450 * -----------------------------------------------------------------
1451 */
1452
AllocateData(N_Vector v)1453 int AllocateData(N_Vector v)
1454 {
1455 int alloc_fail = 0;
1456 N_VectorContent_Raja vc = NVEC_RAJA_CONTENT(v);
1457 N_PrivateVectorContent_Raja vcp = NVEC_RAJA_PRIVATE(v);
1458
1459 if (N_VGetLength_Raja(v) == 0) return(0);
1460
1461 if (vcp->use_managed_mem)
1462 {
1463 alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->device_data),
1464 NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_UVM);
1465 if (alloc_fail)
1466 {
1467 SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed for SUNMEMTYPE_UVM\n");
1468 }
1469 vc->host_data = SUNMemoryHelper_Alias(vc->device_data);
1470 }
1471 else
1472 {
1473 alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->host_data),
1474 NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_HOST);
1475 if (alloc_fail)
1476 {
1477 SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed to alloc SUNMEMTYPE_HOST\n");
1478 }
1479
1480 alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->device_data),
1481 NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_DEVICE);
1482 if (alloc_fail)
1483 {
1484 SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed to alloc SUNMEMTYPE_DEVICE\n");
1485 }
1486 }
1487
1488 return(alloc_fail ? -1 : 0);
1489 }
1490
1491
CreateArrayOfPointersOnDevice(realtype *** d_ptrs,SUNMemory * d_ref,int nvec,N_Vector * V)1492 void CreateArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
1493 int nvec, N_Vector *V)
1494 {
1495 size_t bytes = sizeof(realtype*)*nvec;
1496 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(V[0]);
1497
1498 // Default return values
1499 *d_ref = nullptr;
1500 *d_ptrs = nullptr;
1501
1502 // Create space for host and device pointers
1503 SUNMemory h_mem;
1504 SUNMemoryHelper_Alloc(h, &h_mem, bytes, SUNMEMTYPE_HOST);
1505
1506 SUNMemory d_mem;
1507 SUNMemoryHelper_Alloc(h, &d_mem, bytes, SUNMEMTYPE_DEVICE);
1508
1509 // Fill the host memory with the pointers
1510 realtype** h_array = (realtype**) h_mem->ptr;
1511 for (int j=0; j<nvec; j++) {
1512 h_array[j] = NVEC_RAJA_DDATAp(V[j]);
1513 }
1514
1515 // Copy the host memory to the device
1516 SUNMemoryHelper_Copy(h, d_mem, h_mem, bytes);
1517
1518 // Return the device SUNMemory, and the raw pointer array
1519 *d_ref = d_mem;
1520 *d_ptrs = (realtype**) d_mem->ptr;
1521
1522 // Free the host SUNMemory
1523 SUNMemoryHelper_Dealloc(h, h_mem);
1524 }
1525
Create2DArrayOfPointersOnDevice(realtype *** d_ptrs,SUNMemory * d_ref,int nvec,int nsum,N_Vector ** V)1526 void Create2DArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
1527 int nvec, int nsum, N_Vector **V)
1528 {
1529 size_t bytes = sizeof(realtype*)*nsum*nvec;
1530 SUNMemoryHelper h = NVEC_RAJA_MEMHELP(V[0][0]);
1531
1532 // Default return values
1533 *d_ref = nullptr;
1534 *d_ptrs = nullptr;
1535
1536 // Create space for host and device pointers
1537 SUNMemory h_mem;
1538 SUNMemoryHelper_Alloc(h, &h_mem, bytes, SUNMEMTYPE_HOST);
1539
1540 SUNMemory d_mem;
1541 SUNMemoryHelper_Alloc(h, &d_mem, bytes, SUNMEMTYPE_DEVICE);
1542
1543 // Fill the host memory with the pointers
1544 realtype** h_array = (realtype**) h_mem->ptr;
1545 for (int j=0; j<nvec; j++)
1546 for (int k=0; k<nsum; k++)
1547 h_array[j*nsum+k] = NVEC_RAJA_DDATAp(V[k][j]);
1548
1549 // Copy the host memory to the device
1550 SUNMemoryHelper_Copy(h, d_mem, h_mem, bytes);
1551
1552 // Return the device SUNMemory, and the raw pointer array
1553 *d_ref = d_mem;
1554 *d_ptrs = (realtype**) d_mem->ptr;
1555
1556 // Free the host SUNMemory
1557 SUNMemoryHelper_Dealloc(h, h_mem);
1558 }
1559
1560 } // extern "C"
1561