1 /*------------------------------------------------------------------------------------------------*
2  * Copyright (C) by the DBCSR developers group - All rights reserved                              *
3  * This file is part of the DBCSR library.                                                        *
4  *                                                                                                *
5  * For information on the license, see the LICENSE file.                                          *
6  * For further information please visit https://dbcsr.cp2k.org                                    *
7  * SPDX-License-Identifier: GPL-2.0+                                                              *
8  *------------------------------------------------------------------------------------------------*/
9 
10 #ifdef __CUDA
11 #include "cuda/acc_cuda.h"
12 #else
13 #include "hip/acc_hip.h"
14 #endif
15 
16 #include <stdio.h>
17 #include <math.h>
18 #include "acc_error.h"
19 #include "include/acc.h"
20 
21 static const int verbose_print = 0;
22 
23 
24 /****************************************************************************/
acc_dev_mem_allocate(void ** dev_mem,size_t n)25 extern "C" int acc_dev_mem_allocate(void **dev_mem, size_t n){
26   ACC_API_CALL(Malloc, ((void **) dev_mem, (size_t) n));
27   if (dev_mem == NULL)
28     return -2;
29   if (verbose_print)
30     printf ("Device allocation address %p, size %ld\n", *dev_mem, (long) n);
31 
32   return 0;
33 }
34 
35 
36 /****************************************************************************/
acc_dev_mem_deallocate(void * dev_mem)37 extern "C" int acc_dev_mem_deallocate(void *dev_mem){
38   if (verbose_print)
39     printf ("Device deallocation address %p\n", dev_mem);
40   ACC_API_CALL(Free, ((void *) dev_mem));
41 
42   return 0;
43 }
44 
45 
46 /****************************************************************************/
acc_host_mem_allocate(void ** host_mem,size_t n,void * stream)47 extern "C" int acc_host_mem_allocate(void **host_mem, size_t n, void *stream){
48   unsigned int flag = ACC(HostAllocDefault);
49 
50   ACC_API_CALL(HostAlloc, ((void **) host_mem, (size_t) n, flag));
51   if (host_mem == NULL)
52     return -2;
53   if (verbose_print)
54     printf ("Allocating %zd bytes of host pinned memory at %p\n", n, *host_mem);
55 
56   return 0;
57 }
58 
59 
60 /****************************************************************************/
acc_host_mem_deallocate(void * host_mem,void * stream)61 extern "C" int acc_host_mem_deallocate(void *host_mem, void *stream){
62   if (verbose_print)
63     printf ("Host pinned deallocation address %p\n", host_mem);
64   ACC_API_CALL(FreeHost, ((void *) host_mem));
65 
66   return 0;
67 }
68 
69 /****************************************************************************/
acc_dev_mem_set_ptr(void ** dev_mem,void * other,size_t lb)70 extern "C" int acc_dev_mem_set_ptr(void **dev_mem, void *other, size_t lb){
71 
72   (*dev_mem) = ((char *) other) + lb;
73 
74   return 0;
75 }
76 
77 /****************************************************************************/
acc_memcpy_h2d(const void * host_mem,void * dev_mem,size_t count,void * stream)78 extern "C" int acc_memcpy_h2d(const void *host_mem, void *dev_mem, size_t count, void* stream){
79   ACC(Stream_t)* acc_stream = (ACC(Stream_t)*) stream;
80   if (verbose_print)
81       printf ("Copying %zd bytes from host address %p to device address %p \n", count, host_mem, dev_mem);
82 
83   ACC_API_CALL(MemcpyAsync, (dev_mem, host_mem, count, ACC(MemcpyHostToDevice), *acc_stream));
84 
85   return 0;
86 }
87 
88 
89 /****************************************************************************/
acc_memcpy_d2h(const void * dev_mem,void * host_mem,size_t count,void * stream)90 extern "C" int acc_memcpy_d2h(const void *dev_mem, void *host_mem, size_t count, void* stream){
91   ACC(Stream_t)* acc_stream = (ACC(Stream_t)*) stream;
92   if (verbose_print)
93       printf ("Copying %zd bytes from device address %p to host address %p\n", count, dev_mem, host_mem);
94 
95   ACC_API_CALL(MemcpyAsync, (host_mem, dev_mem, count, ACC(MemcpyDeviceToHost), *acc_stream));
96 
97   if (verbose_print)
98     printf ("d2h %f\n", *((double *) host_mem));
99 
100   return 0;
101 }
102 
103 
104 /****************************************************************************/
acc_memcpy_d2d(const void * devmem_src,void * devmem_dst,size_t count,void * stream)105 extern "C" int acc_memcpy_d2d(const void *devmem_src, void *devmem_dst, size_t count, void* stream){
106   ACC(Stream_t)* acc_stream = (ACC(Stream_t)*) stream;
107   if (verbose_print)
108       printf ("Copying %zd bytes from device address %p to device address %p \n", count, devmem_src, devmem_dst);
109 
110 
111   if(stream == NULL){
112       ACC_API_CALL(Memcpy, (devmem_dst, devmem_src, count, ACC(MemcpyDeviceToDevice)));
113   } else {
114       ACC_API_CALL(MemcpyAsync, (devmem_dst, devmem_src, count, ACC(MemcpyDeviceToDevice), *acc_stream));
115   }
116 
117   return 0;
118 }
119 
120 
121 /****************************************************************************/
acc_memset_zero(void * dev_mem,size_t offset,size_t length,void * stream)122 extern "C" int acc_memset_zero(void *dev_mem, size_t offset, size_t length, void* stream){
123   ACC(Error_t) cErr;
124   ACC(Stream_t)* acc_stream = (ACC(Stream_t)*) stream;
125   if(stream == NULL){
126       cErr = ACC(Memset)((void *) (((char *) dev_mem) + offset), (int) 0, length);
127   } else {
128       cErr = ACC(MemsetAsync)((void *) (((char *) dev_mem) + offset), (int) 0, length, *acc_stream);
129   }
130 
131   if (verbose_print)
132     printf ("Zero at device address %p, offset %d, len %d\n",
133      dev_mem, (int) offset, (int) length);
134   if (acc_error_check(cErr))
135     return -1;
136   if (acc_error_check(ACC(GetLastError)()))
137     return -1;
138 
139   return 0;
140 }
141 
142 
143 /****************************************************************************/
acc_dev_mem_info(size_t * free,size_t * avail)144 extern "C" int acc_dev_mem_info(size_t* free, size_t* avail){
145   ACC_API_CALL(MemGetInfo, (free, avail));
146   return 0;
147 }
148