1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
classmethod(Func f,Args...args)136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
toMlirStringRef(const std::string & s)154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
setPyGlobalDebugFlag160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
getPyGlobalDebugFlag162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
bindPyGlobalDebugFlag164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
PyRegionIterator(PyOperationRef operation)180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
dunderIter()183   PyRegionIterator &dunderIter() { return *this; }
184 
dunderNext()185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
bind(py::module & m)194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator")
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
PyRegionList(PyOperationRef operation)209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
dunderLen()211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
dunderGetItem(intptr_t index)216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
bind(py::module & m)226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence")
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
dunderIter()241   PyBlockIterator &dunderIter() { return *this; }
242 
dunderNext()243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
bind(py::module & m)254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator")
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
PyBlockList(PyOperationRef operation,MlirRegion region)270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
dunderIter()273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
dunderLen()278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
dunderGetItem(intptr_t index)289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
appendBlock(py::args pyArgTypes)306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
bind(py::module & m)319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList")
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
dunderIter()337   PyOperationIterator &dunderIter() { return *this; }
338 
dunderNext()339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
bind(py::module & m)351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator")
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
dunderIter()371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
dunderLen()377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
dunderGetItem(intptr_t index)388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
bind(py::module & m)407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList")
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
PyMlirContext(MlirContext context)425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
~PyMlirContext()431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
getCapsule()440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
createFromCapsule(py::object capsule)444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
createNewContextForInit()451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
forContext(MlirContext context)457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
getLiveContexts()474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
getLiveCount()479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
getLiveOperationCount()481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
getLiveModuleCount()483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
contextEnter()485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
resolve()495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
getStack()511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
getTopOfStack()516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
getContext()545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
getInsertionPoint()551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
getLocation()557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
getDefaultContext()563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
getDefaultInsertionPoint()568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
getDefaultLocation()573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
pushContext(PyMlirContext & context)578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
popContext(PyMlirContext & context)586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
popInsertionPoint(PyInsertionPoint & insertionPoint)608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
pushLocation(PyLocation & location)621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
popLocation(PyLocation & location)630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
getDialectForKey(const std::string & key,bool attrError)644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   // If the "std" dialect was asked for, substitute the empty namespace :(
647   static const std::string emptyKey;
648   const std::string *canonKey = key == "std" ? &emptyKey : &key;
649   MlirDialect dialect = mlirContextGetOrLoadDialect(
650       getContext()->get(), {canonKey->data(), canonKey->size()});
651   if (mlirDialectIsNull(dialect)) {
652     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
653                      Twine("Dialect '") + key + "' not found");
654   }
655   return dialect;
656 }
657 
658 //------------------------------------------------------------------------------
659 // PyLocation
660 //------------------------------------------------------------------------------
661 
getCapsule()662 py::object PyLocation::getCapsule() {
663   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
664 }
665 
createFromCapsule(py::object capsule)666 PyLocation PyLocation::createFromCapsule(py::object capsule) {
667   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
668   if (mlirLocationIsNull(rawLoc))
669     throw py::error_already_set();
670   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
671                     rawLoc);
672 }
673 
contextEnter()674 py::object PyLocation::contextEnter() {
675   return PyThreadContextEntry::pushLocation(*this);
676 }
677 
contextExit(py::object excType,py::object excVal,py::object excTb)678 void PyLocation::contextExit(py::object excType, py::object excVal,
679                              py::object excTb) {
680   PyThreadContextEntry::popLocation(*this);
681 }
682 
resolve()683 PyLocation &DefaultingPyLocation::resolve() {
684   auto *location = PyThreadContextEntry::getDefaultLocation();
685   if (!location) {
686     throw SetPyError(
687         PyExc_RuntimeError,
688         "An MLIR function requires a Location but none was provided in the "
689         "call or from the surrounding environment. Either pass to the function "
690         "with a 'loc=' argument or establish a default using 'with loc:'");
691   }
692   return *location;
693 }
694 
695 //------------------------------------------------------------------------------
696 // PyModule
697 //------------------------------------------------------------------------------
698 
PyModule(PyMlirContextRef contextRef,MlirModule module)699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
700     : BaseContextObject(std::move(contextRef)), module(module) {}
701 
~PyModule()702 PyModule::~PyModule() {
703   py::gil_scoped_acquire acquire;
704   auto &liveModules = getContext()->liveModules;
705   assert(liveModules.count(module.ptr) == 1 &&
706          "destroying module not in live map");
707   liveModules.erase(module.ptr);
708   mlirModuleDestroy(module);
709 }
710 
forModule(MlirModule module)711 PyModuleRef PyModule::forModule(MlirModule module) {
712   MlirContext context = mlirModuleGetContext(module);
713   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
714 
715   py::gil_scoped_acquire acquire;
716   auto &liveModules = contextRef->liveModules;
717   auto it = liveModules.find(module.ptr);
718   if (it == liveModules.end()) {
719     // Create.
720     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
721     // Note that the default return value policy on cast is automatic_reference,
722     // which does not take ownership (delete will not be called).
723     // Just be explicit.
724     py::object pyRef =
725         py::cast(unownedModule, py::return_value_policy::take_ownership);
726     unownedModule->handle = pyRef;
727     liveModules[module.ptr] =
728         std::make_pair(unownedModule->handle, unownedModule);
729     return PyModuleRef(unownedModule, std::move(pyRef));
730   }
731   // Use existing.
732   PyModule *existing = it->second.second;
733   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
734   return PyModuleRef(existing, std::move(pyRef));
735 }
736 
createFromCapsule(py::object capsule)737 py::object PyModule::createFromCapsule(py::object capsule) {
738   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
739   if (mlirModuleIsNull(rawModule))
740     throw py::error_already_set();
741   return forModule(rawModule).releaseObject();
742 }
743 
getCapsule()744 py::object PyModule::getCapsule() {
745   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
746 }
747 
748 //------------------------------------------------------------------------------
749 // PyOperation
750 //------------------------------------------------------------------------------
751 
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
753     : BaseContextObject(std::move(contextRef)), operation(operation) {}
754 
~PyOperation()755 PyOperation::~PyOperation() {
756   // If the operation has already been invalidated there is nothing to do.
757   if (!valid)
758     return;
759   auto &liveOperations = getContext()->liveOperations;
760   assert(liveOperations.count(operation.ptr) == 1 &&
761          "destroying operation not in live map");
762   liveOperations.erase(operation.ptr);
763   if (!isAttached()) {
764     mlirOperationDestroy(operation);
765   }
766 }
767 
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
769                                            MlirOperation operation,
770                                            py::object parentKeepAlive) {
771   auto &liveOperations = contextRef->liveOperations;
772   // Create.
773   PyOperation *unownedOperation =
774       new PyOperation(std::move(contextRef), operation);
775   // Note that the default return value policy on cast is automatic_reference,
776   // which does not take ownership (delete will not be called).
777   // Just be explicit.
778   py::object pyRef =
779       py::cast(unownedOperation, py::return_value_policy::take_ownership);
780   unownedOperation->handle = pyRef;
781   if (parentKeepAlive) {
782     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
783   }
784   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
785   return PyOperationRef(unownedOperation, std::move(pyRef));
786 }
787 
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
789                                          MlirOperation operation,
790                                          py::object parentKeepAlive) {
791   auto &liveOperations = contextRef->liveOperations;
792   auto it = liveOperations.find(operation.ptr);
793   if (it == liveOperations.end()) {
794     // Create.
795     return createInstance(std::move(contextRef), operation,
796                           std::move(parentKeepAlive));
797   }
798   // Use existing.
799   PyOperation *existing = it->second.second;
800   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
801   return PyOperationRef(existing, std::move(pyRef));
802 }
803 
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
805                                            MlirOperation operation,
806                                            py::object parentKeepAlive) {
807   auto &liveOperations = contextRef->liveOperations;
808   assert(liveOperations.count(operation.ptr) == 0 &&
809          "cannot create detached operation that already exists");
810   (void)liveOperations;
811 
812   PyOperationRef created = createInstance(std::move(contextRef), operation,
813                                           std::move(parentKeepAlive));
814   created->attached = false;
815   return created;
816 }
817 
checkValid() const818 void PyOperation::checkValid() const {
819   if (!valid) {
820     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
821   }
822 }
823 
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)824 void PyOperationBase::print(py::object fileObject, bool binary,
825                             llvm::Optional<int64_t> largeElementsLimit,
826                             bool enableDebugInfo, bool prettyDebugInfo,
827                             bool printGenericOpForm, bool useLocalScope) {
828   PyOperation &operation = getOperation();
829   operation.checkValid();
830   if (fileObject.is_none())
831     fileObject = py::module::import("sys").attr("stdout");
832 
833   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
834     fileObject.attr("write")("// Verification failed, printing generic form\n");
835     printGenericOpForm = true;
836   }
837 
838   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
839   if (largeElementsLimit)
840     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
841   if (enableDebugInfo)
842     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
843   if (printGenericOpForm)
844     mlirOpPrintingFlagsPrintGenericOpForm(flags);
845 
846   PyFileAccumulator accum(fileObject, binary);
847   py::gil_scoped_release();
848   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
849                               accum.getUserData());
850   mlirOpPrintingFlagsDestroy(flags);
851 }
852 
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)853 py::object PyOperationBase::getAsm(bool binary,
854                                    llvm::Optional<int64_t> largeElementsLimit,
855                                    bool enableDebugInfo, bool prettyDebugInfo,
856                                    bool printGenericOpForm,
857                                    bool useLocalScope) {
858   py::object fileObject;
859   if (binary) {
860     fileObject = py::module::import("io").attr("BytesIO")();
861   } else {
862     fileObject = py::module::import("io").attr("StringIO")();
863   }
864   print(fileObject, /*binary=*/binary,
865         /*largeElementsLimit=*/largeElementsLimit,
866         /*enableDebugInfo=*/enableDebugInfo,
867         /*prettyDebugInfo=*/prettyDebugInfo,
868         /*printGenericOpForm=*/printGenericOpForm,
869         /*useLocalScope=*/useLocalScope);
870 
871   return fileObject.attr("getvalue")();
872 }
873 
getParentOperation()874 PyOperationRef PyOperation::getParentOperation() {
875   checkValid();
876   if (!isAttached())
877     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
878   MlirOperation operation = mlirOperationGetParentOperation(get());
879   if (mlirOperationIsNull(operation))
880     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
881   return PyOperation::forOperation(getContext(), operation);
882 }
883 
getBlock()884 PyBlock PyOperation::getBlock() {
885   checkValid();
886   PyOperationRef parentOperation = getParentOperation();
887   MlirBlock block = mlirOperationGetBlock(get());
888   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
889   return PyBlock{std::move(parentOperation), block};
890 }
891 
getCapsule()892 py::object PyOperation::getCapsule() {
893   checkValid();
894   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
895 }
896 
createFromCapsule(py::object capsule)897 py::object PyOperation::createFromCapsule(py::object capsule) {
898   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
899   if (mlirOperationIsNull(rawOperation))
900     throw py::error_already_set();
901   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
902   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
903       .releaseObject();
904 }
905 
create(std::string name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)906 py::object PyOperation::create(
907     std::string name, llvm::Optional<std::vector<PyType *>> results,
908     llvm::Optional<std::vector<PyValue *>> operands,
909     llvm::Optional<py::dict> attributes,
910     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
911     DefaultingPyLocation location, py::object maybeIp) {
912   llvm::SmallVector<MlirValue, 4> mlirOperands;
913   llvm::SmallVector<MlirType, 4> mlirResults;
914   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
915   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
916 
917   // General parameter validation.
918   if (regions < 0)
919     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
920 
921   // Unpack/validate operands.
922   if (operands) {
923     mlirOperands.reserve(operands->size());
924     for (PyValue *operand : *operands) {
925       if (!operand)
926         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
927       mlirOperands.push_back(operand->get());
928     }
929   }
930 
931   // Unpack/validate results.
932   if (results) {
933     mlirResults.reserve(results->size());
934     for (PyType *result : *results) {
935       // TODO: Verify result type originate from the same context.
936       if (!result)
937         throw SetPyError(PyExc_ValueError, "result type cannot be None");
938       mlirResults.push_back(*result);
939     }
940   }
941   // Unpack/validate attributes.
942   if (attributes) {
943     mlirAttributes.reserve(attributes->size());
944     for (auto &it : *attributes) {
945       std::string key;
946       try {
947         key = it.first.cast<std::string>();
948       } catch (py::cast_error &err) {
949         std::string msg = "Invalid attribute key (not a string) when "
950                           "attempting to create the operation \"" +
951                           name + "\" (" + err.what() + ")";
952         throw py::cast_error(msg);
953       }
954       try {
955         auto &attribute = it.second.cast<PyAttribute &>();
956         // TODO: Verify attribute originates from the same context.
957         mlirAttributes.emplace_back(std::move(key), attribute);
958       } catch (py::reference_cast_error &) {
959         // This exception seems thrown when the value is "None".
960         std::string msg =
961             "Found an invalid (`None`?) attribute value for the key \"" + key +
962             "\" when attempting to create the operation \"" + name + "\"";
963         throw py::cast_error(msg);
964       } catch (py::cast_error &err) {
965         std::string msg = "Invalid attribute value for the key \"" + key +
966                           "\" when attempting to create the operation \"" +
967                           name + "\" (" + err.what() + ")";
968         throw py::cast_error(msg);
969       }
970     }
971   }
972   // Unpack/validate successors.
973   if (successors) {
974     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
975     mlirSuccessors.reserve(successors->size());
976     for (auto *successor : *successors) {
977       // TODO: Verify successor originate from the same context.
978       if (!successor)
979         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
980       mlirSuccessors.push_back(successor->get());
981     }
982   }
983 
984   // Apply unpacked/validated to the operation state. Beyond this
985   // point, exceptions cannot be thrown or else the state will leak.
986   MlirOperationState state =
987       mlirOperationStateGet(toMlirStringRef(name), location);
988   if (!mlirOperands.empty())
989     mlirOperationStateAddOperands(&state, mlirOperands.size(),
990                                   mlirOperands.data());
991   if (!mlirResults.empty())
992     mlirOperationStateAddResults(&state, mlirResults.size(),
993                                  mlirResults.data());
994   if (!mlirAttributes.empty()) {
995     // Note that the attribute names directly reference bytes in
996     // mlirAttributes, so that vector must not be changed from here
997     // on.
998     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
999     mlirNamedAttributes.reserve(mlirAttributes.size());
1000     for (auto &it : mlirAttributes)
1001       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1002           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1003                             toMlirStringRef(it.first)),
1004           it.second));
1005     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1006                                     mlirNamedAttributes.data());
1007   }
1008   if (!mlirSuccessors.empty())
1009     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1010                                     mlirSuccessors.data());
1011   if (regions) {
1012     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1013     mlirRegions.resize(regions);
1014     for (int i = 0; i < regions; ++i)
1015       mlirRegions[i] = mlirRegionCreate();
1016     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1017                                       mlirRegions.data());
1018   }
1019 
1020   // Construct the operation.
1021   MlirOperation operation = mlirOperationCreate(&state);
1022   PyOperationRef created =
1023       PyOperation::createDetached(location->getContext(), operation);
1024 
1025   // InsertPoint active?
1026   if (!maybeIp.is(py::cast(false))) {
1027     PyInsertionPoint *ip;
1028     if (maybeIp.is_none()) {
1029       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1030     } else {
1031       ip = py::cast<PyInsertionPoint *>(maybeIp);
1032     }
1033     if (ip)
1034       ip->insert(*created.get());
1035   }
1036 
1037   return created->createOpView();
1038 }
1039 
createOpView()1040 py::object PyOperation::createOpView() {
1041   checkValid();
1042   MlirIdentifier ident = mlirOperationGetName(get());
1043   MlirStringRef identStr = mlirIdentifierStr(ident);
1044   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1045       StringRef(identStr.data, identStr.length));
1046   if (opViewClass)
1047     return (*opViewClass)(getRef().getObject());
1048   return py::cast(PyOpView(getRef().getObject()));
1049 }
1050 
erase()1051 void PyOperation::erase() {
1052   checkValid();
1053   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1054   // Python reference to a child operation is live. All children should also
1055   // have their `valid` bit set to false.
1056   auto &liveOperations = getContext()->liveOperations;
1057   if (liveOperations.count(operation.ptr))
1058     liveOperations.erase(operation.ptr);
1059   mlirOperationDestroy(operation);
1060   valid = false;
1061 }
1062 
1063 //------------------------------------------------------------------------------
1064 // PyOpView
1065 //------------------------------------------------------------------------------
1066 
1067 py::object
buildGeneric(py::object cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,py::object maybeIp)1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1069                        py::list operandList,
1070                        llvm::Optional<py::dict> attributes,
1071                        llvm::Optional<std::vector<PyBlock *>> successors,
1072                        llvm::Optional<int> regions,
1073                        DefaultingPyLocation location, py::object maybeIp) {
1074   PyMlirContextRef context = location->getContext();
1075   // Class level operation construction metadata.
1076   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1077   // Operand and result segment specs are either none, which does no
1078   // variadic unpacking, or a list of ints with segment sizes, where each
1079   // element is either a positive number (typically 1 for a scalar) or -1 to
1080   // indicate that it is derived from the length of the same-indexed operand
1081   // or result (implying that it is a list at that position).
1082   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1083   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1084 
1085   std::vector<uint32_t> operandSegmentLengths;
1086   std::vector<uint32_t> resultSegmentLengths;
1087 
1088   // Validate/determine region count.
1089   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1090   int opMinRegionCount = std::get<0>(opRegionSpec);
1091   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1092   if (!regions) {
1093     regions = opMinRegionCount;
1094   }
1095   if (*regions < opMinRegionCount) {
1096     throw py::value_error(
1097         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1098          llvm::Twine(opMinRegionCount) +
1099          " regions but was built with regions=" + llvm::Twine(*regions))
1100             .str());
1101   }
1102   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1103     throw py::value_error(
1104         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1105          llvm::Twine(opMinRegionCount) +
1106          " regions but was built with regions=" + llvm::Twine(*regions))
1107             .str());
1108   }
1109 
1110   // Unpack results.
1111   std::vector<PyType *> resultTypes;
1112   resultTypes.reserve(resultTypeList.size());
1113   if (resultSegmentSpecObj.is_none()) {
1114     // Non-variadic result unpacking.
1115     for (auto it : llvm::enumerate(resultTypeList)) {
1116       try {
1117         resultTypes.push_back(py::cast<PyType *>(it.value()));
1118         if (!resultTypes.back())
1119           throw py::cast_error();
1120       } catch (py::cast_error &err) {
1121         throw py::value_error((llvm::Twine("Result ") +
1122                                llvm::Twine(it.index()) + " of operation \"" +
1123                                name + "\" must be a Type (" + err.what() + ")")
1124                                   .str());
1125       }
1126     }
1127   } else {
1128     // Sized result unpacking.
1129     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1130     if (resultSegmentSpec.size() != resultTypeList.size()) {
1131       throw py::value_error((llvm::Twine("Operation \"") + name +
1132                              "\" requires " +
1133                              llvm::Twine(resultSegmentSpec.size()) +
1134                              "result segments but was provided " +
1135                              llvm::Twine(resultTypeList.size()))
1136                                 .str());
1137     }
1138     resultSegmentLengths.reserve(resultTypeList.size());
1139     for (auto it :
1140          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1141       int segmentSpec = std::get<1>(it.value());
1142       if (segmentSpec == 1 || segmentSpec == 0) {
1143         // Unpack unary element.
1144         try {
1145           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1146           if (resultType) {
1147             resultTypes.push_back(resultType);
1148             resultSegmentLengths.push_back(1);
1149           } else if (segmentSpec == 0) {
1150             // Allowed to be optional.
1151             resultSegmentLengths.push_back(0);
1152           } else {
1153             throw py::cast_error("was None and result is not optional");
1154           }
1155         } catch (py::cast_error &err) {
1156           throw py::value_error((llvm::Twine("Result ") +
1157                                  llvm::Twine(it.index()) + " of operation \"" +
1158                                  name + "\" must be a Type (" + err.what() +
1159                                  ")")
1160                                     .str());
1161         }
1162       } else if (segmentSpec == -1) {
1163         // Unpack sequence by appending.
1164         try {
1165           if (std::get<0>(it.value()).is_none()) {
1166             // Treat it as an empty list.
1167             resultSegmentLengths.push_back(0);
1168           } else {
1169             // Unpack the list.
1170             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1171             for (py::object segmentItem : segment) {
1172               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1173               if (!resultTypes.back()) {
1174                 throw py::cast_error("contained a None item");
1175               }
1176             }
1177             resultSegmentLengths.push_back(segment.size());
1178           }
1179         } catch (std::exception &err) {
1180           // NOTE: Sloppy to be using a catch-all here, but there are at least
1181           // three different unrelated exceptions that can be thrown in the
1182           // above "casts". Just keep the scope above small and catch them all.
1183           throw py::value_error((llvm::Twine("Result ") +
1184                                  llvm::Twine(it.index()) + " of operation \"" +
1185                                  name + "\" must be a Sequence of Types (" +
1186                                  err.what() + ")")
1187                                     .str());
1188         }
1189       } else {
1190         throw py::value_error("Unexpected segment spec");
1191       }
1192     }
1193   }
1194 
1195   // Unpack operands.
1196   std::vector<PyValue *> operands;
1197   operands.reserve(operands.size());
1198   if (operandSegmentSpecObj.is_none()) {
1199     // Non-sized operand unpacking.
1200     for (auto it : llvm::enumerate(operandList)) {
1201       try {
1202         operands.push_back(py::cast<PyValue *>(it.value()));
1203         if (!operands.back())
1204           throw py::cast_error();
1205       } catch (py::cast_error &err) {
1206         throw py::value_error((llvm::Twine("Operand ") +
1207                                llvm::Twine(it.index()) + " of operation \"" +
1208                                name + "\" must be a Value (" + err.what() + ")")
1209                                   .str());
1210       }
1211     }
1212   } else {
1213     // Sized operand unpacking.
1214     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1215     if (operandSegmentSpec.size() != operandList.size()) {
1216       throw py::value_error((llvm::Twine("Operation \"") + name +
1217                              "\" requires " +
1218                              llvm::Twine(operandSegmentSpec.size()) +
1219                              "operand segments but was provided " +
1220                              llvm::Twine(operandList.size()))
1221                                 .str());
1222     }
1223     operandSegmentLengths.reserve(operandList.size());
1224     for (auto it :
1225          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1226       int segmentSpec = std::get<1>(it.value());
1227       if (segmentSpec == 1 || segmentSpec == 0) {
1228         // Unpack unary element.
1229         try {
1230           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1231           if (operandValue) {
1232             operands.push_back(operandValue);
1233             operandSegmentLengths.push_back(1);
1234           } else if (segmentSpec == 0) {
1235             // Allowed to be optional.
1236             operandSegmentLengths.push_back(0);
1237           } else {
1238             throw py::cast_error("was None and operand is not optional");
1239           }
1240         } catch (py::cast_error &err) {
1241           throw py::value_error((llvm::Twine("Operand ") +
1242                                  llvm::Twine(it.index()) + " of operation \"" +
1243                                  name + "\" must be a Value (" + err.what() +
1244                                  ")")
1245                                     .str());
1246         }
1247       } else if (segmentSpec == -1) {
1248         // Unpack sequence by appending.
1249         try {
1250           if (std::get<0>(it.value()).is_none()) {
1251             // Treat it as an empty list.
1252             operandSegmentLengths.push_back(0);
1253           } else {
1254             // Unpack the list.
1255             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1256             for (py::object segmentItem : segment) {
1257               operands.push_back(py::cast<PyValue *>(segmentItem));
1258               if (!operands.back()) {
1259                 throw py::cast_error("contained a None item");
1260               }
1261             }
1262             operandSegmentLengths.push_back(segment.size());
1263           }
1264         } catch (std::exception &err) {
1265           // NOTE: Sloppy to be using a catch-all here, but there are at least
1266           // three different unrelated exceptions that can be thrown in the
1267           // above "casts". Just keep the scope above small and catch them all.
1268           throw py::value_error((llvm::Twine("Operand ") +
1269                                  llvm::Twine(it.index()) + " of operation \"" +
1270                                  name + "\" must be a Sequence of Values (" +
1271                                  err.what() + ")")
1272                                     .str());
1273         }
1274       } else {
1275         throw py::value_error("Unexpected segment spec");
1276       }
1277     }
1278   }
1279 
1280   // Merge operand/result segment lengths into attributes if needed.
1281   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1282     // Dup.
1283     if (attributes) {
1284       attributes = py::dict(*attributes);
1285     } else {
1286       attributes = py::dict();
1287     }
1288     if (attributes->contains("result_segment_sizes") ||
1289         attributes->contains("operand_segment_sizes")) {
1290       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1291                             "'operand_segment_sizes' attribute is unsupported. "
1292                             "Use Operation.create for such low-level access.");
1293     }
1294 
1295     // Add result_segment_sizes attribute.
1296     if (!resultSegmentLengths.empty()) {
1297       int64_t size = resultSegmentLengths.size();
1298       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1299           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1300           resultSegmentLengths.size(), resultSegmentLengths.data());
1301       (*attributes)["result_segment_sizes"] =
1302           PyAttribute(context, segmentLengthAttr);
1303     }
1304 
1305     // Add operand_segment_sizes attribute.
1306     if (!operandSegmentLengths.empty()) {
1307       int64_t size = operandSegmentLengths.size();
1308       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1309           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1310           operandSegmentLengths.size(), operandSegmentLengths.data());
1311       (*attributes)["operand_segment_sizes"] =
1312           PyAttribute(context, segmentLengthAttr);
1313     }
1314   }
1315 
1316   // Delegate to create.
1317   return PyOperation::create(std::move(name),
1318                              /*results=*/std::move(resultTypes),
1319                              /*operands=*/std::move(operands),
1320                              /*attributes=*/std::move(attributes),
1321                              /*successors=*/std::move(successors),
1322                              /*regions=*/*regions, location, maybeIp);
1323 }
1324 
PyOpView(py::object operationObject)1325 PyOpView::PyOpView(py::object operationObject)
1326     // Casting through the PyOperationBase base-class and then back to the
1327     // Operation lets us accept any PyOperationBase subclass.
1328     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1329       operationObject(operation.getRef().getObject()) {}
1330 
createRawSubclass(py::object userClass)1331 py::object PyOpView::createRawSubclass(py::object userClass) {
1332   // This is... a little gross. The typical pattern is to have a pure python
1333   // class that extends OpView like:
1334   //   class AddFOp(_cext.ir.OpView):
1335   //     def __init__(self, loc, lhs, rhs):
1336   //       operation = loc.context.create_operation(
1337   //           "addf", lhs, rhs, results=[lhs.type])
1338   //       super().__init__(operation)
1339   //
1340   // I.e. The goal of the user facing type is to provide a nice constructor
1341   // that has complete freedom for the op under construction. This is at odds
1342   // with our other desire to sometimes create this object by just passing an
1343   // operation (to initialize the base class). We could do *arg and **kwargs
1344   // munging to try to make it work, but instead, we synthesize a new class
1345   // on the fly which extends this user class (AddFOp in this example) and
1346   // *give it* the base class's __init__ method, thus bypassing the
1347   // intermediate subclass's __init__ method entirely. While slightly,
1348   // underhanded, this is safe/legal because the type hierarchy has not changed
1349   // (we just added a new leaf) and we aren't mucking around with __new__.
1350   // Typically, this new class will be stored on the original as "_Raw" and will
1351   // be used for casts and other things that need a variant of the class that
1352   // is initialized purely from an operation.
1353   py::object parentMetaclass =
1354       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1355   py::dict attributes;
1356   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1357   // now.
1358   //   auto opViewType = py::type::of<PyOpView>();
1359   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1360   attributes["__init__"] = opViewType.attr("__init__");
1361   py::str origName = userClass.attr("__name__");
1362   py::str newName = py::str("_") + origName;
1363   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1364 }
1365 
1366 //------------------------------------------------------------------------------
1367 // PyInsertionPoint.
1368 //------------------------------------------------------------------------------
1369 
PyInsertionPoint(PyBlock & block)1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1371 
PyInsertionPoint(PyOperationBase & beforeOperationBase)1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1373     : refOperation(beforeOperationBase.getOperation().getRef()),
1374       block((*refOperation)->getBlock()) {}
1375 
insert(PyOperationBase & operationBase)1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1377   PyOperation &operation = operationBase.getOperation();
1378   if (operation.isAttached())
1379     throw SetPyError(PyExc_ValueError,
1380                      "Attempt to insert operation that is already attached");
1381   block.getParentOperation()->checkValid();
1382   MlirOperation beforeOp = {nullptr};
1383   if (refOperation) {
1384     // Insert before operation.
1385     (*refOperation)->checkValid();
1386     beforeOp = (*refOperation)->get();
1387   } else {
1388     // Insert at end (before null) is only valid if the block does not
1389     // already end in a known terminator (violating this will cause assertion
1390     // failures later).
1391     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1392       throw py::index_error("Cannot insert operation at the end of a block "
1393                             "that already has a terminator. Did you mean to "
1394                             "use 'InsertionPoint.at_block_terminator(block)' "
1395                             "versus 'InsertionPoint(block)'?");
1396     }
1397   }
1398   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1399   operation.setAttached();
1400 }
1401 
atBlockBegin(PyBlock & block)1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1403   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1404   if (mlirOperationIsNull(firstOp)) {
1405     // Just insert at end.
1406     return PyInsertionPoint(block);
1407   }
1408 
1409   // Insert before first op.
1410   PyOperationRef firstOpRef = PyOperation::forOperation(
1411       block.getParentOperation()->getContext(), firstOp);
1412   return PyInsertionPoint{block, std::move(firstOpRef)};
1413 }
1414 
atBlockTerminator(PyBlock & block)1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1416   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1417   if (mlirOperationIsNull(terminator))
1418     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1419   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1420       block.getParentOperation()->getContext(), terminator);
1421   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1422 }
1423 
contextEnter()1424 py::object PyInsertionPoint::contextEnter() {
1425   return PyThreadContextEntry::pushInsertionPoint(*this);
1426 }
1427 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1428 void PyInsertionPoint::contextExit(pybind11::object excType,
1429                                    pybind11::object excVal,
1430                                    pybind11::object excTb) {
1431   PyThreadContextEntry::popInsertionPoint(*this);
1432 }
1433 
1434 //------------------------------------------------------------------------------
1435 // PyAttribute.
1436 //------------------------------------------------------------------------------
1437 
operator ==(const PyAttribute & other)1438 bool PyAttribute::operator==(const PyAttribute &other) {
1439   return mlirAttributeEqual(attr, other.attr);
1440 }
1441 
getCapsule()1442 py::object PyAttribute::getCapsule() {
1443   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1444 }
1445 
createFromCapsule(py::object capsule)1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1447   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1448   if (mlirAttributeIsNull(rawAttr))
1449     throw py::error_already_set();
1450   return PyAttribute(
1451       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1452 }
1453 
1454 //------------------------------------------------------------------------------
1455 // PyNamedAttribute.
1456 //------------------------------------------------------------------------------
1457 
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1459     : ownedName(new std::string(std::move(ownedName))) {
1460   namedAttr = mlirNamedAttributeGet(
1461       mlirIdentifierGet(mlirAttributeGetContext(attr),
1462                         toMlirStringRef(*this->ownedName)),
1463       attr);
1464 }
1465 
1466 //------------------------------------------------------------------------------
1467 // PyType.
1468 //------------------------------------------------------------------------------
1469 
operator ==(const PyType & other)1470 bool PyType::operator==(const PyType &other) {
1471   return mlirTypeEqual(type, other.type);
1472 }
1473 
getCapsule()1474 py::object PyType::getCapsule() {
1475   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1476 }
1477 
createFromCapsule(py::object capsule)1478 PyType PyType::createFromCapsule(py::object capsule) {
1479   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1480   if (mlirTypeIsNull(rawType))
1481     throw py::error_already_set();
1482   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1483                 rawType);
1484 }
1485 
1486 //------------------------------------------------------------------------------
1487 // PyValue and subclases.
1488 //------------------------------------------------------------------------------
1489 
getCapsule()1490 pybind11::object PyValue::getCapsule() {
1491   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1492 }
1493 
createFromCapsule(pybind11::object capsule)1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1495   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1496   if (mlirValueIsNull(value))
1497     throw py::error_already_set();
1498   MlirOperation owner;
1499   if (mlirValueIsAOpResult(value))
1500     owner = mlirOpResultGetOwner(value);
1501   if (mlirValueIsABlockArgument(value))
1502     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1503   if (mlirOperationIsNull(owner))
1504     throw py::error_already_set();
1505   MlirContext ctx = mlirOperationGetContext(owner);
1506   PyOperationRef ownerRef =
1507       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1508   return PyValue(ownerRef, value);
1509 }
1510 
1511 namespace {
1512 /// CRTP base class for Python MLIR values that subclass Value and should be
1513 /// castable from it. The value hierarchy is one level deep and is not supposed
1514 /// to accommodate other levels unless core MLIR changes.
1515 template <typename DerivedTy>
1516 class PyConcreteValue : public PyValue {
1517 public:
1518   // Derived classes must define statics for:
1519   //   IsAFunctionTy isaFunction
1520   //   const char *pyClassName
1521   // and redefine bindDerived.
1522   using ClassTy = py::class_<DerivedTy, PyValue>;
1523   using IsAFunctionTy = bool (*)(MlirValue);
1524 
1525   PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1526   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1527       : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1528   PyConcreteValue(PyValue &orig)
1529       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1530 
1531   /// Attempts to cast the original value to the derived type and throws on
1532   /// type mismatches.
castFrom(PyValue & orig)1533   static MlirValue castFrom(PyValue &orig) {
1534     if (!DerivedTy::isaFunction(orig.get())) {
1535       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1536       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1537                                              DerivedTy::pyClassName +
1538                                              " (from " + origRepr + ")");
1539     }
1540     return orig.get();
1541   }
1542 
1543   /// Binds the Python module objects to functions of this class.
bind(py::module & m)1544   static void bind(py::module &m) {
1545     auto cls = ClassTy(m, DerivedTy::pyClassName);
1546     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1547     DerivedTy::bindDerived(cls);
1548   }
1549 
1550   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1551   static void bindDerived(ClassTy &m) {}
1552 };
1553 
1554 /// Python wrapper for MlirBlockArgument.
1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1556 public:
1557   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1558   static constexpr const char *pyClassName = "BlockArgument";
1559   using PyConcreteValue::PyConcreteValue;
1560 
bindDerived(ClassTy & c)1561   static void bindDerived(ClassTy &c) {
1562     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1563       return PyBlock(self.getParentOperation(),
1564                      mlirBlockArgumentGetOwner(self.get()));
1565     });
1566     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1567       return mlirBlockArgumentGetArgNumber(self.get());
1568     });
1569     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1570       return mlirBlockArgumentSetType(self.get(), type);
1571     });
1572   }
1573 };
1574 
1575 /// Python wrapper for MlirOpResult.
1576 class PyOpResult : public PyConcreteValue<PyOpResult> {
1577 public:
1578   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1579   static constexpr const char *pyClassName = "OpResult";
1580   using PyConcreteValue::PyConcreteValue;
1581 
bindDerived(ClassTy & c)1582   static void bindDerived(ClassTy &c) {
1583     c.def_property_readonly("owner", [](PyOpResult &self) {
1584       assert(
1585           mlirOperationEqual(self.getParentOperation()->get(),
1586                              mlirOpResultGetOwner(self.get())) &&
1587           "expected the owner of the value in Python to match that in the IR");
1588       return self.getParentOperation().getObject();
1589     });
1590     c.def_property_readonly("result_number", [](PyOpResult &self) {
1591       return mlirOpResultGetResultNumber(self.get());
1592     });
1593   }
1594 };
1595 
1596 /// A list of block arguments. Internally, these are stored as consecutive
1597 /// elements, random access is cheap. The argument list is associated with the
1598 /// operation that contains the block (detached blocks are not allowed in
1599 /// Python bindings) and extends its lifetime.
1600 class PyBlockArgumentList {
1601 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1602   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1603       : operation(std::move(operation)), block(block) {}
1604 
1605   /// Returns the length of the block argument list.
dunderLen()1606   intptr_t dunderLen() {
1607     operation->checkValid();
1608     return mlirBlockGetNumArguments(block);
1609   }
1610 
1611   /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1612   PyBlockArgument dunderGetItem(intptr_t index) {
1613     if (index < 0 || index >= dunderLen()) {
1614       throw SetPyError(PyExc_IndexError,
1615                        "attempt to access out of bounds region");
1616     }
1617     PyValue value(operation, mlirBlockGetArgument(block, index));
1618     return PyBlockArgument(value);
1619   }
1620 
1621   /// Defines a Python class in the bindings.
bind(py::module & m)1622   static void bind(py::module &m) {
1623     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1624         .def("__len__", &PyBlockArgumentList::dunderLen)
1625         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1626   }
1627 
1628 private:
1629   PyOperationRef operation;
1630   MlirBlock block;
1631 };
1632 
1633 /// A list of operation operands. Internally, these are stored as consecutive
1634 /// elements, random access is cheap. The result list is associated with the
1635 /// operation whose results these are, and extends the lifetime of this
1636 /// operation.
1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1638 public:
1639   static constexpr const char *pyClassName = "OpOperandList";
1640 
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1641   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1642                   intptr_t length = -1, intptr_t step = 1)
1643       : Sliceable(startIndex,
1644                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1645                                : length,
1646                   step),
1647         operation(operation) {}
1648 
getNumElements()1649   intptr_t getNumElements() {
1650     operation->checkValid();
1651     return mlirOperationGetNumOperands(operation->get());
1652   }
1653 
getElement(intptr_t pos)1654   PyValue getElement(intptr_t pos) {
1655     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1656     MlirOperation owner;
1657     if (mlirValueIsAOpResult(operand))
1658       owner = mlirOpResultGetOwner(operand);
1659     else if (mlirValueIsABlockArgument(operand))
1660       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1661     else
1662       assert(false && "Value must be an block arg or op result.");
1663     PyOperationRef pyOwner =
1664         PyOperation::forOperation(operation->getContext(), owner);
1665     return PyValue(pyOwner, operand);
1666   }
1667 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1668   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1669     return PyOpOperandList(operation, startIndex, length, step);
1670   }
1671 
dunderSetItem(intptr_t index,PyValue value)1672   void dunderSetItem(intptr_t index, PyValue value) {
1673     index = wrapIndex(index);
1674     mlirOperationSetOperand(operation->get(), index, value.get());
1675   }
1676 
bindDerived(ClassTy & c)1677   static void bindDerived(ClassTy &c) {
1678     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1679   }
1680 
1681 private:
1682   PyOperationRef operation;
1683 };
1684 
1685 /// A list of operation results. Internally, these are stored as consecutive
1686 /// elements, random access is cheap. The result list is associated with the
1687 /// operation whose results these are, and extends the lifetime of this
1688 /// operation.
1689 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1690 public:
1691   static constexpr const char *pyClassName = "OpResultList";
1692 
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1693   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1694                  intptr_t length = -1, intptr_t step = 1)
1695       : Sliceable(startIndex,
1696                   length == -1 ? mlirOperationGetNumResults(operation->get())
1697                                : length,
1698                   step),
1699         operation(operation) {}
1700 
getNumElements()1701   intptr_t getNumElements() {
1702     operation->checkValid();
1703     return mlirOperationGetNumResults(operation->get());
1704   }
1705 
getElement(intptr_t index)1706   PyOpResult getElement(intptr_t index) {
1707     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1708     return PyOpResult(value);
1709   }
1710 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1711   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1712     return PyOpResultList(operation, startIndex, length, step);
1713   }
1714 
1715 private:
1716   PyOperationRef operation;
1717 };
1718 
1719 /// A list of operation attributes. Can be indexed by name, producing
1720 /// attributes, or by index, producing named attributes.
1721 class PyOpAttributeMap {
1722 public:
PyOpAttributeMap(PyOperationRef operation)1723   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1724 
dunderGetItemNamed(const std::string & name)1725   PyAttribute dunderGetItemNamed(const std::string &name) {
1726     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1727                                                          toMlirStringRef(name));
1728     if (mlirAttributeIsNull(attr)) {
1729       throw SetPyError(PyExc_KeyError,
1730                        "attempt to access a non-existent attribute");
1731     }
1732     return PyAttribute(operation->getContext(), attr);
1733   }
1734 
dunderGetItemIndexed(intptr_t index)1735   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1736     if (index < 0 || index >= dunderLen()) {
1737       throw SetPyError(PyExc_IndexError,
1738                        "attempt to access out of bounds attribute");
1739     }
1740     MlirNamedAttribute namedAttr =
1741         mlirOperationGetAttribute(operation->get(), index);
1742     return PyNamedAttribute(
1743         namedAttr.attribute,
1744         std::string(mlirIdentifierStr(namedAttr.name).data));
1745   }
1746 
dunderSetItem(const std::string & name,PyAttribute attr)1747   void dunderSetItem(const std::string &name, PyAttribute attr) {
1748     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1749                                     attr);
1750   }
1751 
dunderDelItem(const std::string & name)1752   void dunderDelItem(const std::string &name) {
1753     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1754                                                      toMlirStringRef(name));
1755     if (!removed)
1756       throw SetPyError(PyExc_KeyError,
1757                        "attempt to delete a non-existent attribute");
1758   }
1759 
dunderLen()1760   intptr_t dunderLen() {
1761     return mlirOperationGetNumAttributes(operation->get());
1762   }
1763 
dunderContains(const std::string & name)1764   bool dunderContains(const std::string &name) {
1765     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1766         operation->get(), toMlirStringRef(name)));
1767   }
1768 
bind(py::module & m)1769   static void bind(py::module &m) {
1770     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1771         .def("__contains__", &PyOpAttributeMap::dunderContains)
1772         .def("__len__", &PyOpAttributeMap::dunderLen)
1773         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1774         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1775         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1776         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1777   }
1778 
1779 private:
1780   PyOperationRef operation;
1781 };
1782 
1783 } // end namespace
1784 
1785 //------------------------------------------------------------------------------
1786 // Populates the core exports of the 'ir' submodule.
1787 //------------------------------------------------------------------------------
1788 
populateIRCore(py::module & m)1789 void mlir::python::populateIRCore(py::module &m) {
1790   //----------------------------------------------------------------------------
1791   // Mapping of MlirContext.
1792   //----------------------------------------------------------------------------
1793   py::class_<PyMlirContext>(m, "Context")
1794       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1795       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1796       .def("_get_context_again",
1797            [](PyMlirContext &self) {
1798              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1799              return ref.releaseObject();
1800            })
1801       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1802       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1803       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1804                              &PyMlirContext::getCapsule)
1805       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1806       .def("__enter__", &PyMlirContext::contextEnter)
1807       .def("__exit__", &PyMlirContext::contextExit)
1808       .def_property_readonly_static(
1809           "current",
1810           [](py::object & /*class*/) {
1811             auto *context = PyThreadContextEntry::getDefaultContext();
1812             if (!context)
1813               throw SetPyError(PyExc_ValueError, "No current Context");
1814             return context;
1815           },
1816           "Gets the Context bound to the current thread or raises ValueError")
1817       .def_property_readonly(
1818           "dialects",
1819           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1820           "Gets a container for accessing dialects by name")
1821       .def_property_readonly(
1822           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1823           "Alias for 'dialect'")
1824       .def(
1825           "get_dialect_descriptor",
1826           [=](PyMlirContext &self, std::string &name) {
1827             MlirDialect dialect = mlirContextGetOrLoadDialect(
1828                 self.get(), {name.data(), name.size()});
1829             if (mlirDialectIsNull(dialect)) {
1830               throw SetPyError(PyExc_ValueError,
1831                                Twine("Dialect '") + name + "' not found");
1832             }
1833             return PyDialectDescriptor(self.getRef(), dialect);
1834           },
1835           "Gets or loads a dialect by name, returning its descriptor object")
1836       .def_property(
1837           "allow_unregistered_dialects",
1838           [](PyMlirContext &self) -> bool {
1839             return mlirContextGetAllowUnregisteredDialects(self.get());
1840           },
1841           [](PyMlirContext &self, bool value) {
1842             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1843           })
1844       .def("enable_multithreading",
1845            [](PyMlirContext &self, bool enable) {
1846              mlirContextEnableMultithreading(self.get(), enable);
1847            })
1848       .def("is_registered_operation",
1849            [](PyMlirContext &self, std::string &name) {
1850              return mlirContextIsRegisteredOperation(
1851                  self.get(), MlirStringRef{name.data(), name.size()});
1852            });
1853 
1854   //----------------------------------------------------------------------------
1855   // Mapping of PyDialectDescriptor
1856   //----------------------------------------------------------------------------
1857   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1858       .def_property_readonly("namespace",
1859                              [](PyDialectDescriptor &self) {
1860                                MlirStringRef ns =
1861                                    mlirDialectGetNamespace(self.get());
1862                                return py::str(ns.data, ns.length);
1863                              })
1864       .def("__repr__", [](PyDialectDescriptor &self) {
1865         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1866         std::string repr("<DialectDescriptor ");
1867         repr.append(ns.data, ns.length);
1868         repr.append(">");
1869         return repr;
1870       });
1871 
1872   //----------------------------------------------------------------------------
1873   // Mapping of PyDialects
1874   //----------------------------------------------------------------------------
1875   py::class_<PyDialects>(m, "Dialects")
1876       .def("__getitem__",
1877            [=](PyDialects &self, std::string keyName) {
1878              MlirDialect dialect =
1879                  self.getDialectForKey(keyName, /*attrError=*/false);
1880              py::object descriptor =
1881                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1882              return createCustomDialectWrapper(keyName, std::move(descriptor));
1883            })
1884       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1885         MlirDialect dialect =
1886             self.getDialectForKey(attrName, /*attrError=*/true);
1887         py::object descriptor =
1888             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1889         return createCustomDialectWrapper(attrName, std::move(descriptor));
1890       });
1891 
1892   //----------------------------------------------------------------------------
1893   // Mapping of PyDialect
1894   //----------------------------------------------------------------------------
1895   py::class_<PyDialect>(m, "Dialect")
1896       .def(py::init<py::object>(), "descriptor")
1897       .def_property_readonly(
1898           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1899       .def("__repr__", [](py::object self) {
1900         auto clazz = self.attr("__class__");
1901         return py::str("<Dialect ") +
1902                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1903                clazz.attr("__module__") + py::str(".") +
1904                clazz.attr("__name__") + py::str(")>");
1905       });
1906 
1907   //----------------------------------------------------------------------------
1908   // Mapping of Location
1909   //----------------------------------------------------------------------------
1910   py::class_<PyLocation>(m, "Location")
1911       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1912       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1913       .def("__enter__", &PyLocation::contextEnter)
1914       .def("__exit__", &PyLocation::contextExit)
1915       .def("__eq__",
1916            [](PyLocation &self, PyLocation &other) -> bool {
1917              return mlirLocationEqual(self, other);
1918            })
1919       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1920       .def_property_readonly_static(
1921           "current",
1922           [](py::object & /*class*/) {
1923             auto *loc = PyThreadContextEntry::getDefaultLocation();
1924             if (!loc)
1925               throw SetPyError(PyExc_ValueError, "No current Location");
1926             return loc;
1927           },
1928           "Gets the Location bound to the current thread or raises ValueError")
1929       .def_static(
1930           "unknown",
1931           [](DefaultingPyMlirContext context) {
1932             return PyLocation(context->getRef(),
1933                               mlirLocationUnknownGet(context->get()));
1934           },
1935           py::arg("context") = py::none(),
1936           "Gets a Location representing an unknown location")
1937       .def_static(
1938           "file",
1939           [](std::string filename, int line, int col,
1940              DefaultingPyMlirContext context) {
1941             return PyLocation(
1942                 context->getRef(),
1943                 mlirLocationFileLineColGet(
1944                     context->get(), toMlirStringRef(filename), line, col));
1945           },
1946           py::arg("filename"), py::arg("line"), py::arg("col"),
1947           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1948       .def_property_readonly(
1949           "context",
1950           [](PyLocation &self) { return self.getContext().getObject(); },
1951           "Context that owns the Location")
1952       .def("__repr__", [](PyLocation &self) {
1953         PyPrintAccumulator printAccum;
1954         mlirLocationPrint(self, printAccum.getCallback(),
1955                           printAccum.getUserData());
1956         return printAccum.join();
1957       });
1958 
1959   //----------------------------------------------------------------------------
1960   // Mapping of Module
1961   //----------------------------------------------------------------------------
1962   py::class_<PyModule>(m, "Module")
1963       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1964       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1965       .def_static(
1966           "parse",
1967           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1968             MlirModule module = mlirModuleCreateParse(
1969                 context->get(), toMlirStringRef(moduleAsm));
1970             // TODO: Rework error reporting once diagnostic engine is exposed
1971             // in C API.
1972             if (mlirModuleIsNull(module)) {
1973               throw SetPyError(
1974                   PyExc_ValueError,
1975                   "Unable to parse module assembly (see diagnostics)");
1976             }
1977             return PyModule::forModule(module).releaseObject();
1978           },
1979           py::arg("asm"), py::arg("context") = py::none(),
1980           kModuleParseDocstring)
1981       .def_static(
1982           "create",
1983           [](DefaultingPyLocation loc) {
1984             MlirModule module = mlirModuleCreateEmpty(loc);
1985             return PyModule::forModule(module).releaseObject();
1986           },
1987           py::arg("loc") = py::none(), "Creates an empty module")
1988       .def_property_readonly(
1989           "context",
1990           [](PyModule &self) { return self.getContext().getObject(); },
1991           "Context that created the Module")
1992       .def_property_readonly(
1993           "operation",
1994           [](PyModule &self) {
1995             return PyOperation::forOperation(self.getContext(),
1996                                              mlirModuleGetOperation(self.get()),
1997                                              self.getRef().releaseObject())
1998                 .releaseObject();
1999           },
2000           "Accesses the module as an operation")
2001       .def_property_readonly(
2002           "body",
2003           [](PyModule &self) {
2004             PyOperationRef module_op = PyOperation::forOperation(
2005                 self.getContext(), mlirModuleGetOperation(self.get()),
2006                 self.getRef().releaseObject());
2007             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2008             return returnBlock;
2009           },
2010           "Return the block for this module")
2011       .def(
2012           "dump",
2013           [](PyModule &self) {
2014             mlirOperationDump(mlirModuleGetOperation(self.get()));
2015           },
2016           kDumpDocstring)
2017       .def(
2018           "__str__",
2019           [](PyModule &self) {
2020             MlirOperation operation = mlirModuleGetOperation(self.get());
2021             PyPrintAccumulator printAccum;
2022             mlirOperationPrint(operation, printAccum.getCallback(),
2023                                printAccum.getUserData());
2024             return printAccum.join();
2025           },
2026           kOperationStrDunderDocstring);
2027 
2028   //----------------------------------------------------------------------------
2029   // Mapping of Operation.
2030   //----------------------------------------------------------------------------
2031   py::class_<PyOperationBase>(m, "_OperationBase")
2032       .def("__eq__",
2033            [](PyOperationBase &self, PyOperationBase &other) {
2034              return &self.getOperation() == &other.getOperation();
2035            })
2036       .def("__eq__",
2037            [](PyOperationBase &self, py::object other) { return false; })
2038       .def_property_readonly("attributes",
2039                              [](PyOperationBase &self) {
2040                                return PyOpAttributeMap(
2041                                    self.getOperation().getRef());
2042                              })
2043       .def_property_readonly("operands",
2044                              [](PyOperationBase &self) {
2045                                return PyOpOperandList(
2046                                    self.getOperation().getRef());
2047                              })
2048       .def_property_readonly("regions",
2049                              [](PyOperationBase &self) {
2050                                return PyRegionList(
2051                                    self.getOperation().getRef());
2052                              })
2053       .def_property_readonly(
2054           "results",
2055           [](PyOperationBase &self) {
2056             return PyOpResultList(self.getOperation().getRef());
2057           },
2058           "Returns the list of Operation results.")
2059       .def_property_readonly(
2060           "result",
2061           [](PyOperationBase &self) {
2062             auto &operation = self.getOperation();
2063             auto numResults = mlirOperationGetNumResults(operation);
2064             if (numResults != 1) {
2065               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2066               throw SetPyError(
2067                   PyExc_ValueError,
2068                   Twine("Cannot call .result on operation ") +
2069                       StringRef(name.data, name.length) + " which has " +
2070                       Twine(numResults) +
2071                       " results (it is only valid for operations with a "
2072                       "single result)");
2073             }
2074             return PyOpResult(operation.getRef(),
2075                               mlirOperationGetResult(operation, 0));
2076           },
2077           "Shortcut to get an op result if it has only one (throws an error "
2078           "otherwise).")
2079       .def("__iter__",
2080            [](PyOperationBase &self) {
2081              return PyRegionIterator(self.getOperation().getRef());
2082            })
2083       .def(
2084           "__str__",
2085           [](PyOperationBase &self) {
2086             return self.getAsm(/*binary=*/false,
2087                                /*largeElementsLimit=*/llvm::None,
2088                                /*enableDebugInfo=*/false,
2089                                /*prettyDebugInfo=*/false,
2090                                /*printGenericOpForm=*/false,
2091                                /*useLocalScope=*/false);
2092           },
2093           "Returns the assembly form of the operation.")
2094       .def("print", &PyOperationBase::print,
2095            // Careful: Lots of arguments must match up with print method.
2096            py::arg("file") = py::none(), py::arg("binary") = false,
2097            py::arg("large_elements_limit") = py::none(),
2098            py::arg("enable_debug_info") = false,
2099            py::arg("pretty_debug_info") = false,
2100            py::arg("print_generic_op_form") = false,
2101            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2102       .def("get_asm", &PyOperationBase::getAsm,
2103            // Careful: Lots of arguments must match up with get_asm method.
2104            py::arg("binary") = false,
2105            py::arg("large_elements_limit") = py::none(),
2106            py::arg("enable_debug_info") = false,
2107            py::arg("pretty_debug_info") = false,
2108            py::arg("print_generic_op_form") = false,
2109            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2110       .def(
2111           "verify",
2112           [](PyOperationBase &self) {
2113             return mlirOperationVerify(self.getOperation());
2114           },
2115           "Verify the operation and return true if it passes, false if it "
2116           "fails.");
2117 
2118   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2119       .def_static("create", &PyOperation::create, py::arg("name"),
2120                   py::arg("results") = py::none(),
2121                   py::arg("operands") = py::none(),
2122                   py::arg("attributes") = py::none(),
2123                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2124                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2125                   kOperationCreateDocstring)
2126       .def_property_readonly("parent",
2127                              [](PyOperation &self) {
2128                                return self.getParentOperation().getObject();
2129                              })
2130       .def("erase", &PyOperation::erase)
2131       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2132                              &PyOperation::getCapsule)
2133       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2134       .def_property_readonly("name",
2135                              [](PyOperation &self) {
2136                                self.checkValid();
2137                                MlirOperation operation = self.get();
2138                                MlirStringRef name = mlirIdentifierStr(
2139                                    mlirOperationGetName(operation));
2140                                return py::str(name.data, name.length);
2141                              })
2142       .def_property_readonly(
2143           "context",
2144           [](PyOperation &self) {
2145             self.checkValid();
2146             return self.getContext().getObject();
2147           },
2148           "Context that owns the Operation")
2149       .def_property_readonly("opview", &PyOperation::createOpView);
2150 
2151   auto opViewClass =
2152       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2153           .def(py::init<py::object>())
2154           .def_property_readonly("operation", &PyOpView::getOperationObject)
2155           .def_property_readonly(
2156               "context",
2157               [](PyOpView &self) {
2158                 return self.getOperation().getContext().getObject();
2159               },
2160               "Context that owns the Operation")
2161           .def("__str__", [](PyOpView &self) {
2162             return py::str(self.getOperationObject());
2163           });
2164   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2165   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2166   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2167   opViewClass.attr("build_generic") = classmethod(
2168       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2169       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2170       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2171       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2172       "Builds a specific, generated OpView based on class level attributes.");
2173 
2174   //----------------------------------------------------------------------------
2175   // Mapping of PyRegion.
2176   //----------------------------------------------------------------------------
2177   py::class_<PyRegion>(m, "Region")
2178       .def_property_readonly(
2179           "blocks",
2180           [](PyRegion &self) {
2181             return PyBlockList(self.getParentOperation(), self.get());
2182           },
2183           "Returns a forward-optimized sequence of blocks.")
2184       .def(
2185           "__iter__",
2186           [](PyRegion &self) {
2187             self.checkValid();
2188             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2189             return PyBlockIterator(self.getParentOperation(), firstBlock);
2190           },
2191           "Iterates over blocks in the region.")
2192       .def("__eq__",
2193            [](PyRegion &self, PyRegion &other) {
2194              return self.get().ptr == other.get().ptr;
2195            })
2196       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2197 
2198   //----------------------------------------------------------------------------
2199   // Mapping of PyBlock.
2200   //----------------------------------------------------------------------------
2201   py::class_<PyBlock>(m, "Block")
2202       .def_property_readonly(
2203           "arguments",
2204           [](PyBlock &self) {
2205             return PyBlockArgumentList(self.getParentOperation(), self.get());
2206           },
2207           "Returns a list of block arguments.")
2208       .def_property_readonly(
2209           "operations",
2210           [](PyBlock &self) {
2211             return PyOperationList(self.getParentOperation(), self.get());
2212           },
2213           "Returns a forward-optimized sequence of operations.")
2214       .def(
2215           "__iter__",
2216           [](PyBlock &self) {
2217             self.checkValid();
2218             MlirOperation firstOperation =
2219                 mlirBlockGetFirstOperation(self.get());
2220             return PyOperationIterator(self.getParentOperation(),
2221                                        firstOperation);
2222           },
2223           "Iterates over operations in the block.")
2224       .def("__eq__",
2225            [](PyBlock &self, PyBlock &other) {
2226              return self.get().ptr == other.get().ptr;
2227            })
2228       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2229       .def(
2230           "__str__",
2231           [](PyBlock &self) {
2232             self.checkValid();
2233             PyPrintAccumulator printAccum;
2234             mlirBlockPrint(self.get(), printAccum.getCallback(),
2235                            printAccum.getUserData());
2236             return printAccum.join();
2237           },
2238           "Returns the assembly form of the block.");
2239 
2240   //----------------------------------------------------------------------------
2241   // Mapping of PyInsertionPoint.
2242   //----------------------------------------------------------------------------
2243 
2244   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2245       .def(py::init<PyBlock &>(), py::arg("block"),
2246            "Inserts after the last operation but still inside the block.")
2247       .def("__enter__", &PyInsertionPoint::contextEnter)
2248       .def("__exit__", &PyInsertionPoint::contextExit)
2249       .def_property_readonly_static(
2250           "current",
2251           [](py::object & /*class*/) {
2252             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2253             if (!ip)
2254               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2255             return ip;
2256           },
2257           "Gets the InsertionPoint bound to the current thread or raises "
2258           "ValueError if none has been set")
2259       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2260            "Inserts before a referenced operation.")
2261       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2262                   py::arg("block"), "Inserts at the beginning of the block.")
2263       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2264                   py::arg("block"), "Inserts before the block terminator.")
2265       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2266            "Inserts an operation.");
2267 
2268   //----------------------------------------------------------------------------
2269   // Mapping of PyAttribute.
2270   //----------------------------------------------------------------------------
2271   py::class_<PyAttribute>(m, "Attribute")
2272       // Delegate to the PyAttribute copy constructor, which will also lifetime
2273       // extend the backing context which owns the MlirAttribute.
2274       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2275            "Casts the passed attribute to the generic Attribute")
2276       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2277                              &PyAttribute::getCapsule)
2278       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2279       .def_static(
2280           "parse",
2281           [](std::string attrSpec, DefaultingPyMlirContext context) {
2282             MlirAttribute type = mlirAttributeParseGet(
2283                 context->get(), toMlirStringRef(attrSpec));
2284             // TODO: Rework error reporting once diagnostic engine is exposed
2285             // in C API.
2286             if (mlirAttributeIsNull(type)) {
2287               throw SetPyError(PyExc_ValueError,
2288                                Twine("Unable to parse attribute: '") +
2289                                    attrSpec + "'");
2290             }
2291             return PyAttribute(context->getRef(), type);
2292           },
2293           py::arg("asm"), py::arg("context") = py::none(),
2294           "Parses an attribute from an assembly form")
2295       .def_property_readonly(
2296           "context",
2297           [](PyAttribute &self) { return self.getContext().getObject(); },
2298           "Context that owns the Attribute")
2299       .def_property_readonly("type",
2300                              [](PyAttribute &self) {
2301                                return PyType(self.getContext()->getRef(),
2302                                              mlirAttributeGetType(self));
2303                              })
2304       .def(
2305           "get_named",
2306           [](PyAttribute &self, std::string name) {
2307             return PyNamedAttribute(self, std::move(name));
2308           },
2309           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2310       .def("__eq__",
2311            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2312       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2313       .def(
2314           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2315           kDumpDocstring)
2316       .def(
2317           "__str__",
2318           [](PyAttribute &self) {
2319             PyPrintAccumulator printAccum;
2320             mlirAttributePrint(self, printAccum.getCallback(),
2321                                printAccum.getUserData());
2322             return printAccum.join();
2323           },
2324           "Returns the assembly form of the Attribute.")
2325       .def("__repr__", [](PyAttribute &self) {
2326         // Generally, assembly formats are not printed for __repr__ because
2327         // this can cause exceptionally long debug output and exceptions.
2328         // However, attribute values are generally considered useful and are
2329         // printed. This may need to be re-evaluated if debug dumps end up
2330         // being excessive.
2331         PyPrintAccumulator printAccum;
2332         printAccum.parts.append("Attribute(");
2333         mlirAttributePrint(self, printAccum.getCallback(),
2334                            printAccum.getUserData());
2335         printAccum.parts.append(")");
2336         return printAccum.join();
2337       });
2338 
2339   //----------------------------------------------------------------------------
2340   // Mapping of PyNamedAttribute
2341   //----------------------------------------------------------------------------
2342   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2343       .def("__repr__",
2344            [](PyNamedAttribute &self) {
2345              PyPrintAccumulator printAccum;
2346              printAccum.parts.append("NamedAttribute(");
2347              printAccum.parts.append(
2348                  mlirIdentifierStr(self.namedAttr.name).data);
2349              printAccum.parts.append("=");
2350              mlirAttributePrint(self.namedAttr.attribute,
2351                                 printAccum.getCallback(),
2352                                 printAccum.getUserData());
2353              printAccum.parts.append(")");
2354              return printAccum.join();
2355            })
2356       .def_property_readonly(
2357           "name",
2358           [](PyNamedAttribute &self) {
2359             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2360                            mlirIdentifierStr(self.namedAttr.name).length);
2361           },
2362           "The name of the NamedAttribute binding")
2363       .def_property_readonly(
2364           "attr",
2365           [](PyNamedAttribute &self) {
2366             // TODO: When named attribute is removed/refactored, also remove
2367             // this constructor (it does an inefficient table lookup).
2368             auto contextRef = PyMlirContext::forContext(
2369                 mlirAttributeGetContext(self.namedAttr.attribute));
2370             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2371           },
2372           py::keep_alive<0, 1>(),
2373           "The underlying generic attribute of the NamedAttribute binding");
2374 
2375   //----------------------------------------------------------------------------
2376   // Mapping of PyType.
2377   //----------------------------------------------------------------------------
2378   py::class_<PyType>(m, "Type")
2379       // Delegate to the PyType copy constructor, which will also lifetime
2380       // extend the backing context which owns the MlirType.
2381       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2382            "Casts the passed type to the generic Type")
2383       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2384       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2385       .def_static(
2386           "parse",
2387           [](std::string typeSpec, DefaultingPyMlirContext context) {
2388             MlirType type =
2389                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2390             // TODO: Rework error reporting once diagnostic engine is exposed
2391             // in C API.
2392             if (mlirTypeIsNull(type)) {
2393               throw SetPyError(PyExc_ValueError,
2394                                Twine("Unable to parse type: '") + typeSpec +
2395                                    "'");
2396             }
2397             return PyType(context->getRef(), type);
2398           },
2399           py::arg("asm"), py::arg("context") = py::none(),
2400           kContextParseTypeDocstring)
2401       .def_property_readonly(
2402           "context", [](PyType &self) { return self.getContext().getObject(); },
2403           "Context that owns the Type")
2404       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2405       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2406       .def(
2407           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2408       .def(
2409           "__str__",
2410           [](PyType &self) {
2411             PyPrintAccumulator printAccum;
2412             mlirTypePrint(self, printAccum.getCallback(),
2413                           printAccum.getUserData());
2414             return printAccum.join();
2415           },
2416           "Returns the assembly form of the type.")
2417       .def("__repr__", [](PyType &self) {
2418         // Generally, assembly formats are not printed for __repr__ because
2419         // this can cause exceptionally long debug output and exceptions.
2420         // However, types are an exception as they typically have compact
2421         // assembly forms and printing them is useful.
2422         PyPrintAccumulator printAccum;
2423         printAccum.parts.append("Type(");
2424         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2425         printAccum.parts.append(")");
2426         return printAccum.join();
2427       });
2428 
2429   //----------------------------------------------------------------------------
2430   // Mapping of Value.
2431   //----------------------------------------------------------------------------
2432   py::class_<PyValue>(m, "Value")
2433       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2434       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2435       .def_property_readonly(
2436           "context",
2437           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2438           "Context in which the value lives.")
2439       .def(
2440           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2441           kDumpDocstring)
2442       .def_property_readonly(
2443           "owner",
2444           [](PyValue &self) {
2445             assert(mlirOperationEqual(self.getParentOperation()->get(),
2446                                       mlirOpResultGetOwner(self.get())) &&
2447                    "expected the owner of the value in Python to match that in "
2448                    "the IR");
2449             return self.getParentOperation().getObject();
2450           })
2451       .def("__eq__",
2452            [](PyValue &self, PyValue &other) {
2453              return self.get().ptr == other.get().ptr;
2454            })
2455       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2456       .def(
2457           "__str__",
2458           [](PyValue &self) {
2459             PyPrintAccumulator printAccum;
2460             printAccum.parts.append("Value(");
2461             mlirValuePrint(self.get(), printAccum.getCallback(),
2462                            printAccum.getUserData());
2463             printAccum.parts.append(")");
2464             return printAccum.join();
2465           },
2466           kValueDunderStrDocstring)
2467       .def_property_readonly("type", [](PyValue &self) {
2468         return PyType(self.getParentOperation()->getContext(),
2469                       mlirValueGetType(self.get()));
2470       });
2471   PyBlockArgument::bind(m);
2472   PyOpResult::bind(m);
2473 
2474   // Container bindings.
2475   PyBlockArgumentList::bind(m);
2476   PyBlockIterator::bind(m);
2477   PyBlockList::bind(m);
2478   PyOperationIterator::bind(m);
2479   PyOperationList::bind(m);
2480   PyOpAttributeMap::bind(m);
2481   PyOpOperandList::bind(m);
2482   PyOpResultList::bind(m);
2483   PyRegionIterator::bind(m);
2484   PyRegionList::bind(m);
2485 
2486   // Debug bindings.
2487   PyGlobalDebugFlag::bind(m);
2488 }
2489