1 //===- IRModule.cpp - IR pybind module ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 #include "Globals.h"
11 #include "PybindUtils.h"
12 
13 #include <vector>
14 
15 namespace py = pybind11;
16 using namespace mlir;
17 using namespace mlir::python;
18 
19 // -----------------------------------------------------------------------------
20 // PyGlobals
21 // -----------------------------------------------------------------------------
22 
23 PyGlobals *PyGlobals::instance = nullptr;
24 
PyGlobals()25 PyGlobals::PyGlobals() {
26   assert(!instance && "PyGlobals already constructed");
27   instance = this;
28 }
29 
~PyGlobals()30 PyGlobals::~PyGlobals() { instance = nullptr; }
31 
loadDialectModule(llvm::StringRef dialectNamespace)32 void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
33   py::gil_scoped_acquire();
34   if (loadedDialectModulesCache.contains(dialectNamespace))
35     return;
36   // Since re-entrancy is possible, make a copy of the search prefixes.
37   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
38   py::object loaded;
39   for (std::string moduleName : localSearchPrefixes) {
40     moduleName.push_back('.');
41     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
42 
43     try {
44       py::gil_scoped_release();
45       loaded = py::module::import(moduleName.c_str());
46     } catch (py::error_already_set &e) {
47       if (e.matches(PyExc_ModuleNotFoundError)) {
48         continue;
49       } else {
50         throw;
51       }
52     }
53     break;
54   }
55 
56   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
57   // may have occurred, which may do anything.
58   loadedDialectModulesCache.insert(dialectNamespace);
59 }
60 
registerDialectImpl(const std::string & dialectNamespace,py::object pyClass)61 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
62                                     py::object pyClass) {
63   py::gil_scoped_acquire();
64   py::object &found = dialectClassMap[dialectNamespace];
65   if (found) {
66     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
67                                              dialectNamespace +
68                                              "' is already registered.");
69   }
70   found = std::move(pyClass);
71 }
72 
registerOperationImpl(const std::string & operationName,py::object pyClass,py::object rawOpViewClass)73 void PyGlobals::registerOperationImpl(const std::string &operationName,
74                                       py::object pyClass,
75                                       py::object rawOpViewClass) {
76   py::gil_scoped_acquire();
77   py::object &found = operationClassMap[operationName];
78   if (found) {
79     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
80                                              operationName +
81                                              "' is already registered.");
82   }
83   found = std::move(pyClass);
84   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
85 }
86 
87 llvm::Optional<py::object>
lookupDialectClass(const std::string & dialectNamespace)88 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
89   py::gil_scoped_acquire();
90   loadDialectModule(dialectNamespace);
91   // Fast match against the class map first (common case).
92   const auto foundIt = dialectClassMap.find(dialectNamespace);
93   if (foundIt != dialectClassMap.end()) {
94     if (foundIt->second.is_none())
95       return llvm::None;
96     assert(foundIt->second && "py::object is defined");
97     return foundIt->second;
98   }
99 
100   // Not found and loading did not yield a registration. Negative cache.
101   dialectClassMap[dialectNamespace] = py::none();
102   return llvm::None;
103 }
104 
105 llvm::Optional<pybind11::object>
lookupRawOpViewClass(llvm::StringRef operationName)106 PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
107   {
108     py::gil_scoped_acquire();
109     auto foundIt = rawOpViewClassMapCache.find(operationName);
110     if (foundIt != rawOpViewClassMapCache.end()) {
111       if (foundIt->second.is_none())
112         return llvm::None;
113       assert(foundIt->second && "py::object is defined");
114       return foundIt->second;
115     }
116   }
117 
118   // Not found. Load the dialect namespace.
119   auto split = operationName.split('.');
120   llvm::StringRef dialectNamespace = split.first;
121   loadDialectModule(dialectNamespace);
122 
123   // Attempt to find from the canonical map and cache.
124   {
125     py::gil_scoped_acquire();
126     auto foundIt = rawOpViewClassMap.find(operationName);
127     if (foundIt != rawOpViewClassMap.end()) {
128       if (foundIt->second.is_none())
129         return llvm::None;
130       assert(foundIt->second && "py::object is defined");
131       // Positive cache.
132       rawOpViewClassMapCache[operationName] = foundIt->second;
133       return foundIt->second;
134     } else {
135       // Negative cache.
136       rawOpViewClassMap[operationName] = py::none();
137       return llvm::None;
138     }
139   }
140 }
141 
clearImportCache()142 void PyGlobals::clearImportCache() {
143   py::gil_scoped_acquire();
144   loadedDialectModulesCache.clear();
145   rawOpViewClassMapCache.clear();
146 }
147