1 //===- MainModule.cpp - Main pybind module --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <tuple>
10 
11 #include "PybindUtils.h"
12 
13 #include "Globals.h"
14 #include "IRModules.h"
15 #include "Pass.h"
16 
17 namespace py = pybind11;
18 using namespace mlir;
19 using namespace mlir::python;
20 
21 // -----------------------------------------------------------------------------
22 // PyGlobals
23 // -----------------------------------------------------------------------------
24 
25 PyGlobals *PyGlobals::instance = nullptr;
26 
PyGlobals()27 PyGlobals::PyGlobals() {
28   assert(!instance && "PyGlobals already constructed");
29   instance = this;
30 }
31 
~PyGlobals()32 PyGlobals::~PyGlobals() { instance = nullptr; }
33 
loadDialectModule(llvm::StringRef dialectNamespace)34 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
35   py::gil_scoped_acquire();
36   if (loadedDialectModulesCache.contains(dialectNamespace))
37     return;
38   // Since re-entrancy is possible, make a copy of the search prefixes.
39   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
40   py::object loaded;
41   for (std::string moduleName : localSearchPrefixes) {
42     moduleName.push_back('.');
43     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
44 
45     try {
46       py::gil_scoped_release();
47       loaded = py::module::import(moduleName.c_str());
48     } catch (py::error_already_set &e) {
49       if (e.matches(PyExc_ModuleNotFoundError)) {
50         continue;
51       } else {
52         throw;
53       }
54     }
55     break;
56   }
57 
58   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
59   // may have occurred, which may do anything.
60   loadedDialectModulesCache.insert(dialectNamespace);
61 }
62 
registerDialectImpl(const std::string & dialectNamespace,py::object pyClass)63 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
64                                     py::object pyClass) {
65   py::gil_scoped_acquire();
66   py::object &found = dialectClassMap[dialectNamespace];
67   if (found) {
68     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
69                                              dialectNamespace +
70                                              "' is already registered.");
71   }
72   found = std::move(pyClass);
73 }
74 
registerOperationImpl(const std::string & operationName,py::object pyClass,py::object rawOpViewClass)75 void PyGlobals::registerOperationImpl(const std::string &operationName,
76                                       py::object pyClass,
77                                       py::object rawOpViewClass) {
78   py::gil_scoped_acquire();
79   py::object &found = operationClassMap[operationName];
80   if (found) {
81     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
82                                              operationName +
83                                              "' is already registered.");
84   }
85   found = std::move(pyClass);
86   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
87 }
88 
89 llvm::Optional<py::object>
lookupDialectClass(const std::string & dialectNamespace)90 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
91   py::gil_scoped_acquire();
92   loadDialectModule(dialectNamespace);
93   // Fast match against the class map first (common case).
94   const auto foundIt = dialectClassMap.find(dialectNamespace);
95   if (foundIt != dialectClassMap.end()) {
96     if (foundIt->second.is_none())
97       return llvm::None;
98     assert(foundIt->second && "py::object is defined");
99     return foundIt->second;
100   }
101 
102   // Not found and loading did not yield a registration. Negative cache.
103   dialectClassMap[dialectNamespace] = py::none();
104   return llvm::None;
105 }
106 
107 llvm::Optional<pybind11::object>
lookupRawOpViewClass(llvm::StringRef operationName)108 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
109   {
110     py::gil_scoped_acquire();
111     auto foundIt = rawOpViewClassMapCache.find(operationName);
112     if (foundIt != rawOpViewClassMapCache.end()) {
113       if (foundIt->second.is_none())
114         return llvm::None;
115       assert(foundIt->second && "py::object is defined");
116       return foundIt->second;
117     }
118   }
119 
120   // Not found. Load the dialect namespace.
121   auto split = operationName.split('.');
122   llvm::StringRef dialectNamespace = split.first;
123   loadDialectModule(dialectNamespace);
124 
125   // Attempt to find from the canonical map and cache.
126   {
127     py::gil_scoped_acquire();
128     auto foundIt = rawOpViewClassMap.find(operationName);
129     if (foundIt != rawOpViewClassMap.end()) {
130       if (foundIt->second.is_none())
131         return llvm::None;
132       assert(foundIt->second && "py::object is defined");
133       // Positive cache.
134       rawOpViewClassMapCache[operationName] = foundIt->second;
135       return foundIt->second;
136     } else {
137       // Negative cache.
138       rawOpViewClassMap[operationName] = py::none();
139       return llvm::None;
140     }
141   }
142 }
143 
clearImportCache()144 void PyGlobals::clearImportCache() {
145   py::gil_scoped_acquire();
146   loadedDialectModulesCache.clear();
147   rawOpViewClassMapCache.clear();
148 }
149 
150 // -----------------------------------------------------------------------------
151 // Module initialization.
152 // -----------------------------------------------------------------------------
153 
PYBIND11_MODULE(_mlir,m)154 PYBIND11_MODULE(_mlir, m) {
155   m.doc() = "MLIR Python Native Extension";
156 
157   py::class_<PyGlobals>(m, "_Globals")
158       .def_property("dialect_search_modules",
159                     &PyGlobals::getDialectSearchPrefixes,
160                     &PyGlobals::setDialectSearchPrefixes)
161       .def("append_dialect_search_prefix",
162            [](PyGlobals &self, std::string moduleName) {
163              self.getDialectSearchPrefixes().push_back(std::move(moduleName));
164              self.clearImportCache();
165            })
166       .def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
167            "Testing hook for directly registering a dialect")
168       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
169            "Testing hook for directly registering an operation");
170 
171   // Aside from making the globals accessible to python, having python manage
172   // it is necessary to make sure it is destroyed (and releases its python
173   // resources) properly.
174   m.attr("globals") =
175       py::cast(new PyGlobals, py::return_value_policy::take_ownership);
176 
177   // Registration decorators.
178   m.def(
179       "register_dialect",
180       [](py::object pyClass) {
181         std::string dialectNamespace =
182             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
183         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
184         return pyClass;
185       },
186       "Class decorator for registering a custom Dialect wrapper");
187   m.def(
188       "register_operation",
189       [](py::object dialectClass) -> py::cpp_function {
190         return py::cpp_function(
191             [dialectClass](py::object opClass) -> py::object {
192               std::string operationName =
193                   opClass.attr("OPERATION_NAME").cast<std::string>();
194               auto rawSubclass = PyOpView::createRawSubclass(opClass);
195               PyGlobals::get().registerOperationImpl(operationName, opClass,
196                                                      rawSubclass);
197 
198               // Dict-stuff the new opClass by name onto the dialect class.
199               py::object opClassName = opClass.attr("__name__");
200               dialectClass.attr(opClassName) = opClass;
201 
202               // Now create a special "Raw" subclass that passes through
203               // construction to the OpView parent (bypasses the intermediate
204               // child's __init__).
205               opClass.attr("_Raw") = rawSubclass;
206               return opClass;
207             });
208       },
209       "Class decorator for registering a custom Operation wrapper");
210 
211   // Define and populate IR submodule.
212   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
213   populateIRSubmodule(irModule);
214 
215   // Define and populate PassManager submodule.
216   auto passModule =
217       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
218   populatePassManagerSubmodule(passModule);
219 }
220