1 #ifndef HL_PYTORCH_CUDA_HELPERS_H
2 #define HL_PYTORCH_CUDA_HELPERS_H
3
4 /** \file
5 * Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6 * the correct GPU device and stream.
7 */
8
9 #ifdef HL_PT_CUDA
10 #include "HalideRuntimeCuda.h"
11 #include "cuda.h"
12
13 namespace Halide {
14 namespace PyTorch {
15
16 typedef struct UserContext {
UserContextUserContext17 UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
18 : device_id(id), cuda_context(ctx), stream(stream){};
19
20 int device_id;
21 CUcontext *cuda_context;
22 cudaStream_t *stream;
23 } UserContext;
24
25 } // namespace PyTorch
26 } // namespace Halide
27
28 // Replace Halide weakly-linked CUDA handles
29 extern "C" {
30
31 int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
32 if (user_context != NULL) {
33 Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
34 *ctx = *user_ctx->cuda_context;
35 } else {
36 *ctx = NULL;
37 }
38 return 0;
39 }
40
halide_cuda_get_stream(void * user_context,CUcontext ctx,CUstream * stream)41 int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
42 if (user_context != NULL) {
43 Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
44 *stream = *user_ctx->stream;
45 } else {
46 *stream = 0;
47 }
48 return 0;
49 }
50
halide_get_gpu_device(void * user_context)51 int halide_get_gpu_device(void *user_context) {
52 if (user_context != NULL) {
53 Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
54 return user_ctx->device_id;
55 } else {
56 return 0;
57 }
58 }
59
60 } // extern "C"
61
62 #endif // HL_PT_CUDA
63
64 #endif /* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
65