1 #if !defined(__CUDAVECIMPL)
2 #define __CUDAVECIMPL
3 
4 #include <petscvec.h>
5 #include <petsccublas.h>
6 #include <petsc/private/vecimpl.h>
7 
8 typedef struct {
9   PetscScalar  *GPUarray;           /* this always holds the GPU data */
10   PetscScalar  *GPUarray_allocated; /* if the array was allocated by PETSc this is its pointer */
11   cudaStream_t stream;              /* A stream for doing asynchronous data transfers */
12 } Vec_CUDA;
13 
14 PETSC_INTERN PetscErrorCode VecCUDAGetArrays_Private(Vec,const PetscScalar**,const PetscScalar**,PetscOffloadMask*);
15 PETSC_INTERN PetscErrorCode VecDotNorm2_SeqCUDA(Vec,Vec,PetscScalar*, PetscScalar*);
16 PETSC_INTERN PetscErrorCode VecPointwiseDivide_SeqCUDA(Vec,Vec,Vec);
17 PETSC_INTERN PetscErrorCode VecWAXPY_SeqCUDA(Vec,PetscScalar,Vec,Vec);
18 PETSC_INTERN PetscErrorCode VecMDot_SeqCUDA(Vec,PetscInt,const Vec[],PetscScalar*);
19 PETSC_EXTERN PetscErrorCode VecSet_SeqCUDA(Vec,PetscScalar);
20 PETSC_INTERN PetscErrorCode VecMAXPY_SeqCUDA(Vec,PetscInt,const PetscScalar*,Vec*);
21 PETSC_INTERN PetscErrorCode VecAXPBYPCZ_SeqCUDA(Vec,PetscScalar,PetscScalar,PetscScalar,Vec,Vec);
22 PETSC_INTERN PetscErrorCode VecPointwiseMult_SeqCUDA(Vec,Vec,Vec);
23 PETSC_INTERN PetscErrorCode VecPlaceArray_SeqCUDA(Vec,const PetscScalar*);
24 PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA(Vec);
25 PETSC_INTERN PetscErrorCode VecReplaceArray_SeqCUDA(Vec,const PetscScalar*);
26 PETSC_INTERN PetscErrorCode VecDot_SeqCUDA(Vec,Vec,PetscScalar*);
27 PETSC_INTERN PetscErrorCode VecTDot_SeqCUDA(Vec,Vec,PetscScalar*);
28 PETSC_INTERN PetscErrorCode VecScale_SeqCUDA(Vec,PetscScalar);
29 PETSC_EXTERN PetscErrorCode VecCopy_SeqCUDA(Vec,Vec);
30 PETSC_INTERN PetscErrorCode VecSwap_SeqCUDA(Vec,Vec);
31 PETSC_EXTERN PetscErrorCode VecAXPY_SeqCUDA(Vec,PetscScalar,Vec);
32 PETSC_INTERN PetscErrorCode VecAXPBY_SeqCUDA(Vec,PetscScalar,PetscScalar,Vec);
33 PETSC_INTERN PetscErrorCode VecDuplicate_SeqCUDA(Vec,Vec*);
34 PETSC_INTERN PetscErrorCode VecConjugate_SeqCUDA(Vec xin);
35 PETSC_INTERN PetscErrorCode VecNorm_SeqCUDA(Vec,NormType,PetscReal*);
36 PETSC_INTERN PetscErrorCode VecCUDACopyToGPU(Vec);
37 PETSC_INTERN PetscErrorCode VecCUDAAllocateCheck(Vec);
38 PETSC_EXTERN PetscErrorCode VecCreate_SeqCUDA(Vec);
39 PETSC_INTERN PetscErrorCode VecCreate_SeqCUDA_Private(Vec,const PetscScalar*);
40 PETSC_INTERN PetscErrorCode VecCreate_MPICUDA(Vec);
41 PETSC_INTERN PetscErrorCode VecCreate_MPICUDA_Private(Vec,PetscBool,PetscInt,const PetscScalar*);
42 PETSC_INTERN PetscErrorCode VecCreate_CUDA(Vec);
43 PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA(Vec);
44 PETSC_INTERN PetscErrorCode VecDestroy_MPICUDA(Vec);
45 PETSC_INTERN PetscErrorCode VecAYPX_SeqCUDA(Vec,PetscScalar,Vec);
46 PETSC_INTERN PetscErrorCode VecSetRandom_SeqCUDA(Vec,PetscRandom);
47 PETSC_INTERN PetscErrorCode VecGetLocalVector_SeqCUDA(Vec,Vec);
48 PETSC_INTERN PetscErrorCode VecRestoreLocalVector_SeqCUDA(Vec,Vec);
49 PETSC_INTERN PetscErrorCode VecGetArrayWrite_SeqCUDA(Vec,PetscScalar**);
50 PETSC_INTERN PetscErrorCode VecCopy_SeqCUDA_Private(Vec xin,Vec yin);
51 PETSC_INTERN PetscErrorCode VecSetRandom_SeqCUDA_Private(Vec xin,PetscRandom r);
52 PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA_Private(Vec v);
53 PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA_Private(Vec vin);
54 PETSC_INTERN PetscErrorCode VecCUDACopyToGPU_Public(Vec);
55 PETSC_INTERN PetscErrorCode VecCUDAAllocateCheck_Public(Vec);
56 PETSC_INTERN PetscErrorCode VecCUDACopyToGPUSome(Vec,PetscCUDAIndices,ScatterMode);
57 PETSC_INTERN PetscErrorCode VecCUDACopyFromGPUSome(Vec,PetscCUDAIndices,ScatterMode);
58 
59 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesCreate_PtoP(PetscInt, PetscInt*,PetscInt, PetscInt*,PetscCUDAIndices*);
60 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesCreate_StoS(PetscInt,PetscInt,PetscInt,PetscInt,PetscInt,PetscInt*,PetscInt*,PetscCUDAIndices*);
61 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesDestroy(PetscCUDAIndices*);
62 PETSC_INTERN PetscErrorCode VecScatterCUDA_StoS(Vec,Vec,PetscCUDAIndices,InsertMode,ScatterMode);
63 
64 typedef enum {VEC_SCATTER_CUDA_STOS, VEC_SCATTER_CUDA_PTOP} VecCUDAScatterType;
65 typedef enum {VEC_SCATTER_CUDA_GENERAL, VEC_SCATTER_CUDA_STRIDED} VecCUDASequentialScatterMode;
66 
67 struct  _p_VecScatterCUDAIndices_PtoP {
68   PetscInt ns;
69   PetscInt sendLowestIndex;
70   PetscInt nr;
71   PetscInt recvLowestIndex;
72 };
73 
74 struct  _p_VecScatterCUDAIndices_StoS {
75   /* from indices data */
76   PetscInt *fslots;
77   PetscInt fromFirst;
78   PetscInt fromStep;
79   VecCUDASequentialScatterMode fromMode;
80 
81   /* to indices data */
82   PetscInt *tslots;
83   PetscInt toFirst;
84   PetscInt toStep;
85   VecCUDASequentialScatterMode toMode;
86 
87   PetscInt n;
88   PetscInt MAX_BLOCKS;
89   PetscInt MAX_CORESIDENT_THREADS;
90   cudaStream_t stream;
91 };
92 
93 struct  _p_PetscCUDAIndices {
94   void * scatter;
95   VecCUDAScatterType scatterType;
96 };
97 
98 /* complex single */
99 #if defined(PETSC_USE_COMPLEX)
100 #if defined(PETSC_USE_REAL_SINGLE)
101 #define cublasXaxpy(a,b,c,d,e,f,g)               cublasCaxpy((a),(b),(cuComplex*)(c),(cuComplex*)(d),(e),(cuComplex*)(f),(g))
102 #define cublasXscal(a,b,c,d,e)                   cublasCscal((a),(b),(cuComplex*)(c),(cuComplex*)(d),(e))
103 #define cublasXdotu(a,b,c,d,e,f,g)               cublasCdotu((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f),(cuComplex*)(g))
104 #define cublasXdot(a,b,c,d,e,f,g)                cublasCdotc((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f),(cuComplex*)(g))
105 #define cublasXswap(a,b,c,d,e,f)                 cublasCswap((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f))
106 #define cublasXnrm2(a,b,c,d,e)                   cublasScnrm2((a),(b),(cuComplex*)(c),(d),(e))
107 #define cublasIXamax(a,b,c,d,e)                  cublasIcamax((a),(b),(cuComplex*)(c),(d),(e))
108 #define cublasXasum(a,b,c,d,e)                   cublasScasum((a),(b),(cuComplex*)(c),(d),(e))
109 #define cublasXgemv(a,b,c,d,e,f,g,h,i,j,k,l)     cublasCgemv((a),(b),(c),(d),(cuComplex*)(e),(cuComplex*)(f),(g),(cuComplex*)(h),(i),(cuComplex*)(j),(cuComplex*)(k),(l))
110 #define cublasXgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n) cublasCgemm((a),(b),(c),(d),(e),(f),(cuComplex*)(g),(cuComplex*)(h),(i),(cuComplex*)(j),(k),(cuComplex*)(l),(cuComplex*)(m),(n))
111 #define cublasXgeam(a,b,c,d,e,f,g,h,i,j,k,l,m)   cublasCgeam((a),(b),(c),(d),(e),(cuComplex*)(f),(cuComplex*)(g),(h),(cuComplex*)(i),(cuComplex*)(j),(k),(cuComplex*)(l),(m))
112 #else /* complex double */
113 #define cublasXaxpy(a,b,c,d,e,f,g)               cublasZaxpy((a),(b),(cuDoubleComplex*)(c),(cuDoubleComplex*)(d),(e),(cuDoubleComplex*)(f),(g))
114 #define cublasXscal(a,b,c,d,e)                   cublasZscal((a),(b),(cuDoubleComplex*)(c),(cuDoubleComplex*)(d),(e))
115 #define cublasXdotu(a,b,c,d,e,f,g)               cublasZdotu((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f),(cuDoubleComplex*)(g))
116 #define cublasXdot(a,b,c,d,e,f,g)                cublasZdotc((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f),(cuDoubleComplex*)(g))
117 #define cublasXswap(a,b,c,d,e,f)                 cublasZswap((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f))
118 #define cublasXnrm2(a,b,c,d,e)                   cublasDznrm2((a),(b),(cuDoubleComplex*)(c),(d),(e))
119 #define cublasIXamax(a,b,c,d,e)                  cublasIzamax((a),(b),(cuDoubleComplex*)(c),(d),(e))
120 #define cublasXasum(a,b,c,d,e)                   cublasDzasum((a),(b),(cuDoubleComplex*)(c),(d),(e))
121 #define cublasXgemv(a,b,c,d,e,f,g,h,i,j,k,l)     cublasZgemv((a),(b),(c),(d),(cuDoubleComplex*)(e),(cuDoubleComplex*)(f),(g),(cuDoubleComplex*)(h),(i),(cuDoubleComplex*)(j),(cuDoubleComplex*)(k),(l))
122 #define cublasXgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n) cublasZgemm((a),(b),(c),(d),(e),(f),(cuDoubleComplex*)(g),(cuDoubleComplex*)(h),(i),(cuDoubleComplex*)(j),(k),(cuDoubleComplex*)(l),(cuDoubleComplex*)(m),(n))
123 #define cublasXgeam(a,b,c,d,e,f,g,h,i,j,k,l,m)   cublasZgeam((a),(b),(c),(d),(e),(cuDoubleComplex*)(f),(cuDoubleComplex*)(g),(h),(cuDoubleComplex*)(i),(cuDoubleComplex*)(j),(k),(cuDoubleComplex*)(l),(m))
124 #endif
125 #else /* real single */
126 #if defined(PETSC_USE_REAL_SINGLE)
127 #define cublasXaxpy  cublasSaxpy
128 #define cublasXscal  cublasSscal
129 #define cublasXdotu  cublasSdot
130 #define cublasXdot   cublasSdot
131 #define cublasXswap  cublasSswap
132 #define cublasXnrm2  cublasSnrm2
133 #define cublasIXamax cublasIsamax
134 #define cublasXasum  cublasSasum
135 #define cublasXgemv  cublasSgemv
136 #define cublasXgemm  cublasSgemm
137 #define cublasXgeam  cublasSgeam
138 #else /* real double */
139 #define cublasXaxpy  cublasDaxpy
140 #define cublasXscal  cublasDscal
141 #define cublasXdotu  cublasDdot
142 #define cublasXdot   cublasDdot
143 #define cublasXswap  cublasDswap
144 #define cublasXnrm2  cublasDnrm2
145 #define cublasIXamax cublasIdamax
146 #define cublasXasum  cublasDasum
147 #define cublasXgemv  cublasDgemv
148 #define cublasXgemm  cublasDgemm
149 #define cublasXgeam  cublasDgeam
150 #endif
151 #endif
152 
153 #endif
154