1 #include <iostream>
2 
3 #include "CodeGen_C.h"
4 #include "CodeGen_PyTorch.h"
5 #include "IROperator.h"
6 #include "Param.h"
7 #include "Util.h"
8 #include "Var.h"
9 
10 namespace Halide {
11 namespace Internal {
12 
CodeGen_PyTorch(std::ostream & s)13 CodeGen_PyTorch::CodeGen_PyTorch(std::ostream &s)
14     : IRPrinter(s) {
15 }
16 
compile(const Module & module)17 void CodeGen_PyTorch::compile(const Module &module) {
18     const Target target = module.target();
19 
20     if (target.has_feature(Target::CUDA)) {
21         if (!target.has_feature(Target::UserContext)) {
22             user_error << "Compile a PyTorch wrapper for a CUDA op requires the "
23                           "UserContext feature to properly manage the GPU memory. "
24                           "Please add \"-user_context\" to the generator's target options.\n";
25         }
26         stream << "#include \"ATen/cuda/CUDAContext.h\"\n";
27         stream << "#include \"HalideBuffer.h\"\n";
28         stream << "#include \"HalidePyTorchCudaHelpers.h\"\n";
29         stream << "#include \"HalidePyTorchHelpers.h\"\n";
30         stream << "#include \"torch/extension.h\"\n";
31     } else {
32         stream << "#include \"HalideBuffer.h\"\n";
33         stream << "#include \"HalidePyTorchHelpers.h\"\n";
34         stream << "#include \"torch/extension.h\"\n";
35     }
36 
37     stream << "\n";
38 
39     // Emit extern decls of the Halide-generated functions we use directly
40     // into this file, so that we don't have to #include the relevant .h
41     // file directly; this simplifies certain compile/build setups (since
42     // we don't have to build files in tandem and/or get include paths right),
43     // and should be totally safe, since we are using the same codegen logic
44     // that would be in the .h file anyway.
45     {
46         CodeGen_C extern_decl_gen(stream, module.target(), CodeGen_C::CPlusPlusExternDecl);
47         extern_decl_gen.compile(module);
48     }
49 
50     for (const auto &f : module.functions()) {
51         if (target.has_feature(Target::CUDA)) {
52             compile(f, true);
53         } else {
54             compile(f, false);
55         }
56     }
57 }
58 
compile(const LoweredFunc & f,bool is_cuda)59 void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
60     // Don't put non-external function declarations in headers.
61     std::vector<std::string> namespaces;
62     std::string simple_name = extract_namespaces(f.name, namespaces);
63 
64     if (!namespaces.empty()) {
65         for (const auto &ns : namespaces) {
66             stream << "namespace " << ns << " {\n";
67         }
68         stream << "\n";
69     }
70     const std::vector<LoweredArgument> &args = f.args;
71     std::vector<LoweredArgument> buffer_args;
72 
73     stream << "HALIDE_FUNCTION_ATTRS\n";
74     stream << "inline int " << simple_name << "_th_(";
75     for (size_t i = 0; i < args.size(); i++) {
76         if (args[i].name == "__user_context") {
77             continue;
78         } else if (args[i].is_buffer()) {
79             buffer_args.push_back(args[i]);
80             stream
81                 << "at::Tensor &"
82                 << c_print_name(args[i].name);
83         } else {
84             stream
85                 << type_to_c_type(args[i].type, true)
86                 << c_print_name(args[i].name);
87         }
88 
89         if (i < args.size() - 1)
90             stream << ", ";
91     }
92 
93     stream << ") {\n";
94     indent += 4;
95 
96     if (is_cuda) {
97         stream << get_indent() << "// Setup CUDA\n";
98         stream << get_indent() << "int device_id = at::cuda::current_device();\n";
99         stream << get_indent() << "CUcontext ctx = 0;\n";
100         stream << get_indent() << "CUresult res = cuCtxGetCurrent(&ctx);\n";
101         stream << get_indent() << "AT_ASSERTM(res == 0, \"Could not acquire CUDA context\");\n";
102         stream << get_indent() << "cudaStream_t stream = at::cuda::getCurrentCUDAStream(device_id);\n";
103         stream << get_indent() << "Halide::PyTorch::UserContext user_ctx(device_id, &ctx, &stream);\n";
104         stream << get_indent() << "void* __user_context = (void*) &user_ctx;\n\n";
105     }
106 
107     stream << get_indent() << "// Check tensors have contiguous memory and are on the correct device\n";
108     for (size_t i = 0; i < buffer_args.size(); i++) {
109         stream << get_indent();
110         stream
111             << "HLPT_CHECK_CONTIGUOUS("
112             << c_print_name(buffer_args[i].name)
113             << ");\n";
114 
115         if (is_cuda) {
116             stream << get_indent();
117             stream
118                 << "HLPT_CHECK_DEVICE("
119                 << c_print_name(buffer_args[i].name)
120                 << ", device_id);\n";
121         }
122     }
123     stream << "\n";
124 
125     stream << get_indent() << "// Wrap tensors in Halide buffers\n";
126     for (size_t i = 0; i < buffer_args.size(); i++) {
127         if (!buffer_args[i].is_buffer())
128             continue;
129 
130         stream << get_indent();
131         std::string tp = type_to_c_type(buffer_args[i].type, false);
132         stream
133             << "Halide::Runtime::Buffer<" << tp << "> "
134             << c_print_name(buffer_args[i].name)
135             << "_buffer = Halide::PyTorch::wrap<" << tp << ">("
136             << c_print_name(buffer_args[i].name)
137             << ");\n";
138     }
139     stream << "\n";
140 
141     stream << get_indent() << "// Run Halide pipeline\n";
142 
143     stream << get_indent() << "int err = " << simple_name << "(";
144     for (size_t i = 0; i < args.size(); i++) {
145         if (args[i].is_buffer()) {
146             stream
147                 << c_print_name(args[i].name)
148                 << "_buffer";
149         } else {
150             stream << c_print_name(args[i].name);
151         }
152         if (i < args.size() - 1)
153             stream << ", ";
154     }
155     stream << ");\n";
156 
157     stream << "\n";
158 
159     stream << get_indent() << "AT_ASSERTM(err == 0, \"Halide call failed\");\n";
160 
161     if (is_cuda) {
162         stream << get_indent() << "// Make sure data is on device\n";
163         for (size_t i = 0; i < buffer_args.size(); i++) {
164             if (buffer_args[i].is_buffer()) {
165                 stream << get_indent();
166                 stream
167                     << "AT_ASSERTM(!"
168                     << c_print_name(buffer_args[i].name) << "_buffer.host_dirty(),"
169                     << "\"device not synchronized for buffer "
170                     << c_print_name(buffer_args[i].name)
171                     << ", make sure all update stages are excplicitly computed on GPU."
172                     << "\");\n";
173                 stream << get_indent();
174                 stream
175                     << c_print_name(buffer_args[i].name) << "_buffer"
176                     << ".device_detach_native();\n";
177             }
178         }
179         stream << "\n";
180     }
181 
182     // TODO(mgharbi): this is not very well documented
183     if (get_env_variable("FLUSH_MEMOIZE_CACHE") == "1") {
184         stream << get_indent() << "// Flush cache\n";
185         if (is_cuda) {
186             stream << get_indent() << "halide_memoization_cache_cleanup(__user_context);\n";
187         } else {
188             stream << get_indent() << "halide_memoization_cache_cleanup(NULL);\n";
189         }
190     }
191 
192     stream << get_indent() << "return 0;\n";
193 
194     indent -= 4;
195     stream << "}\n";
196 
197     if (!namespaces.empty()) {
198         stream << "\n";
199         for (size_t i = namespaces.size(); i > 0; i--) {
200             stream << "}  // namespace " << namespaces[i - 1] << "\n";
201         }
202         stream << "\n";
203     }
204 }
205 
test()206 void CodeGen_PyTorch::test() {
207     // Dummy Halide pipeline
208     LoweredArgument buffer_arg("buf", Argument::OutputBuffer, Int(32), 3, ArgumentEstimates{});
209     LoweredArgument float_arg("alpha", Argument::InputScalar, Float(32), 0, ArgumentEstimates{});
210     LoweredArgument int_arg("beta", Argument::InputScalar, Int(32), 0, ArgumentEstimates{});
211     std::vector<LoweredArgument> args = {buffer_arg, float_arg, int_arg};
212     Var x("x");
213     Param<float> alpha("alpha");
214     Param<int> beta("beta");
215     Expr e = Add::make(alpha, Cast::make(Float(32), beta));
216     Stmt s = Store::make("buf", e, x, Parameter(), const_true(), ModulusRemainder());
217     Expr buf = Variable::make(Handle(), "buf.buffer");
218     s = LetStmt::make("buf", Call::make(Handle(), Call::buffer_get_host, {buf}, Call::Extern), s);
219 
220     const auto compare_src = [&](const std::string &src, const std::string &correct_src) {
221         if (src != correct_src) {
222             int diff = 0;
223             while (src[diff] == correct_src[diff]) {
224                 diff++;
225             }
226             int diff_end = diff + 1;
227             while (diff > 0 && src[diff] != '\n') {
228                 diff--;
229             }
230             while (diff_end < (int)src.size() && src[diff_end] != '\n') {
231                 diff_end++;
232             }
233 
234             internal_error
235                 << "Correct source code:\n"
236                 << correct_src
237                 << "Actual source code:\n"
238                 << src
239                 << "Difference starts at:" << diff << "\n"
240                 << "Correct: " << correct_src.substr(diff, diff_end - diff) << "\n"
241                 << "Actual: " << src.substr(diff, diff_end - diff) << "\n";
242         }
243     };
244 
245     {
246         // TODO(mgharbi): test that Target("host-cuda") raises an exception since
247         // we require the "user_context" feature when using CUDA
248 
249         Module m("", Target("host"));
250         m.append(LoweredFunc("test1", args, s, LinkageType::External));
251 
252         std::ostringstream src;
253         CodeGen_PyTorch(src).compile(m);
254 
255         std::string correct_src =
256             R"GOLDEN_CODE(#include "HalideBuffer.h"
257 #include "HalidePyTorchHelpers.h"
258 #include "torch/extension.h"
259 
260 struct halide_buffer_t;
261 struct halide_filter_metadata_t;
262 
263 #ifndef HALIDE_MUST_USE_RESULT
264 #ifdef __has_attribute
265 #if __has_attribute(nodiscard)
266 #define HALIDE_MUST_USE_RESULT [[nodiscard]]
267 #elif __has_attribute(warn_unused_result)
268 #define HALIDE_MUST_USE_RESULT __attribute__((warn_unused_result))
269 #else
270 #define HALIDE_MUST_USE_RESULT
271 #endif
272 #else
273 #define HALIDE_MUST_USE_RESULT
274 #endif
275 #endif
276 
277 #ifndef HALIDE_FUNCTION_ATTRS
278 #define HALIDE_FUNCTION_ATTRS
279 #endif
280 
281 
282 
283 #ifdef __cplusplus
284 extern "C" {
285 #endif
286 
287 HALIDE_FUNCTION_ATTRS
288 int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta);
289 
290 #ifdef __cplusplus
291 }  // extern "C"
292 #endif
293 
294 HALIDE_FUNCTION_ATTRS
295 inline int test1_th_(at::Tensor &_buf, float _alpha, int32_t _beta) {
296     // Check tensors have contiguous memory and are on the correct device
297     HLPT_CHECK_CONTIGUOUS(_buf);
298 
299     // Wrap tensors in Halide buffers
300     Halide::Runtime::Buffer<int32_t> _buf_buffer = Halide::PyTorch::wrap<int32_t>(_buf);
301 
302     // Run Halide pipeline
303     int err = test1(_buf_buffer, _alpha, _beta);
304 
305     AT_ASSERTM(err == 0, "Halide call failed");
306     return 0;
307 }
308 )GOLDEN_CODE";
309 
310         compare_src(src.str(), correct_src);
311     }
312 
313     Target host_cuda("host-cuda-user_context");
314     if (host_supports_target_device(host_cuda)) {
315         Module m("", host_cuda);
316         m.append(LoweredFunc("test1", args, s, LinkageType::External));
317 
318         std::ostringstream src;
319         CodeGen_PyTorch(src).compile(m);
320 
321         std::string correct_src =
322             R"GOLDEN_CODE(#include "ATen/cuda/CUDAContext.h"
323 #include "HalideBuffer.h"
324 #include "HalidePyTorchCudaHelpers.h"
325 #include "HalidePyTorchHelpers.h"
326 #include "torch/extension.h"
327 
328 struct halide_buffer_t;
329 struct halide_filter_metadata_t;
330 
331 #ifndef HALIDE_MUST_USE_RESULT
332 #ifdef __has_attribute
333 #if __has_attribute(nodiscard)
334 #define HALIDE_MUST_USE_RESULT [[nodiscard]]
335 #elif __has_attribute(warn_unused_result)
336 #define HALIDE_MUST_USE_RESULT __attribute__((warn_unused_result))
337 #else
338 #define HALIDE_MUST_USE_RESULT
339 #endif
340 #else
341 #define HALIDE_MUST_USE_RESULT
342 #endif
343 #endif
344 
345 #ifndef HALIDE_FUNCTION_ATTRS
346 #define HALIDE_FUNCTION_ATTRS
347 #endif
348 
349 
350 
351 #ifdef __cplusplus
352 extern "C" {
353 #endif
354 
355 HALIDE_FUNCTION_ATTRS
356 int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta);
357 
358 #ifdef __cplusplus
359 }  // extern "C"
360 #endif
361 
362 HALIDE_FUNCTION_ATTRS
363 inline int test1_th_(at::Tensor &_buf, float _alpha, int32_t _beta) {
364     // Setup CUDA
365     int device_id = at::cuda::current_device();
366     CUcontext ctx = 0;
367     CUresult res = cuCtxGetCurrent(&ctx);
368     AT_ASSERTM(res == 0, "Could not acquire CUDA context");
369     cudaStream_t stream = at::cuda::getCurrentCUDAStream(device_id);
370     Halide::PyTorch::UserContext user_ctx(device_id, &ctx, &stream);
371     void* __user_context = (void*) &user_ctx;
372 
373     // Check tensors have contiguous memory and are on the correct device
374     HLPT_CHECK_CONTIGUOUS(_buf);
375     HLPT_CHECK_DEVICE(_buf, device_id);
376 
377     // Wrap tensors in Halide buffers
378     Halide::Runtime::Buffer<int32_t> _buf_buffer = Halide::PyTorch::wrap<int32_t>(_buf);
379 
380     // Run Halide pipeline
381     int err = test1(_buf_buffer, _alpha, _beta);
382 
383     AT_ASSERTM(err == 0, "Halide call failed");
384     // Make sure data is on device
385     AT_ASSERTM(!_buf_buffer.host_dirty(),"device not synchronized for buffer _buf, make sure all update stages are excplicitly computed on GPU.");
386     _buf_buffer.device_detach_native();
387 
388     return 0;
389 }
390 )GOLDEN_CODE";
391 
392         compare_src(src.str(), correct_src);
393     } else {
394         user_warning << "Host does not support " << host_cuda << ", skipping part of test";
395     }
396 
397     std::cout << "CodeGen_PyTorch test passed\n";
398 }
399 
400 }  // namespace Internal
401 }  // namespace Halide
402