1 // Note that this deliberately does *not* include PyHalide.h,
2 // or depend on any of the code in src: this is intended to be
3 // a minimal, generic wrapper to expose an arbitrary Generator
4 // for stub usage in Python.
5
6 #include <pybind11/pybind11.h>
7 #include <pybind11/stl.h>
8
9 #include <string>
10 #include <utility>
11
12 #include <vector>
13
14 #include "Halide.h"
15
16 namespace py = pybind11;
17
18 using FactoryFunc = std::unique_ptr<Halide::Internal::GeneratorBase> (*)(const Halide::GeneratorContext &context);
19
20 namespace Halide {
21 namespace PythonBindings {
22
23 using GeneratorParamsMap = Internal::GeneratorParamsMap;
24 using Stub = Internal::GeneratorStub;
25 using StubInput = Internal::StubInput;
26 using StubInputBuffer = Internal::StubInputBuffer<void>;
27
28 namespace {
29
30 // This seems redundant to the code in PyError.cpp, but is necessary
31 // in case the Stub builder links in a separate copy of libHalide, rather
32 // sharing the same halide.so that is built by default.
halide_python_error(void *,const char * msg)33 void halide_python_error(void *, const char *msg) {
34 throw Error(msg);
35 }
36
halide_python_print(void *,const char * msg)37 void halide_python_print(void *, const char *msg) {
38 py::print(msg, py::arg("end") = "");
39 }
40
41 class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter {
42 public:
warning(const char * msg)43 void warning(const char *msg) override {
44 py::print(msg, py::arg("end") = "");
45 }
46
error(const char * msg)47 void error(const char *msg) override {
48 throw Error(msg);
49 // This method must not return!
50 }
51 };
52
install_error_handlers(py::module & m)53 void install_error_handlers(py::module &m) {
54 static HalidePythonCompileTimeErrorReporter reporter;
55 set_custom_compile_time_error_reporter(&reporter);
56
57 Halide::Internal::JITHandlers handlers;
58 handlers.custom_error = halide_python_error;
59 handlers.custom_print = halide_python_print;
60 Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
61 }
62
63 // Anything that defines __getitem__ looks sequencelike to pybind,
64 // so also check for __len_ to avoid things like Buffer and Func here.
is_real_sequence(const py::object & o)65 bool is_real_sequence(const py::object &o) {
66 return py::isinstance<py::sequence>(o) && py::hasattr(o, "__len__");
67 }
68
to_stub_input(const py::object & o)69 StubInput to_stub_input(const py::object &o) {
70 // Don't use isinstance: we want to get things that
71 // can be implicitly converted as well (eg ImageParam -> Func)
72 try {
73 return StubInput(StubInputBuffer(o.cast<Buffer<>>()));
74 } catch (...) {
75 // Not convertible to Buffer. Fall thru and try next.
76 }
77
78 try {
79 return StubInput(o.cast<Func>());
80 } catch (...) {
81 // Not convertible to Func. Fall thru and try next.
82 }
83
84 return StubInput(o.cast<Expr>());
85 }
86
append_input(const py::object & value,std::vector<StubInput> & v)87 void append_input(const py::object &value, std::vector<StubInput> &v) {
88 if (is_real_sequence(value)) {
89 for (auto o : py::reinterpret_borrow<py::sequence>(value)) {
90 v.push_back(to_stub_input(o));
91 }
92 } else {
93 v.push_back(to_stub_input(value));
94 }
95 }
96
generate_impl(FactoryFunc factory,const GeneratorContext & context,const py::args & args,const py::kwargs & kwargs)97 py::object generate_impl(FactoryFunc factory, const GeneratorContext &context, const py::args &args, const py::kwargs &kwargs) {
98 Stub stub(context, [factory](const GeneratorContext &context) -> std::unique_ptr<Halide::Internal::GeneratorBase> {
99 return factory(context);
100 });
101 auto names = stub.get_names();
102 _halide_user_assert(!names.outputs.empty())
103 << "Generators that use build() (instead of generate()+Output<>) are not supported in the Python bindings.";
104 std::map<std::string, size_t> input_name_to_pos;
105 for (size_t i = 0; i < names.inputs.size(); ++i) {
106 input_name_to_pos[names.inputs[i]] = i;
107 }
108
109 // Inputs can be specified by either positional or named args,
110 // and must all be specified.
111 //
112 // GeneratorParams can only be specified by name, and are always optional.
113 //
114 std::vector<std::vector<StubInput>> inputs;
115 inputs.resize(names.inputs.size());
116
117 GeneratorParamsMap generator_params;
118
119 // Process the kwargs first.
120 for (auto kw : kwargs) {
121 // If the kwarg is the name of a known input, stick it in the input
122 // vector. If not, stick it in the GeneratorParamsMap (if it's invalid,
123 // an error will be reported further downstream).
124 std::string key = kw.first.cast<std::string>();
125 py::handle value = kw.second;
126 auto it = input_name_to_pos.find(key);
127 if (it != input_name_to_pos.end()) {
128 append_input(py::cast<py::object>(value), inputs[it->second]);
129 } else {
130 if (py::isinstance<LoopLevel>(value)) {
131 generator_params[key] = value.cast<LoopLevel>();
132 } else {
133 generator_params[key] = py::str(value).cast<std::string>();
134 }
135 }
136 }
137
138 // Now, the positional args.
139 _halide_user_assert(args.size() <= names.inputs.size())
140 << "Expected at most " << names.inputs.size() << " positional args, but saw " << args.size() << ".";
141 for (size_t i = 0; i < args.size(); ++i) {
142 _halide_user_assert(inputs[i].empty())
143 << "Generator Input named '" << names.inputs[i] << "' was specified by both position and keyword.";
144 append_input(args[i], inputs[i]);
145 }
146
147 for (size_t i = 0; i < inputs.size(); ++i) {
148 _halide_user_assert(!inputs[i].empty())
149 << "Generator Input named '" << names.inputs[i] << "' was not specified.";
150 }
151
152 const std::vector<std::vector<Func>> outputs = stub.generate(generator_params, inputs);
153
154 py::tuple py_outputs(outputs.size());
155 for (size_t i = 0; i < outputs.size(); i++) {
156 py::object o;
157 if (outputs[i].size() == 1) {
158 // convert list-of-1 into single element
159 o = py::cast(outputs[i][0]);
160 } else {
161 o = py::cast(outputs[i]);
162 }
163 if (outputs.size() == 1) {
164 // bail early, return the single object rather than a dict
165 return o;
166 }
167 py_outputs[i] = o;
168 }
169 // An explicit "std::move" is needed here because there's
170 // an implicit tuple->object conversion that inhibits it otherwise.
171 return std::move(py_outputs);
172 }
173
pystub_init(pybind11::module & m,FactoryFunc factory)174 void pystub_init(pybind11::module &m, FactoryFunc factory) {
175 m.def(
176 "generate", [factory](const Halide::Target &target, py::args args, py::kwargs kwargs) -> py::object {
177 return generate_impl(factory, Halide::GeneratorContext(target), args, kwargs);
178 },
179 py::arg("target"));
180 }
181
182 } // namespace
183 } // namespace PythonBindings
184 } // namespace Halide
185
_halide_pystub_impl(const char * module_name,FactoryFunc factory)186 extern "C" PyObject *_halide_pystub_impl(const char *module_name, FactoryFunc factory) {
187 int major, minor;
188 if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) {
189 PyErr_SetString(PyExc_ImportError, "Can't parse Python version.");
190 return nullptr;
191 } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) {
192 PyErr_Format(PyExc_ImportError,
193 "Python version mismatch: module was compiled for "
194 "version %i.%i, while the interpreter is running "
195 "version %i.%i.",
196 PY_MAJOR_VERSION, PY_MINOR_VERSION,
197 major, minor);
198 return nullptr;
199 }
200 auto m = pybind11::module(module_name);
201 try {
202 Halide::PythonBindings::install_error_handlers(m);
203 Halide::PythonBindings::pystub_init(m, factory);
204 return m.ptr();
205 } catch (pybind11::error_already_set &e) {
206 PyErr_SetString(PyExc_ImportError, e.what());
207 return nullptr;
208 } catch (const std::exception &e) {
209 PyErr_SetString(PyExc_ImportError, e.what());
210 return nullptr;
211 }
212 }
213