1 #if !defined(__CUDAVECIMPL) 2 #define __CUDAVECIMPL 3 4 #include <petscvec.h> 5 #include <petsccublas.h> 6 #include <petsc/private/vecimpl.h> 7 8 typedef struct { 9 PetscScalar *GPUarray; /* this always holds the GPU data */ 10 PetscScalar *GPUarray_allocated; /* if the array was allocated by PETSc this is its pointer */ 11 cudaStream_t stream; /* A stream for doing asynchronous data transfers */ 12 } Vec_CUDA; 13 14 PETSC_INTERN PetscErrorCode VecCUDAGetArrays_Private(Vec,const PetscScalar**,const PetscScalar**,PetscOffloadMask*); 15 PETSC_INTERN PetscErrorCode VecDotNorm2_SeqCUDA(Vec,Vec,PetscScalar*, PetscScalar*); 16 PETSC_INTERN PetscErrorCode VecPointwiseDivide_SeqCUDA(Vec,Vec,Vec); 17 PETSC_INTERN PetscErrorCode VecWAXPY_SeqCUDA(Vec,PetscScalar,Vec,Vec); 18 PETSC_INTERN PetscErrorCode VecMDot_SeqCUDA(Vec,PetscInt,const Vec[],PetscScalar*); 19 PETSC_EXTERN PetscErrorCode VecSet_SeqCUDA(Vec,PetscScalar); 20 PETSC_INTERN PetscErrorCode VecMAXPY_SeqCUDA(Vec,PetscInt,const PetscScalar*,Vec*); 21 PETSC_INTERN PetscErrorCode VecAXPBYPCZ_SeqCUDA(Vec,PetscScalar,PetscScalar,PetscScalar,Vec,Vec); 22 PETSC_INTERN PetscErrorCode VecPointwiseMult_SeqCUDA(Vec,Vec,Vec); 23 PETSC_INTERN PetscErrorCode VecPlaceArray_SeqCUDA(Vec,const PetscScalar*); 24 PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA(Vec); 25 PETSC_INTERN PetscErrorCode VecReplaceArray_SeqCUDA(Vec,const PetscScalar*); 26 PETSC_INTERN PetscErrorCode VecDot_SeqCUDA(Vec,Vec,PetscScalar*); 27 PETSC_INTERN PetscErrorCode VecTDot_SeqCUDA(Vec,Vec,PetscScalar*); 28 PETSC_INTERN PetscErrorCode VecScale_SeqCUDA(Vec,PetscScalar); 29 PETSC_EXTERN PetscErrorCode VecCopy_SeqCUDA(Vec,Vec); 30 PETSC_INTERN PetscErrorCode VecSwap_SeqCUDA(Vec,Vec); 31 PETSC_EXTERN PetscErrorCode VecAXPY_SeqCUDA(Vec,PetscScalar,Vec); 32 PETSC_INTERN PetscErrorCode VecAXPBY_SeqCUDA(Vec,PetscScalar,PetscScalar,Vec); 33 PETSC_INTERN PetscErrorCode VecDuplicate_SeqCUDA(Vec,Vec*); 34 PETSC_INTERN PetscErrorCode VecConjugate_SeqCUDA(Vec xin); 35 PETSC_INTERN PetscErrorCode VecNorm_SeqCUDA(Vec,NormType,PetscReal*); 36 PETSC_INTERN PetscErrorCode VecCUDACopyToGPU(Vec); 37 PETSC_INTERN PetscErrorCode VecCUDAAllocateCheck(Vec); 38 PETSC_EXTERN PetscErrorCode VecCreate_SeqCUDA(Vec); 39 PETSC_INTERN PetscErrorCode VecCreate_SeqCUDA_Private(Vec,const PetscScalar*); 40 PETSC_INTERN PetscErrorCode VecCreate_MPICUDA(Vec); 41 PETSC_INTERN PetscErrorCode VecCreate_MPICUDA_Private(Vec,PetscBool,PetscInt,const PetscScalar*); 42 PETSC_INTERN PetscErrorCode VecCreate_CUDA(Vec); 43 PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA(Vec); 44 PETSC_INTERN PetscErrorCode VecDestroy_MPICUDA(Vec); 45 PETSC_INTERN PetscErrorCode VecAYPX_SeqCUDA(Vec,PetscScalar,Vec); 46 PETSC_INTERN PetscErrorCode VecSetRandom_SeqCUDA(Vec,PetscRandom); 47 PETSC_INTERN PetscErrorCode VecGetLocalVector_SeqCUDA(Vec,Vec); 48 PETSC_INTERN PetscErrorCode VecRestoreLocalVector_SeqCUDA(Vec,Vec); 49 PETSC_INTERN PetscErrorCode VecGetArrayWrite_SeqCUDA(Vec,PetscScalar**); 50 PETSC_INTERN PetscErrorCode VecCopy_SeqCUDA_Private(Vec xin,Vec yin); 51 PETSC_INTERN PetscErrorCode VecSetRandom_SeqCUDA_Private(Vec xin,PetscRandom r); 52 PETSC_INTERN PetscErrorCode VecDestroy_SeqCUDA_Private(Vec v); 53 PETSC_INTERN PetscErrorCode VecResetArray_SeqCUDA_Private(Vec vin); 54 PETSC_INTERN PetscErrorCode VecCUDACopyToGPU_Public(Vec); 55 PETSC_INTERN PetscErrorCode VecCUDAAllocateCheck_Public(Vec); 56 PETSC_INTERN PetscErrorCode VecCUDACopyToGPUSome(Vec,PetscCUDAIndices,ScatterMode); 57 PETSC_INTERN PetscErrorCode VecCUDACopyFromGPUSome(Vec,PetscCUDAIndices,ScatterMode); 58 59 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesCreate_PtoP(PetscInt, PetscInt*,PetscInt, PetscInt*,PetscCUDAIndices*); 60 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesCreate_StoS(PetscInt,PetscInt,PetscInt,PetscInt,PetscInt,PetscInt*,PetscInt*,PetscCUDAIndices*); 61 PETSC_INTERN PetscErrorCode VecScatterCUDAIndicesDestroy(PetscCUDAIndices*); 62 PETSC_INTERN PetscErrorCode VecScatterCUDA_StoS(Vec,Vec,PetscCUDAIndices,InsertMode,ScatterMode); 63 64 typedef enum {VEC_SCATTER_CUDA_STOS, VEC_SCATTER_CUDA_PTOP} VecCUDAScatterType; 65 typedef enum {VEC_SCATTER_CUDA_GENERAL, VEC_SCATTER_CUDA_STRIDED} VecCUDASequentialScatterMode; 66 67 struct _p_VecScatterCUDAIndices_PtoP { 68 PetscInt ns; 69 PetscInt sendLowestIndex; 70 PetscInt nr; 71 PetscInt recvLowestIndex; 72 }; 73 74 struct _p_VecScatterCUDAIndices_StoS { 75 /* from indices data */ 76 PetscInt *fslots; 77 PetscInt fromFirst; 78 PetscInt fromStep; 79 VecCUDASequentialScatterMode fromMode; 80 81 /* to indices data */ 82 PetscInt *tslots; 83 PetscInt toFirst; 84 PetscInt toStep; 85 VecCUDASequentialScatterMode toMode; 86 87 PetscInt n; 88 PetscInt MAX_BLOCKS; 89 PetscInt MAX_CORESIDENT_THREADS; 90 cudaStream_t stream; 91 }; 92 93 struct _p_PetscCUDAIndices { 94 void * scatter; 95 VecCUDAScatterType scatterType; 96 }; 97 98 /* complex single */ 99 #if defined(PETSC_USE_COMPLEX) 100 #if defined(PETSC_USE_REAL_SINGLE) 101 #define cublasXaxpy(a,b,c,d,e,f,g) cublasCaxpy((a),(b),(cuComplex*)(c),(cuComplex*)(d),(e),(cuComplex*)(f),(g)) 102 #define cublasXscal(a,b,c,d,e) cublasCscal((a),(b),(cuComplex*)(c),(cuComplex*)(d),(e)) 103 #define cublasXdotu(a,b,c,d,e,f,g) cublasCdotu((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f),(cuComplex*)(g)) 104 #define cublasXdot(a,b,c,d,e,f,g) cublasCdotc((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f),(cuComplex*)(g)) 105 #define cublasXswap(a,b,c,d,e,f) cublasCswap((a),(b),(cuComplex*)(c),(d),(cuComplex*)(e),(f)) 106 #define cublasXnrm2(a,b,c,d,e) cublasScnrm2((a),(b),(cuComplex*)(c),(d),(e)) 107 #define cublasIXamax(a,b,c,d,e) cublasIcamax((a),(b),(cuComplex*)(c),(d),(e)) 108 #define cublasXasum(a,b,c,d,e) cublasScasum((a),(b),(cuComplex*)(c),(d),(e)) 109 #define cublasXgemv(a,b,c,d,e,f,g,h,i,j,k,l) cublasCgemv((a),(b),(c),(d),(cuComplex*)(e),(cuComplex*)(f),(g),(cuComplex*)(h),(i),(cuComplex*)(j),(cuComplex*)(k),(l)) 110 #define cublasXgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n) cublasCgemm((a),(b),(c),(d),(e),(f),(cuComplex*)(g),(cuComplex*)(h),(i),(cuComplex*)(j),(k),(cuComplex*)(l),(cuComplex*)(m),(n)) 111 #define cublasXgeam(a,b,c,d,e,f,g,h,i,j,k,l,m) cublasCgeam((a),(b),(c),(d),(e),(cuComplex*)(f),(cuComplex*)(g),(h),(cuComplex*)(i),(cuComplex*)(j),(k),(cuComplex*)(l),(m)) 112 #else /* complex double */ 113 #define cublasXaxpy(a,b,c,d,e,f,g) cublasZaxpy((a),(b),(cuDoubleComplex*)(c),(cuDoubleComplex*)(d),(e),(cuDoubleComplex*)(f),(g)) 114 #define cublasXscal(a,b,c,d,e) cublasZscal((a),(b),(cuDoubleComplex*)(c),(cuDoubleComplex*)(d),(e)) 115 #define cublasXdotu(a,b,c,d,e,f,g) cublasZdotu((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f),(cuDoubleComplex*)(g)) 116 #define cublasXdot(a,b,c,d,e,f,g) cublasZdotc((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f),(cuDoubleComplex*)(g)) 117 #define cublasXswap(a,b,c,d,e,f) cublasZswap((a),(b),(cuDoubleComplex*)(c),(d),(cuDoubleComplex*)(e),(f)) 118 #define cublasXnrm2(a,b,c,d,e) cublasDznrm2((a),(b),(cuDoubleComplex*)(c),(d),(e)) 119 #define cublasIXamax(a,b,c,d,e) cublasIzamax((a),(b),(cuDoubleComplex*)(c),(d),(e)) 120 #define cublasXasum(a,b,c,d,e) cublasDzasum((a),(b),(cuDoubleComplex*)(c),(d),(e)) 121 #define cublasXgemv(a,b,c,d,e,f,g,h,i,j,k,l) cublasZgemv((a),(b),(c),(d),(cuDoubleComplex*)(e),(cuDoubleComplex*)(f),(g),(cuDoubleComplex*)(h),(i),(cuDoubleComplex*)(j),(cuDoubleComplex*)(k),(l)) 122 #define cublasXgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n) cublasZgemm((a),(b),(c),(d),(e),(f),(cuDoubleComplex*)(g),(cuDoubleComplex*)(h),(i),(cuDoubleComplex*)(j),(k),(cuDoubleComplex*)(l),(cuDoubleComplex*)(m),(n)) 123 #define cublasXgeam(a,b,c,d,e,f,g,h,i,j,k,l,m) cublasZgeam((a),(b),(c),(d),(e),(cuDoubleComplex*)(f),(cuDoubleComplex*)(g),(h),(cuDoubleComplex*)(i),(cuDoubleComplex*)(j),(k),(cuDoubleComplex*)(l),(m)) 124 #endif 125 #else /* real single */ 126 #if defined(PETSC_USE_REAL_SINGLE) 127 #define cublasXaxpy cublasSaxpy 128 #define cublasXscal cublasSscal 129 #define cublasXdotu cublasSdot 130 #define cublasXdot cublasSdot 131 #define cublasXswap cublasSswap 132 #define cublasXnrm2 cublasSnrm2 133 #define cublasIXamax cublasIsamax 134 #define cublasXasum cublasSasum 135 #define cublasXgemv cublasSgemv 136 #define cublasXgemm cublasSgemm 137 #define cublasXgeam cublasSgeam 138 #else /* real double */ 139 #define cublasXaxpy cublasDaxpy 140 #define cublasXscal cublasDscal 141 #define cublasXdotu cublasDdot 142 #define cublasXdot cublasDdot 143 #define cublasXswap cublasDswap 144 #define cublasXnrm2 cublasDnrm2 145 #define cublasIXamax cublasIdamax 146 #define cublasXasum cublasDasum 147 #define cublasXgemv cublasDgemv 148 #define cublasXgemm cublasDgemm 149 #define cublasXgeam cublasDgeam 150 #endif 151 #endif 152 153 #endif 154