1 // Note that this deliberately does *not* include PyHalide.h,
2 // or depend on any of the code in src: this is intended to be
3 // a minimal, generic wrapper to expose an arbitrary Generator
4 // for stub usage in Python.
5 
6 #include <pybind11/pybind11.h>
7 #include <pybind11/stl.h>
8 
9 #include <string>
10 #include <utility>
11 
12 #include <vector>
13 
14 #include "Halide.h"
15 
16 namespace py = pybind11;
17 
18 using FactoryFunc = std::unique_ptr<Halide::Internal::GeneratorBase> (*)(const Halide::GeneratorContext &context);
19 
20 namespace Halide {
21 namespace PythonBindings {
22 
23 using GeneratorParamsMap = Internal::GeneratorParamsMap;
24 using Stub = Internal::GeneratorStub;
25 using StubInput = Internal::StubInput;
26 using StubInputBuffer = Internal::StubInputBuffer<void>;
27 
28 namespace {
29 
30 // This seems redundant to the code in PyError.cpp, but is necessary
31 // in case the Stub builder links in a separate copy of libHalide, rather
32 // sharing the same halide.so that is built by default.
halide_python_error(void *,const char * msg)33 void halide_python_error(void *, const char *msg) {
34     throw Error(msg);
35 }
36 
halide_python_print(void *,const char * msg)37 void halide_python_print(void *, const char *msg) {
38     py::print(msg, py::arg("end") = "");
39 }
40 
41 class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter {
42 public:
warning(const char * msg)43     void warning(const char *msg) override {
44         py::print(msg, py::arg("end") = "");
45     }
46 
error(const char * msg)47     void error(const char *msg) override {
48         throw Error(msg);
49         // This method must not return!
50     }
51 };
52 
install_error_handlers(py::module & m)53 void install_error_handlers(py::module &m) {
54     static HalidePythonCompileTimeErrorReporter reporter;
55     set_custom_compile_time_error_reporter(&reporter);
56 
57     Halide::Internal::JITHandlers handlers;
58     handlers.custom_error = halide_python_error;
59     handlers.custom_print = halide_python_print;
60     Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
61 }
62 
63 // Anything that defines __getitem__ looks sequencelike to pybind,
64 // so also check for __len_ to avoid things like Buffer and Func here.
is_real_sequence(const py::object & o)65 bool is_real_sequence(const py::object &o) {
66     return py::isinstance<py::sequence>(o) && py::hasattr(o, "__len__");
67 }
68 
to_stub_input(const py::object & o)69 StubInput to_stub_input(const py::object &o) {
70     // Don't use isinstance: we want to get things that
71     // can be implicitly converted as well (eg ImageParam -> Func)
72     try {
73         return StubInput(StubInputBuffer(o.cast<Buffer<>>()));
74     } catch (...) {
75         // Not convertible to Buffer. Fall thru and try next.
76     }
77 
78     try {
79         return StubInput(o.cast<Func>());
80     } catch (...) {
81         // Not convertible to Func. Fall thru and try next.
82     }
83 
84     return StubInput(o.cast<Expr>());
85 }
86 
append_input(const py::object & value,std::vector<StubInput> & v)87 void append_input(const py::object &value, std::vector<StubInput> &v) {
88     if (is_real_sequence(value)) {
89         for (auto o : py::reinterpret_borrow<py::sequence>(value)) {
90             v.push_back(to_stub_input(o));
91         }
92     } else {
93         v.push_back(to_stub_input(value));
94     }
95 }
96 
generate_impl(FactoryFunc factory,const GeneratorContext & context,const py::args & args,const py::kwargs & kwargs)97 py::object generate_impl(FactoryFunc factory, const GeneratorContext &context, const py::args &args, const py::kwargs &kwargs) {
98     Stub stub(context, [factory](const GeneratorContext &context) -> std::unique_ptr<Halide::Internal::GeneratorBase> {
99         return factory(context);
100     });
101     auto names = stub.get_names();
102     _halide_user_assert(!names.outputs.empty())
103         << "Generators that use build() (instead of generate()+Output<>) are not supported in the Python bindings.";
104     std::map<std::string, size_t> input_name_to_pos;
105     for (size_t i = 0; i < names.inputs.size(); ++i) {
106         input_name_to_pos[names.inputs[i]] = i;
107     }
108 
109     // Inputs can be specified by either positional or named args,
110     // and must all be specified.
111     //
112     // GeneratorParams can only be specified by name, and are always optional.
113     //
114     std::vector<std::vector<StubInput>> inputs;
115     inputs.resize(names.inputs.size());
116 
117     GeneratorParamsMap generator_params;
118 
119     // Process the kwargs first.
120     for (auto kw : kwargs) {
121         // If the kwarg is the name of a known input, stick it in the input
122         // vector. If not, stick it in the GeneratorParamsMap (if it's invalid,
123         // an error will be reported further downstream).
124         std::string key = kw.first.cast<std::string>();
125         py::handle value = kw.second;
126         auto it = input_name_to_pos.find(key);
127         if (it != input_name_to_pos.end()) {
128             append_input(py::cast<py::object>(value), inputs[it->second]);
129         } else {
130             if (py::isinstance<LoopLevel>(value)) {
131                 generator_params[key] = value.cast<LoopLevel>();
132             } else {
133                 generator_params[key] = py::str(value).cast<std::string>();
134             }
135         }
136     }
137 
138     // Now, the positional args.
139     _halide_user_assert(args.size() <= names.inputs.size())
140         << "Expected at most " << names.inputs.size() << " positional args, but saw " << args.size() << ".";
141     for (size_t i = 0; i < args.size(); ++i) {
142         _halide_user_assert(inputs[i].empty())
143             << "Generator Input named '" << names.inputs[i] << "' was specified by both position and keyword.";
144         append_input(args[i], inputs[i]);
145     }
146 
147     for (size_t i = 0; i < inputs.size(); ++i) {
148         _halide_user_assert(!inputs[i].empty())
149             << "Generator Input named '" << names.inputs[i] << "' was not specified.";
150     }
151 
152     const std::vector<std::vector<Func>> outputs = stub.generate(generator_params, inputs);
153 
154     py::tuple py_outputs(outputs.size());
155     for (size_t i = 0; i < outputs.size(); i++) {
156         py::object o;
157         if (outputs[i].size() == 1) {
158             // convert list-of-1 into single element
159             o = py::cast(outputs[i][0]);
160         } else {
161             o = py::cast(outputs[i]);
162         }
163         if (outputs.size() == 1) {
164             // bail early, return the single object rather than a dict
165             return o;
166         }
167         py_outputs[i] = o;
168     }
169     // An explicit "std::move" is needed here because there's
170     // an implicit tuple->object conversion that inhibits it otherwise.
171     return std::move(py_outputs);
172 }
173 
pystub_init(pybind11::module & m,FactoryFunc factory)174 void pystub_init(pybind11::module &m, FactoryFunc factory) {
175     m.def(
176         "generate", [factory](const Halide::Target &target, py::args args, py::kwargs kwargs) -> py::object {
177             return generate_impl(factory, Halide::GeneratorContext(target), args, kwargs);
178         },
179         py::arg("target"));
180 }
181 
182 }  // namespace
183 }  // namespace PythonBindings
184 }  // namespace Halide
185 
_halide_pystub_impl(const char * module_name,FactoryFunc factory)186 extern "C" PyObject *_halide_pystub_impl(const char *module_name, FactoryFunc factory) {
187     int major, minor;
188     if (sscanf(Py_GetVersion(), "%i.%i", &major, &minor) != 2) {
189         PyErr_SetString(PyExc_ImportError, "Can't parse Python version.");
190         return nullptr;
191     } else if (major != PY_MAJOR_VERSION || minor != PY_MINOR_VERSION) {
192         PyErr_Format(PyExc_ImportError,
193                      "Python version mismatch: module was compiled for "
194                      "version %i.%i, while the interpreter is running "
195                      "version %i.%i.",
196                      PY_MAJOR_VERSION, PY_MINOR_VERSION,
197                      major, minor);
198         return nullptr;
199     }
200     auto m = pybind11::module(module_name);
201     try {
202         Halide::PythonBindings::install_error_handlers(m);
203         Halide::PythonBindings::pystub_init(m, factory);
204         return m.ptr();
205     } catch (pybind11::error_already_set &e) {
206         PyErr_SetString(PyExc_ImportError, e.what());
207         return nullptr;
208     } catch (const std::exception &e) {
209         PyErr_SetString(PyExc_ImportError, e.what());
210         return nullptr;
211     }
212 }
213