1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file cuda_common.h 22 * \brief Common utilities for CUDA 23 */ 24 #ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_ 25 #define TVM_RUNTIME_CUDA_CUDA_COMMON_H_ 26 27 #include <cuda_runtime.h> 28 #include <tvm/runtime/packed_func.h> 29 30 #include <string> 31 32 #include "../workspace_pool.h" 33 34 namespace tvm { 35 namespace runtime { 36 37 #define CUDA_DRIVER_CALL(x) \ 38 { \ 39 CUresult result = x; \ 40 if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ 41 const char* msg; \ 42 cuGetErrorName(result, &msg); \ 43 LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \ 44 } \ 45 } 46 47 #define CUDA_CALL(func) \ 48 { \ 49 cudaError_t e = (func); \ 50 CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ 51 } 52 53 /*! \brief Thread local workspace */ 54 class CUDAThreadEntry { 55 public: 56 /*! \brief The cuda stream */ 57 cudaStream_t stream{nullptr}; 58 /*! \brief thread local pool*/ 59 WorkspacePool pool; 60 /*! \brief constructor */ 61 CUDAThreadEntry(); 62 // get the threadlocal workspace 63 static CUDAThreadEntry* ThreadLocal(); 64 }; 65 } // namespace runtime 66 } // namespace tvm 67 #endif // TVM_RUNTIME_CUDA_CUDA_COMMON_H_ 68