1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <cuda_profiler_api.h>
9 #include <faiss/gpu/utils/DeviceUtils.h>
10 #include <faiss/impl/FaissAssert.h>
11 #include <faiss/gpu/utils/DeviceDefs.cuh>
12 #include <mutex>
13 #include <unordered_map>
14
15 namespace faiss {
16 namespace gpu {
17
getCurrentDevice()18 int getCurrentDevice() {
19 int dev = -1;
20 CUDA_VERIFY(cudaGetDevice(&dev));
21 FAISS_ASSERT(dev != -1);
22
23 return dev;
24 }
25
setCurrentDevice(int device)26 void setCurrentDevice(int device) {
27 CUDA_VERIFY(cudaSetDevice(device));
28 }
29
getNumDevices()30 int getNumDevices() {
31 int numDev = -1;
32 cudaError_t err = cudaGetDeviceCount(&numDev);
33 if (cudaErrorNoDevice == err) {
34 numDev = 0;
35 } else {
36 CUDA_VERIFY(err);
37 }
38 FAISS_ASSERT(numDev != -1);
39
40 return numDev;
41 }
42
profilerStart()43 void profilerStart() {
44 CUDA_VERIFY(cudaProfilerStart());
45 }
46
profilerStop()47 void profilerStop() {
48 CUDA_VERIFY(cudaProfilerStop());
49 }
50
synchronizeAllDevices()51 void synchronizeAllDevices() {
52 for (int i = 0; i < getNumDevices(); ++i) {
53 DeviceScope scope(i);
54
55 CUDA_VERIFY(cudaDeviceSynchronize());
56 }
57 }
58
getDeviceProperties(int device)59 const cudaDeviceProp& getDeviceProperties(int device) {
60 static std::mutex mutex;
61 static std::unordered_map<int, cudaDeviceProp> properties;
62
63 std::lock_guard<std::mutex> guard(mutex);
64
65 auto it = properties.find(device);
66 if (it == properties.end()) {
67 cudaDeviceProp prop;
68 CUDA_VERIFY(cudaGetDeviceProperties(&prop, device));
69
70 properties[device] = prop;
71 it = properties.find(device);
72 }
73
74 return it->second;
75 }
76
getCurrentDeviceProperties()77 const cudaDeviceProp& getCurrentDeviceProperties() {
78 return getDeviceProperties(getCurrentDevice());
79 }
80
getMaxThreads(int device)81 int getMaxThreads(int device) {
82 return getDeviceProperties(device).maxThreadsPerBlock;
83 }
84
getMaxThreadsCurrentDevice()85 int getMaxThreadsCurrentDevice() {
86 return getMaxThreads(getCurrentDevice());
87 }
88
getMaxSharedMemPerBlock(int device)89 size_t getMaxSharedMemPerBlock(int device) {
90 return getDeviceProperties(device).sharedMemPerBlock;
91 }
92
getMaxSharedMemPerBlockCurrentDevice()93 size_t getMaxSharedMemPerBlockCurrentDevice() {
94 return getMaxSharedMemPerBlock(getCurrentDevice());
95 }
96
getDeviceForAddress(const void * p)97 int getDeviceForAddress(const void* p) {
98 if (!p) {
99 return -1;
100 }
101
102 cudaPointerAttributes att;
103 cudaError_t err = cudaPointerGetAttributes(&att, p);
104 FAISS_ASSERT_FMT(
105 err == cudaSuccess || err == cudaErrorInvalidValue,
106 "unknown error %d",
107 (int)err);
108
109 if (err == cudaErrorInvalidValue) {
110 // Make sure the current thread error status has been reset
111 err = cudaGetLastError();
112 FAISS_ASSERT_FMT(
113 err == cudaErrorInvalidValue, "unknown error %d", (int)err);
114 return -1;
115 }
116
117 // memoryType is deprecated for CUDA 10.0+
118 #if CUDA_VERSION < 10000
119 if (att.memoryType == cudaMemoryTypeHost) {
120 return -1;
121 } else {
122 return att.device;
123 }
124 #else
125 // FIXME: what to use for managed memory?
126 if (att.type == cudaMemoryTypeDevice) {
127 return att.device;
128 } else {
129 return -1;
130 }
131 #endif
132 }
133
getFullUnifiedMemSupport(int device)134 bool getFullUnifiedMemSupport(int device) {
135 const auto& prop = getDeviceProperties(device);
136 return (prop.major >= 6);
137 }
138
getFullUnifiedMemSupportCurrentDevice()139 bool getFullUnifiedMemSupportCurrentDevice() {
140 return getFullUnifiedMemSupport(getCurrentDevice());
141 }
142
getTensorCoreSupport(int device)143 bool getTensorCoreSupport(int device) {
144 const auto& prop = getDeviceProperties(device);
145 return (prop.major >= 7);
146 }
147
getTensorCoreSupportCurrentDevice()148 bool getTensorCoreSupportCurrentDevice() {
149 return getTensorCoreSupport(getCurrentDevice());
150 }
151
getMaxKSelection()152 int getMaxKSelection() {
153 // Don't use the device at the moment, just base this based on the CUDA SDK
154 // that we were compiled with
155 return GPU_MAX_SELECTION_K;
156 }
157
DeviceScope(int device)158 DeviceScope::DeviceScope(int device) {
159 if (device >= 0) {
160 int curDevice = getCurrentDevice();
161
162 if (curDevice != device) {
163 prevDevice_ = curDevice;
164 setCurrentDevice(device);
165 return;
166 }
167 }
168
169 // Otherwise, we keep the current device
170 prevDevice_ = -1;
171 }
172
~DeviceScope()173 DeviceScope::~DeviceScope() {
174 if (prevDevice_ != -1) {
175 setCurrentDevice(prevDevice_);
176 }
177 }
178
CublasHandleScope()179 CublasHandleScope::CublasHandleScope() {
180 auto blasStatus = cublasCreate(&blasHandle_);
181 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
182 }
183
~CublasHandleScope()184 CublasHandleScope::~CublasHandleScope() {
185 auto blasStatus = cublasDestroy(blasHandle_);
186 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
187 }
188
CudaEvent(cudaStream_t stream,bool timer)189 CudaEvent::CudaEvent(cudaStream_t stream, bool timer) : event_(0) {
190 CUDA_VERIFY(cudaEventCreateWithFlags(
191 &event_, timer ? cudaEventDefault : cudaEventDisableTiming));
192 CUDA_VERIFY(cudaEventRecord(event_, stream));
193 }
194
CudaEvent(CudaEvent && event)195 CudaEvent::CudaEvent(CudaEvent&& event) noexcept
196 : event_(std::move(event.event_)) {
197 event.event_ = 0;
198 }
199
~CudaEvent()200 CudaEvent::~CudaEvent() {
201 if (event_) {
202 CUDA_VERIFY(cudaEventDestroy(event_));
203 }
204 }
205
operator =(CudaEvent && event)206 CudaEvent& CudaEvent::operator=(CudaEvent&& event) noexcept {
207 event_ = std::move(event.event_);
208 event.event_ = 0;
209
210 return *this;
211 }
212
streamWaitOnEvent(cudaStream_t stream)213 void CudaEvent::streamWaitOnEvent(cudaStream_t stream) {
214 CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0));
215 }
216
cpuWaitOnEvent()217 void CudaEvent::cpuWaitOnEvent() {
218 CUDA_VERIFY(cudaEventSynchronize(event_));
219 }
220
221 } // namespace gpu
222 } // namespace faiss
223