1 #include <iostream>
2
3 #include "CodeGen_C.h"
4 #include "CodeGen_PyTorch.h"
5 #include "IROperator.h"
6 #include "Param.h"
7 #include "Util.h"
8 #include "Var.h"
9
10 namespace Halide {
11 namespace Internal {
12
CodeGen_PyTorch(std::ostream & s)13 CodeGen_PyTorch::CodeGen_PyTorch(std::ostream &s)
14 : IRPrinter(s) {
15 }
16
compile(const Module & module)17 void CodeGen_PyTorch::compile(const Module &module) {
18 const Target target = module.target();
19
20 if (target.has_feature(Target::CUDA)) {
21 if (!target.has_feature(Target::UserContext)) {
22 user_error << "Compile a PyTorch wrapper for a CUDA op requires the "
23 "UserContext feature to properly manage the GPU memory. "
24 "Please add \"-user_context\" to the generator's target options.\n";
25 }
26 stream << "#include \"ATen/cuda/CUDAContext.h\"\n";
27 stream << "#include \"HalideBuffer.h\"\n";
28 stream << "#include \"HalidePyTorchCudaHelpers.h\"\n";
29 stream << "#include \"HalidePyTorchHelpers.h\"\n";
30 stream << "#include \"torch/extension.h\"\n";
31 } else {
32 stream << "#include \"HalideBuffer.h\"\n";
33 stream << "#include \"HalidePyTorchHelpers.h\"\n";
34 stream << "#include \"torch/extension.h\"\n";
35 }
36
37 stream << "\n";
38
39 // Emit extern decls of the Halide-generated functions we use directly
40 // into this file, so that we don't have to #include the relevant .h
41 // file directly; this simplifies certain compile/build setups (since
42 // we don't have to build files in tandem and/or get include paths right),
43 // and should be totally safe, since we are using the same codegen logic
44 // that would be in the .h file anyway.
45 {
46 CodeGen_C extern_decl_gen(stream, module.target(), CodeGen_C::CPlusPlusExternDecl);
47 extern_decl_gen.compile(module);
48 }
49
50 for (const auto &f : module.functions()) {
51 if (target.has_feature(Target::CUDA)) {
52 compile(f, true);
53 } else {
54 compile(f, false);
55 }
56 }
57 }
58
compile(const LoweredFunc & f,bool is_cuda)59 void CodeGen_PyTorch::compile(const LoweredFunc &f, bool is_cuda) {
60 // Don't put non-external function declarations in headers.
61 std::vector<std::string> namespaces;
62 std::string simple_name = extract_namespaces(f.name, namespaces);
63
64 if (!namespaces.empty()) {
65 for (const auto &ns : namespaces) {
66 stream << "namespace " << ns << " {\n";
67 }
68 stream << "\n";
69 }
70 const std::vector<LoweredArgument> &args = f.args;
71 std::vector<LoweredArgument> buffer_args;
72
73 stream << "HALIDE_FUNCTION_ATTRS\n";
74 stream << "inline int " << simple_name << "_th_(";
75 for (size_t i = 0; i < args.size(); i++) {
76 if (args[i].name == "__user_context") {
77 continue;
78 } else if (args[i].is_buffer()) {
79 buffer_args.push_back(args[i]);
80 stream
81 << "at::Tensor &"
82 << c_print_name(args[i].name);
83 } else {
84 stream
85 << type_to_c_type(args[i].type, true)
86 << c_print_name(args[i].name);
87 }
88
89 if (i < args.size() - 1)
90 stream << ", ";
91 }
92
93 stream << ") {\n";
94 indent += 4;
95
96 if (is_cuda) {
97 stream << get_indent() << "// Setup CUDA\n";
98 stream << get_indent() << "int device_id = at::cuda::current_device();\n";
99 stream << get_indent() << "CUcontext ctx = 0;\n";
100 stream << get_indent() << "CUresult res = cuCtxGetCurrent(&ctx);\n";
101 stream << get_indent() << "AT_ASSERTM(res == 0, \"Could not acquire CUDA context\");\n";
102 stream << get_indent() << "cudaStream_t stream = at::cuda::getCurrentCUDAStream(device_id);\n";
103 stream << get_indent() << "Halide::PyTorch::UserContext user_ctx(device_id, &ctx, &stream);\n";
104 stream << get_indent() << "void* __user_context = (void*) &user_ctx;\n\n";
105 }
106
107 stream << get_indent() << "// Check tensors have contiguous memory and are on the correct device\n";
108 for (size_t i = 0; i < buffer_args.size(); i++) {
109 stream << get_indent();
110 stream
111 << "HLPT_CHECK_CONTIGUOUS("
112 << c_print_name(buffer_args[i].name)
113 << ");\n";
114
115 if (is_cuda) {
116 stream << get_indent();
117 stream
118 << "HLPT_CHECK_DEVICE("
119 << c_print_name(buffer_args[i].name)
120 << ", device_id);\n";
121 }
122 }
123 stream << "\n";
124
125 stream << get_indent() << "// Wrap tensors in Halide buffers\n";
126 for (size_t i = 0; i < buffer_args.size(); i++) {
127 if (!buffer_args[i].is_buffer())
128 continue;
129
130 stream << get_indent();
131 std::string tp = type_to_c_type(buffer_args[i].type, false);
132 stream
133 << "Halide::Runtime::Buffer<" << tp << "> "
134 << c_print_name(buffer_args[i].name)
135 << "_buffer = Halide::PyTorch::wrap<" << tp << ">("
136 << c_print_name(buffer_args[i].name)
137 << ");\n";
138 }
139 stream << "\n";
140
141 stream << get_indent() << "// Run Halide pipeline\n";
142
143 stream << get_indent() << "int err = " << simple_name << "(";
144 for (size_t i = 0; i < args.size(); i++) {
145 if (args[i].is_buffer()) {
146 stream
147 << c_print_name(args[i].name)
148 << "_buffer";
149 } else {
150 stream << c_print_name(args[i].name);
151 }
152 if (i < args.size() - 1)
153 stream << ", ";
154 }
155 stream << ");\n";
156
157 stream << "\n";
158
159 stream << get_indent() << "AT_ASSERTM(err == 0, \"Halide call failed\");\n";
160
161 if (is_cuda) {
162 stream << get_indent() << "// Make sure data is on device\n";
163 for (size_t i = 0; i < buffer_args.size(); i++) {
164 if (buffer_args[i].is_buffer()) {
165 stream << get_indent();
166 stream
167 << "AT_ASSERTM(!"
168 << c_print_name(buffer_args[i].name) << "_buffer.host_dirty(),"
169 << "\"device not synchronized for buffer "
170 << c_print_name(buffer_args[i].name)
171 << ", make sure all update stages are excplicitly computed on GPU."
172 << "\");\n";
173 stream << get_indent();
174 stream
175 << c_print_name(buffer_args[i].name) << "_buffer"
176 << ".device_detach_native();\n";
177 }
178 }
179 stream << "\n";
180 }
181
182 // TODO(mgharbi): this is not very well documented
183 if (get_env_variable("FLUSH_MEMOIZE_CACHE") == "1") {
184 stream << get_indent() << "// Flush cache\n";
185 if (is_cuda) {
186 stream << get_indent() << "halide_memoization_cache_cleanup(__user_context);\n";
187 } else {
188 stream << get_indent() << "halide_memoization_cache_cleanup(NULL);\n";
189 }
190 }
191
192 stream << get_indent() << "return 0;\n";
193
194 indent -= 4;
195 stream << "}\n";
196
197 if (!namespaces.empty()) {
198 stream << "\n";
199 for (size_t i = namespaces.size(); i > 0; i--) {
200 stream << "} // namespace " << namespaces[i - 1] << "\n";
201 }
202 stream << "\n";
203 }
204 }
205
test()206 void CodeGen_PyTorch::test() {
207 // Dummy Halide pipeline
208 LoweredArgument buffer_arg("buf", Argument::OutputBuffer, Int(32), 3, ArgumentEstimates{});
209 LoweredArgument float_arg("alpha", Argument::InputScalar, Float(32), 0, ArgumentEstimates{});
210 LoweredArgument int_arg("beta", Argument::InputScalar, Int(32), 0, ArgumentEstimates{});
211 std::vector<LoweredArgument> args = {buffer_arg, float_arg, int_arg};
212 Var x("x");
213 Param<float> alpha("alpha");
214 Param<int> beta("beta");
215 Expr e = Add::make(alpha, Cast::make(Float(32), beta));
216 Stmt s = Store::make("buf", e, x, Parameter(), const_true(), ModulusRemainder());
217 Expr buf = Variable::make(Handle(), "buf.buffer");
218 s = LetStmt::make("buf", Call::make(Handle(), Call::buffer_get_host, {buf}, Call::Extern), s);
219
220 const auto compare_src = [&](const std::string &src, const std::string &correct_src) {
221 if (src != correct_src) {
222 int diff = 0;
223 while (src[diff] == correct_src[diff]) {
224 diff++;
225 }
226 int diff_end = diff + 1;
227 while (diff > 0 && src[diff] != '\n') {
228 diff--;
229 }
230 while (diff_end < (int)src.size() && src[diff_end] != '\n') {
231 diff_end++;
232 }
233
234 internal_error
235 << "Correct source code:\n"
236 << correct_src
237 << "Actual source code:\n"
238 << src
239 << "Difference starts at:" << diff << "\n"
240 << "Correct: " << correct_src.substr(diff, diff_end - diff) << "\n"
241 << "Actual: " << src.substr(diff, diff_end - diff) << "\n";
242 }
243 };
244
245 {
246 // TODO(mgharbi): test that Target("host-cuda") raises an exception since
247 // we require the "user_context" feature when using CUDA
248
249 Module m("", Target("host"));
250 m.append(LoweredFunc("test1", args, s, LinkageType::External));
251
252 std::ostringstream src;
253 CodeGen_PyTorch(src).compile(m);
254
255 std::string correct_src =
256 R"GOLDEN_CODE(#include "HalideBuffer.h"
257 #include "HalidePyTorchHelpers.h"
258 #include "torch/extension.h"
259
260 struct halide_buffer_t;
261 struct halide_filter_metadata_t;
262
263 #ifndef HALIDE_MUST_USE_RESULT
264 #ifdef __has_attribute
265 #if __has_attribute(nodiscard)
266 #define HALIDE_MUST_USE_RESULT [[nodiscard]]
267 #elif __has_attribute(warn_unused_result)
268 #define HALIDE_MUST_USE_RESULT __attribute__((warn_unused_result))
269 #else
270 #define HALIDE_MUST_USE_RESULT
271 #endif
272 #else
273 #define HALIDE_MUST_USE_RESULT
274 #endif
275 #endif
276
277 #ifndef HALIDE_FUNCTION_ATTRS
278 #define HALIDE_FUNCTION_ATTRS
279 #endif
280
281
282
283 #ifdef __cplusplus
284 extern "C" {
285 #endif
286
287 HALIDE_FUNCTION_ATTRS
288 int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta);
289
290 #ifdef __cplusplus
291 } // extern "C"
292 #endif
293
294 HALIDE_FUNCTION_ATTRS
295 inline int test1_th_(at::Tensor &_buf, float _alpha, int32_t _beta) {
296 // Check tensors have contiguous memory and are on the correct device
297 HLPT_CHECK_CONTIGUOUS(_buf);
298
299 // Wrap tensors in Halide buffers
300 Halide::Runtime::Buffer<int32_t> _buf_buffer = Halide::PyTorch::wrap<int32_t>(_buf);
301
302 // Run Halide pipeline
303 int err = test1(_buf_buffer, _alpha, _beta);
304
305 AT_ASSERTM(err == 0, "Halide call failed");
306 return 0;
307 }
308 )GOLDEN_CODE";
309
310 compare_src(src.str(), correct_src);
311 }
312
313 Target host_cuda("host-cuda-user_context");
314 if (host_supports_target_device(host_cuda)) {
315 Module m("", host_cuda);
316 m.append(LoweredFunc("test1", args, s, LinkageType::External));
317
318 std::ostringstream src;
319 CodeGen_PyTorch(src).compile(m);
320
321 std::string correct_src =
322 R"GOLDEN_CODE(#include "ATen/cuda/CUDAContext.h"
323 #include "HalideBuffer.h"
324 #include "HalidePyTorchCudaHelpers.h"
325 #include "HalidePyTorchHelpers.h"
326 #include "torch/extension.h"
327
328 struct halide_buffer_t;
329 struct halide_filter_metadata_t;
330
331 #ifndef HALIDE_MUST_USE_RESULT
332 #ifdef __has_attribute
333 #if __has_attribute(nodiscard)
334 #define HALIDE_MUST_USE_RESULT [[nodiscard]]
335 #elif __has_attribute(warn_unused_result)
336 #define HALIDE_MUST_USE_RESULT __attribute__((warn_unused_result))
337 #else
338 #define HALIDE_MUST_USE_RESULT
339 #endif
340 #else
341 #define HALIDE_MUST_USE_RESULT
342 #endif
343 #endif
344
345 #ifndef HALIDE_FUNCTION_ATTRS
346 #define HALIDE_FUNCTION_ATTRS
347 #endif
348
349
350
351 #ifdef __cplusplus
352 extern "C" {
353 #endif
354
355 HALIDE_FUNCTION_ATTRS
356 int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta);
357
358 #ifdef __cplusplus
359 } // extern "C"
360 #endif
361
362 HALIDE_FUNCTION_ATTRS
363 inline int test1_th_(at::Tensor &_buf, float _alpha, int32_t _beta) {
364 // Setup CUDA
365 int device_id = at::cuda::current_device();
366 CUcontext ctx = 0;
367 CUresult res = cuCtxGetCurrent(&ctx);
368 AT_ASSERTM(res == 0, "Could not acquire CUDA context");
369 cudaStream_t stream = at::cuda::getCurrentCUDAStream(device_id);
370 Halide::PyTorch::UserContext user_ctx(device_id, &ctx, &stream);
371 void* __user_context = (void*) &user_ctx;
372
373 // Check tensors have contiguous memory and are on the correct device
374 HLPT_CHECK_CONTIGUOUS(_buf);
375 HLPT_CHECK_DEVICE(_buf, device_id);
376
377 // Wrap tensors in Halide buffers
378 Halide::Runtime::Buffer<int32_t> _buf_buffer = Halide::PyTorch::wrap<int32_t>(_buf);
379
380 // Run Halide pipeline
381 int err = test1(_buf_buffer, _alpha, _beta);
382
383 AT_ASSERTM(err == 0, "Halide call failed");
384 // Make sure data is on device
385 AT_ASSERTM(!_buf_buffer.host_dirty(),"device not synchronized for buffer _buf, make sure all update stages are excplicitly computed on GPU.");
386 _buf_buffer.device_detach_native();
387
388 return 0;
389 }
390 )GOLDEN_CODE";
391
392 compare_src(src.str(), correct_src);
393 } else {
394 user_warning << "Host does not support " << host_cuda << ", skipping part of test";
395 }
396
397 std::cout << "CodeGen_PyTorch test passed\n";
398 }
399
400 } // namespace Internal
401 } // namespace Halide
402