1 #ifndef HL_PYTORCH_CUDA_HELPERS_H
2 #define HL_PYTORCH_CUDA_HELPERS_H
3 
4 /** \file
5  * Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6  * the correct GPU device and stream.
7  */
8 
9 #ifdef HL_PT_CUDA
10 #include "HalideRuntimeCuda.h"
11 #include "cuda.h"
12 
13 namespace Halide {
14 namespace PyTorch {
15 
16 typedef struct UserContext {
UserContextUserContext17     UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
18         : device_id(id), cuda_context(ctx), stream(stream){};
19 
20     int device_id;
21     CUcontext *cuda_context;
22     cudaStream_t *stream;
23 } UserContext;
24 
25 }  // namespace PyTorch
26 }  // namespace Halide
27 
28 // Replace Halide weakly-linked CUDA handles
29 extern "C" {
30 
31 int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
32     if (user_context != NULL) {
33         Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
34         *ctx = *user_ctx->cuda_context;
35     } else {
36         *ctx = NULL;
37     }
38     return 0;
39 }
40 
halide_cuda_get_stream(void * user_context,CUcontext ctx,CUstream * stream)41 int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
42     if (user_context != NULL) {
43         Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
44         *stream = *user_ctx->stream;
45     } else {
46         *stream = 0;
47     }
48     return 0;
49 }
50 
halide_get_gpu_device(void * user_context)51 int halide_get_gpu_device(void *user_context) {
52     if (user_context != NULL) {
53         Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
54         return user_ctx->device_id;
55     } else {
56         return 0;
57     }
58 }
59 
60 }  // extern "C"
61 
62 #endif  // HL_PT_CUDA
63 
64 #endif /* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
65