1 // Copyright (C) 2017  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_
4 #define DLIB_DNN_CuDA_DATA_PTR_CPP_
5 
6 #ifdef DLIB_USE_CUDA
7 
8 #include "cuda_data_ptr.h"
9 #include "cuda_utils.h"
10 
11 namespace dlib
12 {
13     namespace cuda
14     {
15 
16     // ----------------------------------------------------------------------------------------
17 
18         weak_cuda_data_void_ptr::
weak_cuda_data_void_ptr(const cuda_data_void_ptr & ptr)19         weak_cuda_data_void_ptr(
20             const cuda_data_void_ptr& ptr
21         ) : num(ptr.num), pdata(ptr.pdata)
22         {
23 
24         }
25 
26     // ----------------------------------------------------------------------------------------
27 
28         cuda_data_void_ptr weak_cuda_data_void_ptr::
lock() const29         lock() const
30         {
31             auto ptr = pdata.lock();
32             if (ptr)
33             {
34                 cuda_data_void_ptr temp;
35                 temp.pdata = ptr;
36                 temp.num = num;
37                 return temp;
38             }
39             else
40             {
41                 return cuda_data_void_ptr();
42             }
43         }
44 
45     // -----------------------------------------------------------------------------------
46     // -----------------------------------------------------------------------------------
47 
48         cuda_data_void_ptr::
cuda_data_void_ptr(size_t n)49         cuda_data_void_ptr(
50             size_t n
51         ) : num(n)
52         {
53             if (n == 0)
54                 return;
55 
56             void* data = nullptr;
57 
58             CHECK_CUDA(cudaMalloc(&data, n));
59             pdata.reset(data, [](void* ptr){
60                 auto err = cudaFree(ptr);
61                 if(err!=cudaSuccess)
62                 std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl;
63             });
64         }
65 
66     // ------------------------------------------------------------------------------------
67 
memcpy(void * dest,const cuda_data_void_ptr & src,const size_t num)68         void memcpy(
69             void* dest,
70             const cuda_data_void_ptr& src,
71             const size_t num
72         )
73         {
74             DLIB_ASSERT(num <= src.size());
75             if (src.size() != 0)
76             {
77                 CHECK_CUDA(cudaMemcpy(dest, src.data(),  num, cudaMemcpyDefault));
78             }
79         }
80 
81     // ------------------------------------------------------------------------------------
82 
memcpy(void * dest,const cuda_data_void_ptr & src)83         void memcpy(
84             void* dest,
85             const cuda_data_void_ptr& src
86         )
87         {
88             memcpy(dest, src, src.size());
89         }
90 
91     // ------------------------------------------------------------------------------------
92 
memcpy(cuda_data_void_ptr dest,const void * src,const size_t num)93         void memcpy(
94             cuda_data_void_ptr dest,
95             const void* src,
96             const size_t num
97         )
98         {
99             DLIB_ASSERT(num <= dest.size());
100             if (dest.size() != 0)
101             {
102                 CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault));
103             }
104         }
105 
106     // ------------------------------------------------------------------------------------
107 
memcpy(cuda_data_void_ptr dest,const void * src)108         void memcpy(
109             cuda_data_void_ptr dest,
110             const void* src
111         )
112         {
113             memcpy(dest,src,dest.size());
114         }
115 
116     // ------------------------------------------------------------------------------------
117 
118         class cudnn_device_buffer
119         {
120         public:
121             // not copyable
122             cudnn_device_buffer(const cudnn_device_buffer&) = delete;
123             cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete;
124 
cudnn_device_buffer()125             cudnn_device_buffer()
126             {
127                 buffers.resize(16);
128             }
~cudnn_device_buffer()129             ~cudnn_device_buffer()
130             {
131             }
132 
get(size_t size)133             cuda_data_void_ptr get (
134                 size_t size
135             )
136             {
137                 int new_device_id;
138                 CHECK_CUDA(cudaGetDevice(&new_device_id));
139                 // make room for more devices if needed
140                 if (new_device_id >= (long)buffers.size())
141                     buffers.resize(new_device_id+16);
142 
143                 // If we don't have a buffer already for this device then make one, or if it's too
144                 // small, make a bigger one.
145                 cuda_data_void_ptr buff = buffers[new_device_id].lock();
146                 if (!buff || buff.size() < size)
147                 {
148                     buff = cuda_data_void_ptr(size);
149                     buffers[new_device_id] = buff;
150                 }
151 
152                 // Finally, return the buffer for the current device
153                 return buff;
154             }
155 
156         private:
157 
158             std::vector<weak_cuda_data_void_ptr> buffers;
159         };
160 
161     // ----------------------------------------------------------------------------------------
162 
device_global_buffer(size_t size)163         cuda_data_void_ptr device_global_buffer(size_t size)
164         {
165             thread_local cudnn_device_buffer buffer;
166             return buffer.get(size);
167         }
168 
169     // ------------------------------------------------------------------------------------
170 
171     }
172 }
173 
174 #endif // DLIB_USE_CUDA
175 
176 #endif // DLIB_DNN_CuDA_DATA_PTR_CPP_
177 
178 
179