1 /* -----------------------------------------------------------------
2  * Programmer(s): Slaven Peles, Cody J. Balos, Daniel McGreer @ LLNL
3  * -----------------------------------------------------------------
4  * SUNDIALS Copyright Start
5  * Copyright (c) 2002-2021, Lawrence Livermore National Security
6  * and Southern Methodist University.
7  * All rights reserved.
8  *
9  * See the top-level LICENSE and NOTICE files for details.
10  *
11  * SPDX-License-Identifier: BSD-3-Clause
12  * SUNDIALS Copyright End
13  * -----------------------------------------------------------------
14  * This is the implementation file for a RAJA implementation
15  * of the NVECTOR package. This will support CUDA and HIP
16  * -----------------------------------------------------------------*/
17 
18 #include <stdio.h>
19 #include <stdlib.h>
20 
21 #include <RAJA/RAJA.hpp>
22 #include <nvector/nvector_raja.h>
23 
24 #include "sundials_debug.h"
25 
26 // RAJA defines
27 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
28 #include <sunmemory/sunmemory_cuda.h>
29 #include "sundials_cuda.h"
30 #define RAJA_NODE_TYPE RAJA::cuda_exec< 256 >
31 #define RAJA_REDUCE_TYPE RAJA::cuda_reduce
32 #define SUNDIALS_GPU_PREFIX(val) cuda ## val
33 #define SUNDIALS_GPU_VERIFY SUNDIALS_CUDA_VERIFY
34 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
35 #include <sunmemory/sunmemory_hip.h>
36 #include "sundials_hip.h"
37 #define RAJA_NODE_TYPE RAJA::hip_exec< 512 >
38 #define RAJA_REDUCE_TYPE RAJA::hip_reduce
39 #define SUNDIALS_GPU_PREFIX(val) hip ## val
40 #define SUNDIALS_GPU_VERIFY SUNDIALS_HIP_VERIFY
41 #endif
42 
43 #define RAJA_LAMBDA [=] __device__
44 
45 #define ZERO   RCONST(0.0)
46 #define HALF   RCONST(0.5)
47 #define ONE    RCONST(1.0)
48 #define ONEPT5 RCONST(1.5)
49 
50 extern "C" {
51 
52 // Static constants
53 static constexpr sunindextype zeroIdx = 0;
54 
55 // Helpful macros
56 #define NVEC_RAJA_CONTENT(x) ((N_VectorContent_Raja)(x->content))
57 #define NVEC_RAJA_PRIVATE(x) ((N_PrivateVectorContent_Raja)(NVEC_RAJA_CONTENT(x)->priv))
58 #define NVEC_RAJA_MEMSIZE(x) (NVEC_RAJA_CONTENT(x)->length * sizeof(realtype))
59 #define NVEC_RAJA_MEMHELP(x) (NVEC_RAJA_CONTENT(x)->mem_helper)
60 #define NVEC_RAJA_HDATAp(x)  ((realtype*) NVEC_RAJA_CONTENT(x)->host_data->ptr)
61 #define NVEC_RAJA_DDATAp(x)  ((realtype*) NVEC_RAJA_CONTENT(x)->device_data->ptr)
62 
63 /*
64  * Private structure definition
65  */
66 
67 struct _N_PrivateVectorContent_Raja
68 {
69   booleantype use_managed_mem; /* indicates if the data pointers and buffer pointers are managed memory */
70 };
71 
72 typedef struct _N_PrivateVectorContent_Raja *N_PrivateVectorContent_Raja;
73 
74 
75 /*
76  * Utility functions
77  */
78 
79 static int AllocateData(N_Vector v);
80 static void CreateArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
81                                           int nvec, N_Vector *V);
82 static void Create2DArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
83                                             int nvec, int nsum, N_Vector **V);
84 
N_VNewEmpty_Raja()85 N_Vector N_VNewEmpty_Raja()
86 {
87   N_Vector v;
88 
89   /* Create an empty vector object */
90   v = NULL;
91   v = N_VNewEmpty();
92   if (v == NULL) return(NULL);
93 
94   /* Attach operations */
95 
96   /* constructors, destructors, and utility operations */
97   v->ops->nvgetvectorid           = N_VGetVectorID_Raja;
98   v->ops->nvclone                 = N_VClone_Raja;
99   v->ops->nvcloneempty            = N_VCloneEmpty_Raja;
100   v->ops->nvdestroy               = N_VDestroy_Raja;
101   v->ops->nvspace                 = N_VSpace_Raja;
102   v->ops->nvgetlength             = N_VGetLength_Raja;
103   v->ops->nvgetarraypointer       = N_VGetHostArrayPointer_Raja;
104   v->ops->nvgetdevicearraypointer = N_VGetDeviceArrayPointer_Raja;
105   v->ops->nvsetarraypointer       = N_VSetHostArrayPointer_Raja;
106 
107 
108   /* standard vector operations */
109   v->ops->nvlinearsum    = N_VLinearSum_Raja;
110   v->ops->nvconst        = N_VConst_Raja;
111   v->ops->nvprod         = N_VProd_Raja;
112   v->ops->nvdiv          = N_VDiv_Raja;
113   v->ops->nvscale        = N_VScale_Raja;
114   v->ops->nvabs          = N_VAbs_Raja;
115   v->ops->nvinv          = N_VInv_Raja;
116   v->ops->nvaddconst     = N_VAddConst_Raja;
117   v->ops->nvdotprod      = N_VDotProd_Raja;
118   v->ops->nvmaxnorm      = N_VMaxNorm_Raja;
119   v->ops->nvmin          = N_VMin_Raja;
120   v->ops->nvl1norm       = N_VL1Norm_Raja;
121   v->ops->nvinvtest      = N_VInvTest_Raja;
122   v->ops->nvconstrmask   = N_VConstrMask_Raja;
123   v->ops->nvminquotient  = N_VMinQuotient_Raja;
124   v->ops->nvwrmsnormmask = N_VWrmsNormMask_Raja;
125   v->ops->nvwrmsnorm     = N_VWrmsNorm_Raja;
126   v->ops->nvwl2norm      = N_VWL2Norm_Raja;
127   v->ops->nvcompare      = N_VCompare_Raja;
128 
129   /* fused and vector array operations are disabled (NULL) by default */
130 
131   /* local reduction operations */
132   v->ops->nvwsqrsumlocal     = N_VWSqrSumLocal_Raja;
133   v->ops->nvwsqrsummasklocal = N_VWSqrSumMaskLocal_Raja;
134   v->ops->nvdotprodlocal     = N_VDotProd_Raja;
135   v->ops->nvmaxnormlocal     = N_VMaxNorm_Raja;
136   v->ops->nvminlocal         = N_VMin_Raja;
137   v->ops->nvl1normlocal      = N_VL1Norm_Raja;
138   v->ops->nvinvtestlocal     = N_VInvTest_Raja;
139   v->ops->nvconstrmasklocal  = N_VConstrMask_Raja;
140   v->ops->nvminquotientlocal = N_VMinQuotient_Raja;
141 
142   /* XBraid interface operations */
143   v->ops->nvbufsize   = N_VBufSize_Raja;
144   v->ops->nvbufpack   = N_VBufPack_Raja;
145   v->ops->nvbufunpack = N_VBufUnpack_Raja;
146 
147   /* print operation for debugging */
148   v->ops->nvprint            = N_VPrint_Raja;
149   v->ops->nvprintfile        = N_VPrintFile_Raja;
150 
151   v->content = (N_VectorContent_Raja) malloc(sizeof(_N_VectorContent_Raja));
152   if (v->content == NULL)
153   {
154     N_VDestroy(v);
155     return NULL;
156   }
157 
158   NVEC_RAJA_CONTENT(v)->priv = malloc(sizeof(_N_PrivateVectorContent_Raja));
159   if (NVEC_RAJA_CONTENT(v)->priv == NULL)
160   {
161     N_VDestroy(v);
162     return NULL;
163   }
164 
165   NVEC_RAJA_CONTENT(v)->length          = 0;
166   NVEC_RAJA_CONTENT(v)->mem_helper      = NULL;
167   NVEC_RAJA_CONTENT(v)->own_helper      = SUNFALSE;
168   NVEC_RAJA_CONTENT(v)->host_data       = NULL;
169   NVEC_RAJA_CONTENT(v)->device_data     = NULL;
170   NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
171 
172   return(v);
173 }
174 
N_VNew_Raja(sunindextype length)175 N_Vector N_VNew_Raja(sunindextype length)
176 {
177   N_Vector v;
178 
179   v = NULL;
180   v = N_VNewEmpty_Raja();
181   if (v == NULL) return(NULL);
182 
183   NVEC_RAJA_CONTENT(v)->length          = length;
184 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
185   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Cuda();
186 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
187   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Hip();
188 #endif
189   NVEC_RAJA_CONTENT(v)->own_helper      = SUNTRUE;
190   NVEC_RAJA_CONTENT(v)->host_data       = NULL;
191   NVEC_RAJA_CONTENT(v)->device_data     = NULL;
192   NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
193 
194   if (NVEC_RAJA_MEMHELP(v) == NULL)
195   {
196     SUNDIALS_DEBUG_PRINT("ERROR in N_VNew_Raja: memory helper is NULL\n");
197     N_VDestroy(v);
198     return(NULL);
199   }
200 
201   if (AllocateData(v))
202   {
203     SUNDIALS_DEBUG_PRINT("ERROR in N_VNew_Raja: AllocateData returned nonzero\n");
204     N_VDestroy(v);
205     return NULL;
206   }
207 
208   return(v);
209 }
210 
N_VNewWithMemHelp_Raja(sunindextype length,booleantype use_managed_mem,SUNMemoryHelper helper)211 N_Vector N_VNewWithMemHelp_Raja(sunindextype length, booleantype use_managed_mem, SUNMemoryHelper helper)
212 {
213   N_Vector v;
214 
215   if (helper == NULL)
216   {
217     SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: helper is NULL\n");
218     return(NULL);
219   }
220 
221   if (!SUNMemoryHelper_ImplementsRequiredOps(helper))
222   {
223     SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: helper doesn't implement all required ops\n");
224     return(NULL);
225   }
226 
227   v = NULL;
228   v = N_VNewEmpty_Raja();
229   if (v == NULL) return(NULL);
230 
231   NVEC_RAJA_CONTENT(v)->length          = length;
232   NVEC_RAJA_CONTENT(v)->mem_helper      = helper;
233   NVEC_RAJA_CONTENT(v)->own_helper      = SUNFALSE;
234   NVEC_RAJA_CONTENT(v)->host_data       = NULL;
235   NVEC_RAJA_CONTENT(v)->device_data     = NULL;
236   NVEC_RAJA_PRIVATE(v)->use_managed_mem = use_managed_mem;
237 
238   if (AllocateData(v))
239   {
240     SUNDIALS_DEBUG_PRINT("ERROR in N_VNewWithMemHelp_Raja: AllocateData returned nonzero\n");
241     N_VDestroy(v);
242     return(NULL);
243   }
244 
245   return(v);
246 }
247 
N_VNewManaged_Raja(sunindextype length)248 N_Vector N_VNewManaged_Raja(sunindextype length)
249 {
250   N_Vector v;
251 
252   v = NULL;
253   v = N_VNewEmpty_Raja();
254   if (v == NULL) return(NULL);
255 
256   NVEC_RAJA_CONTENT(v)->length          = length;
257 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
258   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Cuda();
259 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
260   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Hip();
261 #endif
262   NVEC_RAJA_CONTENT(v)->own_helper      = SUNTRUE;
263   NVEC_RAJA_CONTENT(v)->host_data       = NULL;
264   NVEC_RAJA_CONTENT(v)->device_data     = NULL;
265   NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNTRUE;
266 
267   if (NVEC_RAJA_MEMHELP(v) == NULL)
268   {
269     SUNDIALS_DEBUG_PRINT("ERROR in N_VNewManaged_Raja: memory helper is NULL\n");
270     N_VDestroy(v);
271     return(NULL);
272   }
273 
274   if (AllocateData(v))
275   {
276     SUNDIALS_DEBUG_PRINT("ERROR in N_VNewManaged_Raja: AllocateData returned nonzero\n");
277     N_VDestroy(v);
278     return NULL;
279   }
280 
281   return(v);
282 }
283 
N_VMake_Raja(sunindextype length,realtype * h_vdata,realtype * d_vdata)284 N_Vector N_VMake_Raja(sunindextype length, realtype *h_vdata, realtype *d_vdata)
285 {
286   N_Vector v;
287 
288   if (h_vdata == NULL || d_vdata == NULL) return(NULL);
289 
290   v = NULL;
291   v = N_VNewEmpty_Raja();
292   if (v == NULL) return(NULL);
293 
294   NVEC_RAJA_CONTENT(v)->length          = length;
295   NVEC_RAJA_CONTENT(v)->host_data       = SUNMemoryHelper_Wrap(h_vdata, SUNMEMTYPE_HOST);
296   NVEC_RAJA_CONTENT(v)->device_data     = SUNMemoryHelper_Wrap(d_vdata, SUNMEMTYPE_DEVICE);
297 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
298   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Cuda();
299 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
300   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Hip();
301 #endif
302   NVEC_RAJA_CONTENT(v)->own_helper      = SUNTRUE;
303   NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNFALSE;
304 
305   if (NVEC_RAJA_MEMHELP(v) == NULL)
306   {
307     SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: memory helper is NULL\n");
308     N_VDestroy(v);
309     return(NULL);
310   }
311 
312 
313   if (NVEC_RAJA_CONTENT(v)->device_data == NULL ||
314       NVEC_RAJA_CONTENT(v)->host_data == NULL)
315   {
316     SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: SUNMemoryHelper_Wrap returned NULL\n");
317     N_VDestroy(v);
318     return(NULL);
319   }
320 
321   return(v);
322 }
323 
N_VMakeManaged_Raja(sunindextype length,realtype * vdata)324 N_Vector N_VMakeManaged_Raja(sunindextype length, realtype *vdata)
325 {
326   N_Vector v;
327 
328   if (vdata == NULL) return(NULL);
329 
330   v = NULL;
331   v = N_VNewEmpty_Raja();
332   if (v == NULL) return(NULL);
333 
334   NVEC_RAJA_CONTENT(v)->length          = length;
335   NVEC_RAJA_CONTENT(v)->host_data       = SUNMemoryHelper_Wrap(vdata, SUNMEMTYPE_UVM);
336   NVEC_RAJA_CONTENT(v)->device_data     = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->host_data);
337 #if defined(SUNDIALS_RAJA_BACKENDS_CUDA)
338   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Cuda();
339 #elif defined(SUNDIALS_RAJA_BACKENDS_HIP)
340   NVEC_RAJA_CONTENT(v)->mem_helper      = SUNMemoryHelper_Hip();
341 #endif
342   NVEC_RAJA_CONTENT(v)->own_helper      = SUNTRUE;
343   NVEC_RAJA_PRIVATE(v)->use_managed_mem = SUNTRUE;
344 
345   if (NVEC_RAJA_MEMHELP(v) == NULL)
346   {
347     SUNDIALS_DEBUG_PRINT("ERROR in N_VMakeManaged_Raja: memory helper is NULL\n");
348     N_VDestroy(v);
349     return(NULL);
350   }
351 
352   if (NVEC_RAJA_CONTENT(v)->device_data == NULL ||
353       NVEC_RAJA_CONTENT(v)->host_data == NULL)
354   {
355     SUNDIALS_DEBUG_PRINT("ERROR in N_VMake_Raja: SUNMemoryHelper_Wrap returned NULL\n");
356     N_VDestroy(v);
357     return(NULL);
358   }
359 
360   return(v);
361 }
362 
363 /* -----------------------------------------------------------------
364  * Function to return the global length of the vector.
365  * This is defined as an inline function in nvector_raja.h, so
366  * we just mark it as extern here.
367  */
368 extern sunindextype N_VGetLength_Raja(N_Vector v);
369 
370 /* ----------------------------------------------------------------------------
371  * Return pointer to the raw host data.
372  * This is defined as an inline function in nvector_raja.h, so
373  * we just mark it as extern here.
374  */
375 
376 extern realtype *N_VGetHostArrayPointer_Raja(N_Vector x);
377 
378 /* ----------------------------------------------------------------------------
379  * Return pointer to the raw device data.
380  * This is defined as an inline function in nvector_raja.h, so
381  * we just mark it as extern here.
382  */
383 
384 extern realtype *N_VGetDeviceArrayPointer_Raja(N_Vector x);
385 
386 
387 /* ----------------------------------------------------------------------------
388  * Set pointer to the raw host data. Does not free the existing pointer.
389  */
390 
N_VSetHostArrayPointer_Raja(realtype * h_vdata,N_Vector v)391 void N_VSetHostArrayPointer_Raja(realtype* h_vdata, N_Vector v)
392 {
393   if (N_VIsManagedMemory_Raja(v))
394   {
395     if (NVEC_RAJA_CONTENT(v)->host_data)
396     {
397       NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) h_vdata;
398       NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) h_vdata;
399     }
400     else
401     {
402       NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap((void*) h_vdata, SUNMEMTYPE_UVM);
403       NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->host_data);
404     }
405   }
406   else
407   {
408     if (NVEC_RAJA_CONTENT(v)->host_data)
409     {
410       NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) h_vdata;
411     }
412     else
413     {
414       NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Wrap((void*) h_vdata, SUNMEMTYPE_HOST);
415     }
416   }
417 }
418 
419 /* ----------------------------------------------------------------------------
420  * Set pointer to the raw device data
421  */
422 
N_VSetDeviceArrayPointer_Raja(realtype * d_vdata,N_Vector v)423 void N_VSetDeviceArrayPointer_Raja(realtype* d_vdata, N_Vector v)
424 {
425   if (N_VIsManagedMemory_Raja(v))
426   {
427     if (NVEC_RAJA_CONTENT(v)->device_data)
428     {
429       NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) d_vdata;
430       NVEC_RAJA_CONTENT(v)->host_data->ptr = (void*) d_vdata;
431     }
432     else
433     {
434       NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Wrap((void*) d_vdata, SUNMEMTYPE_UVM);
435       NVEC_RAJA_CONTENT(v)->host_data = SUNMemoryHelper_Alias(NVEC_RAJA_CONTENT(v)->device_data);
436     }
437   }
438   else
439   {
440     if (NVEC_RAJA_CONTENT(v)->device_data)
441     {
442       NVEC_RAJA_CONTENT(v)->device_data->ptr = (void*) d_vdata;
443     }
444     else
445     {
446       NVEC_RAJA_CONTENT(v)->device_data = SUNMemoryHelper_Wrap((void*) d_vdata, SUNMEMTYPE_DEVICE);
447     }
448   }
449 }
450 
451 /* ----------------------------------------------------------------------------
452  * Return a flag indicating if the memory for the vector data is managed
453  */
N_VIsManagedMemory_Raja(N_Vector x)454 booleantype N_VIsManagedMemory_Raja(N_Vector x)
455 {
456   return NVEC_RAJA_PRIVATE(x)->use_managed_mem;
457 }
458 
459 /* ----------------------------------------------------------------------------
460  * Copy vector data to the device
461  */
462 
N_VCopyToDevice_Raja(N_Vector x)463 void N_VCopyToDevice_Raja(N_Vector x)
464 {
465   int copy_fail;
466 
467   copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
468                                         NVEC_RAJA_CONTENT(x)->device_data,
469                                         NVEC_RAJA_CONTENT(x)->host_data,
470                                         NVEC_RAJA_MEMSIZE(x),
471                                         0);
472 
473   if (copy_fail)
474   {
475     SUNDIALS_DEBUG_PRINT("ERROR in N_VCopyToDevice_Raja: SUNMemoryHelper_CopyAsync returned nonzero\n");
476   }
477 
478   /* we synchronize with respect to the host, but on the default stream currently */
479   SUNDIALS_GPU_VERIFY(SUNDIALS_GPU_PREFIX(StreamSynchronize)(0));
480 }
481 
482 /* ----------------------------------------------------------------------------
483  * Copy vector data from the device to the host
484  */
485 
N_VCopyFromDevice_Raja(N_Vector x)486 void N_VCopyFromDevice_Raja(N_Vector x)
487 {
488   int copy_fail;
489 
490   copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
491                                         NVEC_RAJA_CONTENT(x)->host_data,
492                                         NVEC_RAJA_CONTENT(x)->device_data,
493                                         NVEC_RAJA_MEMSIZE(x),
494                                         0);
495 
496   if (copy_fail)
497   {
498     SUNDIALS_DEBUG_PRINT("ERROR in N_VCopyFromDevice_Raja: SUNMemoryHelper_CopyAsync returned nonzero\n");
499   }
500 
501   /* we synchronize with respect to the host, but only in this stream */
502   SUNDIALS_GPU_VERIFY(SUNDIALS_GPU_PREFIX(StreamSynchronize)(0));
503 }
504 
505 /* ----------------------------------------------------------------------------
506  * Function to print the a serial vector to stdout
507  */
508 
N_VPrint_Raja(N_Vector X)509 void N_VPrint_Raja(N_Vector X)
510 {
511   N_VPrintFile_Raja(X, stdout);
512 }
513 
514 /* ----------------------------------------------------------------------------
515  * Function to print the a serial vector to outfile
516  */
517 
N_VPrintFile_Raja(N_Vector X,FILE * outfile)518 void N_VPrintFile_Raja(N_Vector X, FILE *outfile)
519 {
520   sunindextype i;
521 
522   for (i = 0; i < NVEC_RAJA_CONTENT(X)->length; i++) {
523 #if defined(SUNDIALS_EXTENDED_PRECISION)
524     fprintf(outfile, "%35.32Lg\n", NVEC_RAJA_HDATAp(X)[i]);
525 #elif defined(SUNDIALS_DOUBLE_PRECISION)
526     fprintf(outfile, "%19.16g\n", NVEC_RAJA_HDATAp(X)[i]);
527 #else
528     fprintf(outfile, "%11.8g\n", NVEC_RAJA_HDATAp(X)[i]);
529 #endif
530   }
531   fprintf(outfile, "\n");
532 
533   return;
534 }
535 
536 /*
537  * -----------------------------------------------------------------
538  * implementation of vector operations
539  * -----------------------------------------------------------------
540  */
541 
N_VCloneEmpty_Raja(N_Vector w)542 N_Vector N_VCloneEmpty_Raja(N_Vector w)
543 {
544   N_Vector v;
545 
546   if (w == NULL) return(NULL);
547 
548   /* Create vector */
549   v = NULL;
550   v = N_VNewEmpty_Raja();
551   if (v == NULL) return(NULL);
552 
553   /* Attach operations */
554   if (N_VCopyOps(w, v)) { N_VDestroy(v); return(NULL); }
555 
556   /* Set content */
557   NVEC_RAJA_CONTENT(v)->length          = NVEC_RAJA_CONTENT(w)->length;
558   NVEC_RAJA_CONTENT(v)->host_data       = NULL;
559   NVEC_RAJA_CONTENT(v)->device_data     = NULL;
560   NVEC_RAJA_PRIVATE(v)->use_managed_mem = NVEC_RAJA_PRIVATE(w)->use_managed_mem;
561 
562 
563   return(v);
564 }
565 
N_VClone_Raja(N_Vector w)566 N_Vector N_VClone_Raja(N_Vector w)
567 {
568   N_Vector v;
569   v = NULL;
570   v = N_VCloneEmpty_Raja(w);
571   if (v == NULL) return(NULL);
572 
573   NVEC_RAJA_CONTENT(v)->mem_helper = SUNMemoryHelper_Clone(NVEC_RAJA_MEMHELP(w));
574   NVEC_RAJA_CONTENT(v)->own_helper = SUNTRUE;
575 
576   if (AllocateData(v))
577   {
578     SUNDIALS_DEBUG_PRINT("ERROR in N_VClone_Raja: AllocateData returned nonzero\n");
579     N_VDestroy(v);
580     return NULL;
581   }
582 
583 return(v);
584 
585 }
586 
587 
N_VDestroy_Raja(N_Vector v)588 void N_VDestroy_Raja(N_Vector v)
589 {
590   N_VectorContent_Raja vc;
591   N_PrivateVectorContent_Raja vcp;
592 
593   if (v == NULL) return;
594 
595   /* free ops structure */
596   if (v->ops != NULL)
597   {
598     free(v->ops);
599     v->ops = NULL;
600   }
601 
602   /* extract content */
603   vc = NVEC_RAJA_CONTENT(v);
604   if (vc == NULL)
605   {
606     free(v);
607     v = NULL;
608     return;
609   }
610 
611   /* free private content */
612   vcp = (N_PrivateVectorContent_Raja) vc->priv;
613   if (vcp != NULL)
614   {
615     /* free items in private content */
616     free(vcp);
617     vc->priv = NULL;
618   }
619 
620   /* free items in content */
621   if (NVEC_RAJA_MEMHELP(v))
622   {
623     SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(v), vc->host_data);
624     vc->host_data = NULL;
625     SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(v), vc->device_data);
626     vc->device_data = NULL;
627     if (vc->own_helper) SUNMemoryHelper_Destroy(vc->mem_helper);
628     vc->mem_helper = NULL;
629   }
630   else
631   {
632     SUNDIALS_DEBUG_PRINT("WARNING in N_VDestroy_Raja: mem_helper was NULL when trying to dealloc data, this could result in a memory leak\n");
633   }
634 
635   /* free content struct */
636   free(vc);
637 
638   /* free vector */
639   free(v);
640 
641   return;
642 }
643 
N_VSpace_Raja(N_Vector X,sunindextype * lrw,sunindextype * liw)644 void N_VSpace_Raja(N_Vector X, sunindextype *lrw, sunindextype *liw)
645 {
646   *lrw = NVEC_RAJA_CONTENT(X)->length;
647   *liw = 2;
648 }
649 
N_VConst_Raja(realtype c,N_Vector Z)650 void N_VConst_Raja(realtype c, N_Vector Z)
651 {
652   const sunindextype N = NVEC_RAJA_CONTENT(Z)->length;
653   realtype *zdata = NVEC_RAJA_DDATAp(Z);
654 
655   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N), RAJA_LAMBDA(sunindextype i) {
656      zdata[i] = c;
657   });
658 }
659 
N_VLinearSum_Raja(realtype a,N_Vector X,realtype b,N_Vector Y,N_Vector Z)660 void N_VLinearSum_Raja(realtype a, N_Vector X, realtype b, N_Vector Y, N_Vector Z)
661 {
662   const realtype *xdata = NVEC_RAJA_DDATAp(X);
663   const realtype *ydata = NVEC_RAJA_DDATAp(Y);
664   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
665   realtype *zdata = NVEC_RAJA_DDATAp(Z);
666 
667   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
668     RAJA_LAMBDA(sunindextype i) {
669       zdata[i] = a*xdata[i] + b*ydata[i];
670     }
671   );
672 }
673 
N_VProd_Raja(N_Vector X,N_Vector Y,N_Vector Z)674 void N_VProd_Raja(N_Vector X, N_Vector Y, N_Vector Z)
675 {
676   const realtype *xdata = NVEC_RAJA_DDATAp(X);
677   const realtype *ydata = NVEC_RAJA_DDATAp(Y);
678   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
679   realtype *zdata = NVEC_RAJA_DDATAp(Z);
680 
681   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
682     RAJA_LAMBDA(sunindextype i) {
683       zdata[i] = xdata[i] * ydata[i];
684     }
685   );
686 }
687 
N_VDiv_Raja(N_Vector X,N_Vector Y,N_Vector Z)688 void N_VDiv_Raja(N_Vector X, N_Vector Y, N_Vector Z)
689 {
690   const realtype *xdata = NVEC_RAJA_DDATAp(X);
691   const realtype *ydata = NVEC_RAJA_DDATAp(Y);
692   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
693   realtype *zdata = NVEC_RAJA_DDATAp(Z);
694 
695   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
696     RAJA_LAMBDA(sunindextype i) {
697       zdata[i] = xdata[i] / ydata[i];
698     }
699   );
700 }
701 
N_VScale_Raja(realtype c,N_Vector X,N_Vector Z)702 void N_VScale_Raja(realtype c, N_Vector X, N_Vector Z)
703 {
704   const realtype *xdata = NVEC_RAJA_DDATAp(X);
705   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
706   realtype *zdata = NVEC_RAJA_DDATAp(Z);
707 
708   RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
709     RAJA_LAMBDA(sunindextype i) {
710       zdata[i] = c * xdata[i];
711     }
712   );
713 }
714 
N_VAbs_Raja(N_Vector X,N_Vector Z)715 void N_VAbs_Raja(N_Vector X, N_Vector Z)
716 {
717   const realtype *xdata = NVEC_RAJA_DDATAp(X);
718   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
719   realtype *zdata = NVEC_RAJA_DDATAp(Z);
720 
721   RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
722     RAJA_LAMBDA(sunindextype i) {
723       zdata[i] = abs(xdata[i]);
724     }
725   );
726 }
727 
N_VInv_Raja(N_Vector X,N_Vector Z)728 void N_VInv_Raja(N_Vector X, N_Vector Z)
729 {
730   const realtype *xdata = NVEC_RAJA_DDATAp(X);
731   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
732   realtype *zdata = NVEC_RAJA_DDATAp(Z);
733 
734   RAJA::forall<RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
735     RAJA_LAMBDA(sunindextype i) {
736       zdata[i] = ONE / xdata[i];
737     }
738   );
739 }
740 
N_VAddConst_Raja(N_Vector X,realtype b,N_Vector Z)741 void N_VAddConst_Raja(N_Vector X, realtype b, N_Vector Z)
742 {
743   const realtype *xdata = NVEC_RAJA_DDATAp(X);
744   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
745   realtype *zdata = NVEC_RAJA_DDATAp(Z);
746 
747   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
748     RAJA_LAMBDA(sunindextype i) {
749       zdata[i] = xdata[i] + b;
750     }
751   );
752 }
753 
N_VDotProd_Raja(N_Vector X,N_Vector Y)754 realtype N_VDotProd_Raja(N_Vector X, N_Vector Y)
755 {
756   const realtype *xdata = NVEC_RAJA_DDATAp(X);
757   const realtype *ydata = NVEC_RAJA_DDATAp(Y);
758   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
759 
760   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
761   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
762     RAJA_LAMBDA(sunindextype i) {
763       gpu_result += xdata[i] * ydata[i] ;
764     }
765   );
766 
767   return (static_cast<realtype>(gpu_result));
768 }
769 
N_VMaxNorm_Raja(N_Vector X)770 realtype N_VMaxNorm_Raja(N_Vector X)
771 {
772   const realtype *xdata = NVEC_RAJA_DDATAp(X);
773   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
774 
775   RAJA::ReduceMax< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
776   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
777     RAJA_LAMBDA(sunindextype i) {
778       gpu_result.max(abs(xdata[i]));
779     }
780   );
781 
782   return (static_cast<realtype>(gpu_result));
783 }
784 
N_VWSqrSumLocal_Raja(N_Vector X,N_Vector W)785 realtype N_VWSqrSumLocal_Raja(N_Vector X, N_Vector W)
786 {
787   const realtype *xdata = NVEC_RAJA_DDATAp(X);
788   const realtype *wdata = NVEC_RAJA_DDATAp(W);
789   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
790 
791   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
792   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
793     RAJA_LAMBDA(sunindextype i) {
794       gpu_result += (xdata[i] * wdata[i] * xdata[i] * wdata[i]);
795     }
796   );
797 
798   return (static_cast<realtype>(gpu_result));
799 }
800 
N_VWrmsNorm_Raja(N_Vector X,N_Vector W)801 realtype N_VWrmsNorm_Raja(N_Vector X, N_Vector W)
802 {
803   const realtype sum = N_VWSqrSumLocal_Raja(X, W);
804   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
805   return std::sqrt(sum/N);
806 }
807 
N_VWSqrSumMaskLocal_Raja(N_Vector X,N_Vector W,N_Vector ID)808 realtype N_VWSqrSumMaskLocal_Raja(N_Vector X, N_Vector W, N_Vector ID)
809 {
810   const realtype *xdata = NVEC_RAJA_DDATAp(X);
811   const realtype *wdata = NVEC_RAJA_DDATAp(W);
812   const realtype *iddata = NVEC_RAJA_DDATAp(ID);
813   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
814 
815   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
816   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
817     RAJA_LAMBDA(sunindextype i) {
818       if (iddata[i] > ZERO)
819         gpu_result += (xdata[i] * wdata[i] * xdata[i] * wdata[i]);
820     }
821   );
822 
823   return (static_cast<realtype>(gpu_result));
824 }
825 
N_VWrmsNormMask_Raja(N_Vector X,N_Vector W,N_Vector ID)826 realtype N_VWrmsNormMask_Raja(N_Vector X, N_Vector W, N_Vector ID)
827 {
828   const realtype sum = N_VWSqrSumMaskLocal_Raja(X, W, ID);
829   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
830   return std::sqrt(sum/N);
831 }
832 
N_VMin_Raja(N_Vector X)833 realtype N_VMin_Raja(N_Vector X)
834 {
835   const realtype *xdata = NVEC_RAJA_DDATAp(X);
836   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
837 
838   RAJA::ReduceMin< RAJA_REDUCE_TYPE, realtype> gpu_result(std::numeric_limits<realtype>::max());
839   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
840     RAJA_LAMBDA(sunindextype i) {
841       gpu_result.min(xdata[i]);
842     }
843   );
844 
845   return (static_cast<realtype>(gpu_result));
846 }
847 
N_VWL2Norm_Raja(N_Vector X,N_Vector W)848 realtype N_VWL2Norm_Raja(N_Vector X, N_Vector W)
849 {
850   return std::sqrt(N_VWSqrSumLocal_Raja(X, W));
851 }
852 
N_VL1Norm_Raja(N_Vector X)853 realtype N_VL1Norm_Raja(N_Vector X)
854 {
855   const realtype *xdata = NVEC_RAJA_DDATAp(X);
856   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
857 
858   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(0.0);
859   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
860     RAJA_LAMBDA(sunindextype i) {
861       gpu_result += (abs(xdata[i]));
862     }
863   );
864 
865   return (static_cast<realtype>(gpu_result));
866 }
867 
N_VCompare_Raja(realtype c,N_Vector X,N_Vector Z)868 void N_VCompare_Raja(realtype c, N_Vector X, N_Vector Z)
869 {
870   const realtype *xdata = NVEC_RAJA_DDATAp(X);
871   const sunindextype N = NVEC_RAJA_CONTENT(X)->length;
872   realtype *zdata = NVEC_RAJA_DDATAp(Z);
873 
874   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
875     RAJA_LAMBDA(sunindextype i) {
876       zdata[i] = abs(xdata[i]) >= c ? ONE : ZERO;
877     }
878   );
879 }
880 
N_VInvTest_Raja(N_Vector x,N_Vector z)881 booleantype N_VInvTest_Raja(N_Vector x, N_Vector z)
882 {
883   const realtype *xdata = NVEC_RAJA_DDATAp(x);
884   const sunindextype N = NVEC_RAJA_CONTENT(x)->length;
885   realtype *zdata = NVEC_RAJA_DDATAp(z);
886 
887   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(ZERO);
888   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
889     RAJA_LAMBDA(sunindextype i) {
890       if (xdata[i] == ZERO) {
891         gpu_result += ONE;
892       } else {
893         zdata[i] = ONE/xdata[i];
894       }
895     }
896   );
897   realtype minimum = static_cast<realtype>(gpu_result);
898   return (minimum < HALF);
899 }
900 
N_VConstrMask_Raja(N_Vector c,N_Vector x,N_Vector m)901 booleantype N_VConstrMask_Raja(N_Vector c, N_Vector x, N_Vector m)
902 {
903   const realtype *cdata = NVEC_RAJA_DDATAp(c);
904   const realtype *xdata = NVEC_RAJA_DDATAp(x);
905   const sunindextype N = NVEC_RAJA_CONTENT(x)->length;
906   realtype *mdata = NVEC_RAJA_DDATAp(m);
907 
908   RAJA::ReduceSum< RAJA_REDUCE_TYPE, realtype> gpu_result(ZERO);
909   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
910     RAJA_LAMBDA(sunindextype i) {
911       bool test = (abs(cdata[i]) > ONEPT5 && cdata[i]*xdata[i] <= ZERO) ||
912                   (abs(cdata[i]) > HALF   && cdata[i]*xdata[i] <  ZERO);
913       mdata[i] = test ? ONE : ZERO;
914       gpu_result += mdata[i];
915     }
916   );
917 
918   realtype sum = static_cast<realtype>(gpu_result);
919   return(sum < HALF);
920 }
921 
N_VMinQuotient_Raja(N_Vector num,N_Vector denom)922 realtype N_VMinQuotient_Raja(N_Vector num, N_Vector denom)
923 {
924   const realtype *ndata = NVEC_RAJA_DDATAp(num);
925   const realtype *ddata = NVEC_RAJA_DDATAp(denom);
926   const sunindextype N = NVEC_RAJA_CONTENT(num)->length;
927 
928   RAJA::ReduceMin< RAJA_REDUCE_TYPE, realtype> gpu_result(std::numeric_limits<realtype>::max());
929   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
930     RAJA_LAMBDA(sunindextype i) {
931       if (ddata[i] != ZERO)
932         gpu_result.min(ndata[i]/ddata[i]);
933     }
934   );
935   return (static_cast<realtype>(gpu_result));
936 }
937 
938 
939 /*
940  * -----------------------------------------------------------------------------
941  * fused vector operations
942  * -----------------------------------------------------------------------------
943  */
944 
N_VLinearCombination_Raja(int nvec,realtype * c,N_Vector * X,N_Vector z)945 int N_VLinearCombination_Raja(int nvec, realtype* c, N_Vector* X, N_Vector z)
946 {
947   int retval;
948 
949   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(X[0]);
950   sunindextype N = NVEC_RAJA_CONTENT(z)->length;
951   realtype* d_zd = NVEC_RAJA_DDATAp(z);
952 
953   // Create device c array for device
954   SUNMemory h_c, d_c;
955   h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
956   retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
957   if (retval) return(-1);
958 
959   // Copy c array to device
960   retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
961   if (retval) return(-1);
962 
963   SUNMemory d_X;
964   realtype** d_Xd;
965   CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
966 
967   // Shortcut to the arrays to work on
968   realtype* d_cd  = (realtype*) d_c->ptr;
969   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
970     RAJA_LAMBDA(sunindextype i) {
971       d_zd[i] = d_cd[0] * d_Xd[0][i];
972       for (int j=1; j<nvec; j++)
973         d_zd[i] += d_cd[j] * d_Xd[j][i];
974     }
975   );
976 
977   SUNMemoryHelper_Dealloc(h, h_c);
978   SUNMemoryHelper_Dealloc(h, d_c);
979   SUNMemoryHelper_Dealloc(h, d_X);
980 
981   return(0);
982 }
983 
984 
N_VScaleAddMulti_Raja(int nvec,realtype * c,N_Vector x,N_Vector * Y,N_Vector * Z)985 int N_VScaleAddMulti_Raja(int nvec, realtype* c, N_Vector x, N_Vector* Y, N_Vector* Z)
986 {
987   int retval;
988 
989   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(x);
990   sunindextype N = NVEC_RAJA_CONTENT(x)->length;
991   realtype* d_xd = NVEC_RAJA_DDATAp(x);
992 
993   // Create c array for device
994   SUNMemory h_c, d_c;
995   h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
996   retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
997   if (retval) return(-1);
998 
999   // Copy c array to device
1000   retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
1001   if (retval) return(-1);
1002 
1003   SUNMemory d_Y, d_Z;
1004   realtype **d_Yd, **d_Zd;
1005   CreateArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, Y);
1006   CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1007 
1008   // Perform operation
1009   realtype* d_cd  = (realtype*) d_c->ptr;
1010   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1011      RAJA_LAMBDA(sunindextype i) {
1012       for (int j=0; j<nvec; j++)
1013         d_Zd[j][i] = d_cd[j] * d_xd[i] + d_Yd[j][i];
1014     }
1015   );
1016 
1017   SUNMemoryHelper_Dealloc(h, h_c);
1018   SUNMemoryHelper_Dealloc(h, d_c);
1019   SUNMemoryHelper_Dealloc(h, d_Y);
1020   SUNMemoryHelper_Dealloc(h, d_Z);
1021 
1022   return(0);
1023 }
1024 
1025 
1026 /*
1027  * -----------------------------------------------------------------------------
1028  * vector array operations
1029  * -----------------------------------------------------------------------------
1030  */
1031 
N_VLinearSumVectorArray_Raja(int nvec,realtype a,N_Vector * X,realtype b,N_Vector * Y,N_Vector * Z)1032 int N_VLinearSumVectorArray_Raja(int nvec,
1033                                  realtype a, N_Vector* X,
1034                                  realtype b, N_Vector* Y,
1035                                  N_Vector* Z)
1036 {
1037   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1038   sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1039 
1040   SUNMemory d_X, d_Y, d_Z;
1041   realtype **d_Xd, **d_Yd, **d_Zd;
1042   CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1043   CreateArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, Y);
1044   CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1045 
1046   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1047     RAJA_LAMBDA(sunindextype i) {
1048       for (int j=0; j<nvec; j++)
1049         d_Zd[j][i] = a * d_Xd[j][i] + b * d_Yd[j][i];
1050     }
1051   );
1052 
1053   SUNMemoryHelper_Dealloc(h, d_X);
1054   SUNMemoryHelper_Dealloc(h, d_Y);
1055   SUNMemoryHelper_Dealloc(h, d_Z);
1056 
1057   return(0);
1058 }
1059 
1060 
N_VScaleVectorArray_Raja(int nvec,realtype * c,N_Vector * X,N_Vector * Z)1061 int N_VScaleVectorArray_Raja(int nvec, realtype* c, N_Vector* X, N_Vector* Z)
1062 {
1063   int retval;
1064 
1065   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1066   sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1067 
1068   // Create c array for device
1069   SUNMemory h_c, d_c;
1070   h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1071   retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nvec, SUNMEMTYPE_DEVICE);
1072   if (retval) return(-1);
1073 
1074   // Copy c array to device
1075   retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nvec);
1076   if (retval) return(-1);
1077 
1078   SUNMemory d_X, d_Z;
1079   realtype **d_Xd, **d_Zd;
1080   CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1081   CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1082 
1083   realtype* d_cd  = (realtype*) d_c->ptr;
1084   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1085     RAJA_LAMBDA(sunindextype i) {
1086       for (int j=0; j<nvec; j++)
1087         d_Zd[j][i] = d_cd[j] * d_Xd[j][i];
1088     }
1089   );
1090 
1091   SUNMemoryHelper_Dealloc(h, h_c);
1092   SUNMemoryHelper_Dealloc(h, d_c);
1093   SUNMemoryHelper_Dealloc(h, d_X);
1094   SUNMemoryHelper_Dealloc(h, d_Z);
1095 
1096   return(0);
1097 }
1098 
1099 
N_VConstVectorArray_Raja(int nvec,realtype c,N_Vector * Z)1100 int N_VConstVectorArray_Raja(int nvec, realtype c, N_Vector* Z)
1101 {
1102   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1103   sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1104 
1105   SUNMemory d_Z;
1106   realtype** d_Zd;
1107   CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1108 
1109   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1110     RAJA_LAMBDA(sunindextype i) {
1111       for (int j=0; j<nvec; j++)
1112         d_Zd[j][i] = c;
1113     }
1114   );
1115 
1116   SUNMemoryHelper_Dealloc(h, d_Z);
1117 
1118   return(0);
1119 }
1120 
1121 
N_VScaleAddMultiVectorArray_Raja(int nvec,int nsum,realtype * c,N_Vector * X,N_Vector ** Y,N_Vector ** Z)1122 int N_VScaleAddMultiVectorArray_Raja(int nvec, int nsum, realtype* c,
1123                                      N_Vector* X, N_Vector** Y, N_Vector** Z)
1124 {
1125   int retval;
1126 
1127   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(X[0]);
1128   sunindextype N = NVEC_RAJA_CONTENT(X[0])->length;
1129 
1130   // Create c array for device
1131   SUNMemory h_c, d_c;
1132   h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1133   retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nsum, SUNMEMTYPE_DEVICE);
1134   if (retval) return(-1);
1135 
1136   // Copy c array to device
1137   retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nsum);
1138   if (retval) return(-1);
1139 
1140   SUNMemory d_X, d_Y, d_Z;
1141   realtype **d_Xd, **d_Yd, **d_Zd;
1142   CreateArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, X);
1143   Create2DArrayOfPointersOnDevice(&d_Yd, &d_Y, nvec, nsum, Y);
1144   Create2DArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, nsum, Z);
1145 
1146   realtype* d_cd = (realtype*) d_c->ptr;
1147   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1148     RAJA_LAMBDA(sunindextype i) {
1149       for (int j=0; j<nvec; j++)
1150         for (int k=0; k<nsum; k++)
1151           d_Zd[j*nsum+k][i] = d_cd[k] * d_Xd[j][i] + d_Yd[j*nsum+k][i];
1152     }
1153   );
1154 
1155   SUNMemoryHelper_Dealloc(h, h_c);
1156   SUNMemoryHelper_Dealloc(h, d_c);
1157   SUNMemoryHelper_Dealloc(h, d_X);
1158   SUNMemoryHelper_Dealloc(h, d_Y);
1159   SUNMemoryHelper_Dealloc(h, d_Z);
1160 
1161   return(0);
1162 }
1163 
1164 
N_VLinearCombinationVectorArray_Raja(int nvec,int nsum,realtype * c,N_Vector ** X,N_Vector * Z)1165 int N_VLinearCombinationVectorArray_Raja(int nvec, int nsum, realtype* c,
1166                                          N_Vector** X, N_Vector* Z)
1167 {
1168   int retval;
1169 
1170   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(Z[0]);
1171   sunindextype N = NVEC_RAJA_CONTENT(Z[0])->length;
1172 
1173   // Create c array for device
1174   SUNMemory h_c, d_c;
1175   h_c = SUNMemoryHelper_Wrap(c, SUNMEMTYPE_HOST);
1176   retval = SUNMemoryHelper_Alloc(h, &d_c, sizeof(realtype)*nsum, SUNMEMTYPE_DEVICE);
1177   if (retval) return(-1);
1178 
1179   // Copy c array to device
1180   retval = SUNMemoryHelper_Copy(h, d_c, h_c, sizeof(realtype)*nsum);
1181   if (retval) return(-1);
1182 
1183   SUNMemory d_X, d_Z;
1184   realtype **d_Xd, **d_Zd;
1185   CreateArrayOfPointersOnDevice(&d_Zd, &d_Z, nvec, Z);
1186   Create2DArrayOfPointersOnDevice(&d_Xd, &d_X, nvec, nsum, X);
1187 
1188   realtype *d_cd = (realtype*) d_c->ptr;
1189   RAJA::forall< RAJA_NODE_TYPE >(RAJA::RangeSegment(zeroIdx, N),
1190     RAJA_LAMBDA(sunindextype i) {
1191       for (int j=0; j<nvec; j++) {
1192         d_Zd[j][i] = d_cd[0] * d_Xd[j*nsum][i];
1193         for (int k=1; k<nsum; k++) {
1194           d_Zd[j][i] += d_cd[k] * d_Xd[j*nsum+k][i];
1195         }
1196       }
1197     }
1198   );
1199 
1200   SUNMemoryHelper_Dealloc(h, h_c);
1201   SUNMemoryHelper_Dealloc(h, d_c);
1202   SUNMemoryHelper_Dealloc(h, d_X);
1203   SUNMemoryHelper_Dealloc(h, d_Z);
1204 
1205   return(0);
1206 }
1207 
1208 
1209 /*
1210  * -----------------------------------------------------------------
1211  * OPTIONAL XBraid interface operations
1212  * -----------------------------------------------------------------
1213  */
1214 
1215 
N_VBufSize_Raja(N_Vector x,sunindextype * size)1216 int N_VBufSize_Raja(N_Vector x, sunindextype *size)
1217 {
1218   if (x == NULL) return(-1);
1219   *size = (sunindextype)NVEC_RAJA_MEMSIZE(x);
1220   return(0);
1221 }
1222 
1223 
N_VBufPack_Raja(N_Vector x,void * buf)1224 int N_VBufPack_Raja(N_Vector x, void *buf)
1225 {
1226   int copy_fail = 0;
1227   SUNDIALS_GPU_PREFIX(Error_t) cuerr;
1228 
1229   if (x == NULL || buf == NULL) return(-1);
1230 
1231   SUNMemory buf_mem = SUNMemoryHelper_Wrap(buf, SUNMEMTYPE_HOST);
1232   if (buf_mem == NULL) return(-1);
1233 
1234   copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
1235                                         buf_mem,
1236                                         NVEC_RAJA_CONTENT(x)->device_data,
1237                                         NVEC_RAJA_MEMSIZE(x),
1238                                         0);
1239 
1240   /* we synchronize with respect to the host, but only in this stream */
1241   cuerr = SUNDIALS_GPU_PREFIX(StreamSynchronize)(0);
1242 
1243   SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(x), buf_mem);
1244 
1245   return (!SUNDIALS_GPU_VERIFY(cuerr) || copy_fail ? -1 : 0);
1246 }
1247 
1248 
N_VBufUnpack_Raja(N_Vector x,void * buf)1249 int N_VBufUnpack_Raja(N_Vector x, void *buf)
1250 {
1251   int copy_fail = 0;
1252   SUNDIALS_GPU_PREFIX(Error_t) cuerr;
1253 
1254   if (x == NULL || buf == NULL) return(-1);
1255 
1256   SUNMemory buf_mem = SUNMemoryHelper_Wrap(buf, SUNMEMTYPE_HOST);
1257   if (buf_mem == NULL) return(-1);
1258 
1259   copy_fail = SUNMemoryHelper_CopyAsync(NVEC_RAJA_MEMHELP(x),
1260                                         NVEC_RAJA_CONTENT(x)->device_data,
1261                                         buf_mem,
1262                                         NVEC_RAJA_MEMSIZE(x),
1263                                         0);
1264 
1265   /* we synchronize with respect to the host, but only in this stream */
1266   cuerr = SUNDIALS_GPU_PREFIX(StreamSynchronize)(0);
1267 
1268   SUNMemoryHelper_Dealloc(NVEC_RAJA_MEMHELP(x), buf_mem);
1269 
1270   return (!SUNDIALS_GPU_VERIFY(cuerr) || copy_fail ? -1 : 0);
1271 }
1272 
1273 
1274 /*
1275  * -----------------------------------------------------------------
1276  * Enable / Disable fused and vector array operations
1277  * -----------------------------------------------------------------
1278  */
1279 
N_VEnableFusedOps_Raja(N_Vector v,booleantype tf)1280 int N_VEnableFusedOps_Raja(N_Vector v, booleantype tf)
1281 {
1282   /* check that vector is non-NULL */
1283   if (v == NULL) return(-1);
1284 
1285   /* check that ops structure is non-NULL */
1286   if (v->ops == NULL) return(-1);
1287 
1288   if (tf) {
1289     /* enable all fused vector operations */
1290     v->ops->nvlinearcombination = N_VLinearCombination_Raja;
1291     v->ops->nvscaleaddmulti     = N_VScaleAddMulti_Raja;
1292     v->ops->nvdotprodmulti      = NULL;
1293     /* enable all vector array operations */
1294     v->ops->nvlinearsumvectorarray         = N_VLinearSumVectorArray_Raja;
1295     v->ops->nvscalevectorarray             = N_VScaleVectorArray_Raja;
1296     v->ops->nvconstvectorarray             = N_VConstVectorArray_Raja;
1297     v->ops->nvwrmsnormvectorarray          = NULL;
1298     v->ops->nvwrmsnormmaskvectorarray      = NULL;
1299     v->ops->nvscaleaddmultivectorarray     = N_VScaleAddMultiVectorArray_Raja;
1300     v->ops->nvlinearcombinationvectorarray = N_VLinearCombinationVectorArray_Raja;
1301   } else {
1302     /* disable all fused vector operations */
1303     v->ops->nvlinearcombination = NULL;
1304     v->ops->nvscaleaddmulti     = NULL;
1305     v->ops->nvdotprodmulti      = NULL;
1306     /* disable all vector array operations */
1307     v->ops->nvlinearsumvectorarray         = NULL;
1308     v->ops->nvscalevectorarray             = NULL;
1309     v->ops->nvconstvectorarray             = NULL;
1310     v->ops->nvwrmsnormvectorarray          = NULL;
1311     v->ops->nvwrmsnormmaskvectorarray      = NULL;
1312     v->ops->nvscaleaddmultivectorarray     = NULL;
1313     v->ops->nvlinearcombinationvectorarray = NULL;
1314   }
1315 
1316   /* return success */
1317   return(0);
1318 }
1319 
N_VEnableLinearCombination_Raja(N_Vector v,booleantype tf)1320 int N_VEnableLinearCombination_Raja(N_Vector v, booleantype tf)
1321 {
1322   /* check that vector is non-NULL */
1323   if (v == NULL) return(-1);
1324 
1325   /* check that ops structure is non-NULL */
1326   if (v->ops == NULL) return(-1);
1327 
1328   /* enable/disable operation */
1329   if (tf)
1330     v->ops->nvlinearcombination = N_VLinearCombination_Raja;
1331   else
1332     v->ops->nvlinearcombination = NULL;
1333 
1334   /* return success */
1335   return(0);
1336 }
1337 
N_VEnableScaleAddMulti_Raja(N_Vector v,booleantype tf)1338 int N_VEnableScaleAddMulti_Raja(N_Vector v, booleantype tf)
1339 {
1340   /* check that vector is non-NULL */
1341   if (v == NULL) return(-1);
1342 
1343   /* check that ops structure is non-NULL */
1344   if (v->ops == NULL) return(-1);
1345 
1346   /* enable/disable operation */
1347   if (tf)
1348     v->ops->nvscaleaddmulti = N_VScaleAddMulti_Raja;
1349   else
1350     v->ops->nvscaleaddmulti = NULL;
1351 
1352   /* return success */
1353   return(0);
1354 }
1355 
N_VEnableLinearSumVectorArray_Raja(N_Vector v,booleantype tf)1356 int N_VEnableLinearSumVectorArray_Raja(N_Vector v, booleantype tf)
1357 {
1358   /* check that vector is non-NULL */
1359   if (v == NULL) return(-1);
1360 
1361   /* check that ops structure is non-NULL */
1362   if (v->ops == NULL) return(-1);
1363 
1364   /* enable/disable operation */
1365   if (tf)
1366     v->ops->nvlinearsumvectorarray = N_VLinearSumVectorArray_Raja;
1367   else
1368     v->ops->nvlinearsumvectorarray = NULL;
1369 
1370   /* return success */
1371   return(0);
1372 }
1373 
N_VEnableScaleVectorArray_Raja(N_Vector v,booleantype tf)1374 int N_VEnableScaleVectorArray_Raja(N_Vector v, booleantype tf)
1375 {
1376   /* check that vector is non-NULL */
1377   if (v == NULL) return(-1);
1378 
1379   /* check that ops structure is non-NULL */
1380   if (v->ops == NULL) return(-1);
1381 
1382   /* enable/disable operation */
1383   if (tf)
1384     v->ops->nvscalevectorarray = N_VScaleVectorArray_Raja;
1385   else
1386     v->ops->nvscalevectorarray = NULL;
1387 
1388   /* return success */
1389   return(0);
1390 }
1391 
N_VEnableConstVectorArray_Raja(N_Vector v,booleantype tf)1392 int N_VEnableConstVectorArray_Raja(N_Vector v, booleantype tf)
1393 {
1394   /* check that vector is non-NULL */
1395   if (v == NULL) return(-1);
1396 
1397   /* check that ops structure is non-NULL */
1398   if (v->ops == NULL) return(-1);
1399 
1400   /* enable/disable operation */
1401   if (tf)
1402     v->ops->nvconstvectorarray = N_VConstVectorArray_Raja;
1403   else
1404     v->ops->nvconstvectorarray = NULL;
1405 
1406   /* return success */
1407   return(0);
1408 }
1409 
N_VEnableScaleAddMultiVectorArray_Raja(N_Vector v,booleantype tf)1410 int N_VEnableScaleAddMultiVectorArray_Raja(N_Vector v, booleantype tf)
1411 {
1412   /* check that vector is non-NULL */
1413   if (v == NULL) return(-1);
1414 
1415   /* check that ops structure is non-NULL */
1416   if (v->ops == NULL) return(-1);
1417 
1418   /* enable/disable operation */
1419   if (tf)
1420     v->ops->nvscaleaddmultivectorarray = N_VScaleAddMultiVectorArray_Raja;
1421   else
1422     v->ops->nvscaleaddmultivectorarray = NULL;
1423 
1424   /* return success */
1425   return(0);
1426 }
1427 
N_VEnableLinearCombinationVectorArray_Raja(N_Vector v,booleantype tf)1428 int N_VEnableLinearCombinationVectorArray_Raja(N_Vector v, booleantype tf)
1429 {
1430   /* check that vector is non-NULL */
1431   if (v == NULL) return(-1);
1432 
1433   /* check that ops structure is non-NULL */
1434   if (v->ops == NULL) return(-1);
1435 
1436   /* enable/disable operation */
1437   if (tf)
1438     v->ops->nvlinearcombinationvectorarray = N_VLinearCombinationVectorArray_Raja;
1439   else
1440     v->ops->nvlinearcombinationvectorarray = NULL;
1441 
1442   /* return success */
1443   return(0);
1444 }
1445 
1446 
1447 /*
1448  * -----------------------------------------------------------------
1449  * Private utility functions
1450  * -----------------------------------------------------------------
1451  */
1452 
AllocateData(N_Vector v)1453 int AllocateData(N_Vector v)
1454 {
1455   int alloc_fail = 0;
1456   N_VectorContent_Raja vc = NVEC_RAJA_CONTENT(v);
1457   N_PrivateVectorContent_Raja vcp = NVEC_RAJA_PRIVATE(v);
1458 
1459   if (N_VGetLength_Raja(v) == 0) return(0);
1460 
1461   if (vcp->use_managed_mem)
1462   {
1463     alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->device_data),
1464                                        NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_UVM);
1465     if (alloc_fail)
1466     {
1467       SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed for SUNMEMTYPE_UVM\n");
1468     }
1469     vc->host_data = SUNMemoryHelper_Alias(vc->device_data);
1470   }
1471   else
1472   {
1473     alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->host_data),
1474                                        NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_HOST);
1475     if (alloc_fail)
1476     {
1477       SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed to alloc SUNMEMTYPE_HOST\n");
1478     }
1479 
1480     alloc_fail = SUNMemoryHelper_Alloc(NVEC_RAJA_MEMHELP(v), &(vc->device_data),
1481                                        NVEC_RAJA_MEMSIZE(v), SUNMEMTYPE_DEVICE);
1482     if (alloc_fail)
1483     {
1484       SUNDIALS_DEBUG_PRINT("ERROR in AllocateData: SUNMemoryHelper_Alloc failed to alloc SUNMEMTYPE_DEVICE\n");
1485     }
1486   }
1487 
1488   return(alloc_fail ? -1 : 0);
1489 }
1490 
1491 
CreateArrayOfPointersOnDevice(realtype *** d_ptrs,SUNMemory * d_ref,int nvec,N_Vector * V)1492 void CreateArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
1493                                    int nvec, N_Vector *V)
1494 {
1495   size_t bytes = sizeof(realtype*)*nvec;
1496   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(V[0]);
1497 
1498   // Default return values
1499   *d_ref  = nullptr;
1500   *d_ptrs = nullptr;
1501 
1502   // Create space for host and device pointers
1503   SUNMemory h_mem;
1504   SUNMemoryHelper_Alloc(h, &h_mem, bytes, SUNMEMTYPE_HOST);
1505 
1506   SUNMemory d_mem;
1507   SUNMemoryHelper_Alloc(h, &d_mem, bytes, SUNMEMTYPE_DEVICE);
1508 
1509   // Fill the host memory with the pointers
1510   realtype** h_array = (realtype**) h_mem->ptr;
1511   for (int j=0; j<nvec; j++) {
1512     h_array[j] = NVEC_RAJA_DDATAp(V[j]);
1513   }
1514 
1515   // Copy the host memory to the device
1516   SUNMemoryHelper_Copy(h, d_mem, h_mem, bytes);
1517 
1518   // Return the device SUNMemory, and the raw pointer array
1519   *d_ref  = d_mem;
1520   *d_ptrs = (realtype**) d_mem->ptr;
1521 
1522   // Free the host SUNMemory
1523   SUNMemoryHelper_Dealloc(h, h_mem);
1524 }
1525 
Create2DArrayOfPointersOnDevice(realtype *** d_ptrs,SUNMemory * d_ref,int nvec,int nsum,N_Vector ** V)1526 void Create2DArrayOfPointersOnDevice(realtype*** d_ptrs, SUNMemory* d_ref,
1527                                      int nvec, int nsum, N_Vector **V)
1528 {
1529   size_t bytes = sizeof(realtype*)*nsum*nvec;
1530   SUNMemoryHelper h = NVEC_RAJA_MEMHELP(V[0][0]);
1531 
1532   // Default return values
1533   *d_ref  = nullptr;
1534   *d_ptrs = nullptr;
1535 
1536   // Create space for host and device pointers
1537   SUNMemory h_mem;
1538   SUNMemoryHelper_Alloc(h, &h_mem, bytes, SUNMEMTYPE_HOST);
1539 
1540   SUNMemory d_mem;
1541   SUNMemoryHelper_Alloc(h, &d_mem, bytes, SUNMEMTYPE_DEVICE);
1542 
1543   // Fill the host memory with the pointers
1544   realtype** h_array = (realtype**) h_mem->ptr;
1545   for (int j=0; j<nvec; j++)
1546     for (int k=0; k<nsum; k++)
1547       h_array[j*nsum+k] = NVEC_RAJA_DDATAp(V[k][j]);
1548 
1549   // Copy the host memory to the device
1550   SUNMemoryHelper_Copy(h, d_mem, h_mem, bytes);
1551 
1552   // Return the device SUNMemory, and the raw pointer array
1553   *d_ref  = d_mem;
1554   *d_ptrs = (realtype**) d_mem->ptr;
1555 
1556   // Free the host SUNMemory
1557   SUNMemoryHelper_Dealloc(h, h_mem);
1558 }
1559 
1560 } // extern "C"
1561