1 // Copyright (C) 2017 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_DNN_CuDA_DATA_PTR_CPP_ 4 #define DLIB_DNN_CuDA_DATA_PTR_CPP_ 5 6 #ifdef DLIB_USE_CUDA 7 8 #include "cuda_data_ptr.h" 9 #include "cuda_utils.h" 10 11 namespace dlib 12 { 13 namespace cuda 14 { 15 16 // ---------------------------------------------------------------------------------------- 17 18 weak_cuda_data_void_ptr:: weak_cuda_data_void_ptr(const cuda_data_void_ptr & ptr)19 weak_cuda_data_void_ptr( 20 const cuda_data_void_ptr& ptr 21 ) : num(ptr.num), pdata(ptr.pdata) 22 { 23 24 } 25 26 // ---------------------------------------------------------------------------------------- 27 28 cuda_data_void_ptr weak_cuda_data_void_ptr:: lock() const29 lock() const 30 { 31 auto ptr = pdata.lock(); 32 if (ptr) 33 { 34 cuda_data_void_ptr temp; 35 temp.pdata = ptr; 36 temp.num = num; 37 return temp; 38 } 39 else 40 { 41 return cuda_data_void_ptr(); 42 } 43 } 44 45 // ----------------------------------------------------------------------------------- 46 // ----------------------------------------------------------------------------------- 47 48 cuda_data_void_ptr:: cuda_data_void_ptr(size_t n)49 cuda_data_void_ptr( 50 size_t n 51 ) : num(n) 52 { 53 if (n == 0) 54 return; 55 56 void* data = nullptr; 57 58 CHECK_CUDA(cudaMalloc(&data, n)); 59 pdata.reset(data, [](void* ptr){ 60 auto err = cudaFree(ptr); 61 if(err!=cudaSuccess) 62 std::cerr << "cudaFree() failed. Reason: " << cudaGetErrorString(err) << std::endl; 63 }); 64 } 65 66 // ------------------------------------------------------------------------------------ 67 memcpy(void * dest,const cuda_data_void_ptr & src,const size_t num)68 void memcpy( 69 void* dest, 70 const cuda_data_void_ptr& src, 71 const size_t num 72 ) 73 { 74 DLIB_ASSERT(num <= src.size()); 75 if (src.size() != 0) 76 { 77 CHECK_CUDA(cudaMemcpy(dest, src.data(), num, cudaMemcpyDefault)); 78 } 79 } 80 81 // ------------------------------------------------------------------------------------ 82 memcpy(void * dest,const cuda_data_void_ptr & src)83 void memcpy( 84 void* dest, 85 const cuda_data_void_ptr& src 86 ) 87 { 88 memcpy(dest, src, src.size()); 89 } 90 91 // ------------------------------------------------------------------------------------ 92 memcpy(cuda_data_void_ptr dest,const void * src,const size_t num)93 void memcpy( 94 cuda_data_void_ptr dest, 95 const void* src, 96 const size_t num 97 ) 98 { 99 DLIB_ASSERT(num <= dest.size()); 100 if (dest.size() != 0) 101 { 102 CHECK_CUDA(cudaMemcpy(dest.data(), src, num, cudaMemcpyDefault)); 103 } 104 } 105 106 // ------------------------------------------------------------------------------------ 107 memcpy(cuda_data_void_ptr dest,const void * src)108 void memcpy( 109 cuda_data_void_ptr dest, 110 const void* src 111 ) 112 { 113 memcpy(dest,src,dest.size()); 114 } 115 116 // ------------------------------------------------------------------------------------ 117 118 class cudnn_device_buffer 119 { 120 public: 121 // not copyable 122 cudnn_device_buffer(const cudnn_device_buffer&) = delete; 123 cudnn_device_buffer& operator=(const cudnn_device_buffer&) = delete; 124 cudnn_device_buffer()125 cudnn_device_buffer() 126 { 127 buffers.resize(16); 128 } ~cudnn_device_buffer()129 ~cudnn_device_buffer() 130 { 131 } 132 get(size_t size)133 cuda_data_void_ptr get ( 134 size_t size 135 ) 136 { 137 int new_device_id; 138 CHECK_CUDA(cudaGetDevice(&new_device_id)); 139 // make room for more devices if needed 140 if (new_device_id >= (long)buffers.size()) 141 buffers.resize(new_device_id+16); 142 143 // If we don't have a buffer already for this device then make one, or if it's too 144 // small, make a bigger one. 145 cuda_data_void_ptr buff = buffers[new_device_id].lock(); 146 if (!buff || buff.size() < size) 147 { 148 buff = cuda_data_void_ptr(size); 149 buffers[new_device_id] = buff; 150 } 151 152 // Finally, return the buffer for the current device 153 return buff; 154 } 155 156 private: 157 158 std::vector<weak_cuda_data_void_ptr> buffers; 159 }; 160 161 // ---------------------------------------------------------------------------------------- 162 device_global_buffer(size_t size)163 cuda_data_void_ptr device_global_buffer(size_t size) 164 { 165 thread_local cudnn_device_buffer buffer; 166 return buffer.get(size); 167 } 168 169 // ------------------------------------------------------------------------------------ 170 171 } 172 } 173 174 #endif // DLIB_USE_CUDA 175 176 #endif // DLIB_DNN_CuDA_DATA_PTR_CPP_ 177 178 179