1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file cuda_common.h
22  * \brief Common utilities for CUDA
23  */
24 #ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_
25 #define TVM_RUNTIME_CUDA_CUDA_COMMON_H_
26 
27 #include <cuda_runtime.h>
28 #include <tvm/runtime/packed_func.h>
29 
30 #include <string>
31 
32 #include "../workspace_pool.h"
33 
34 namespace tvm {
35 namespace runtime {
36 
37 #define CUDA_DRIVER_CALL(x)                                             \
38   {                                                                     \
39     CUresult result = x;                                                \
40     if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
41       const char* msg;                                                  \
42       cuGetErrorName(result, &msg);                                     \
43       LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg;     \
44     }                                                                   \
45   }
46 
47 #define CUDA_CALL(func)                                                                            \
48   {                                                                                                \
49     cudaError_t e = (func);                                                                        \
50     CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \
51   }
52 
53 /*! \brief Thread local workspace */
54 class CUDAThreadEntry {
55  public:
56   /*! \brief The cuda stream */
57   cudaStream_t stream{nullptr};
58   /*! \brief thread local pool*/
59   WorkspacePool pool;
60   /*! \brief constructor */
61   CUDAThreadEntry();
62   // get the threadlocal workspace
63   static CUDAThreadEntry* ThreadLocal();
64 };
65 }  // namespace runtime
66 }  // namespace tvm
67 #endif  // TVM_RUNTIME_CUDA_CUDA_COMMON_H_
68