1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModules.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/AffineMap.h"
15 #include "mlir-c/Bindings/Python/Interop.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/IntegerSet.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 // Helper for creating an @classmethod.
135 template <class Func, typename... Args>
classmethod(Func f,Args...args)136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)142 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
143   return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
144          mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
145 }
146 
147 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)148 createCustomDialectWrapper(const std::string &dialectNamespace,
149                            py::object dialectDescriptor) {
150   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
151   if (!dialectClass) {
152     // Use the base class.
153     return py::cast(PyDialect(std::move(dialectDescriptor)));
154   }
155 
156   // Create the custom implementation.
157   return (*dialectClass)(std::move(dialectDescriptor));
158 }
159 
toMlirStringRef(const std::string & s)160 static MlirStringRef toMlirStringRef(const std::string &s) {
161   return mlirStringRefCreate(s.data(), s.size());
162 }
163 
164 template <typename PermutationTy>
isPermutation(std::vector<PermutationTy> permutation)165 static bool isPermutation(std::vector<PermutationTy> permutation) {
166   llvm::SmallVector<bool, 8> seen(permutation.size(), false);
167   for (auto val : permutation) {
168     if (val < permutation.size()) {
169       if (seen[val])
170         return false;
171       seen[val] = true;
172       continue;
173     }
174     return false;
175   }
176   return true;
177 }
178 
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182 
183 namespace {
184 
185 class PyRegionIterator {
186 public:
PyRegionIterator(PyOperationRef operation)187   PyRegionIterator(PyOperationRef operation)
188       : operation(std::move(operation)) {}
189 
dunderIter()190   PyRegionIterator &dunderIter() { return *this; }
191 
dunderNext()192   PyRegion dunderNext() {
193     operation->checkValid();
194     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195       throw py::stop_iteration();
196     }
197     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198     return PyRegion(operation, region);
199   }
200 
bind(py::module & m)201   static void bind(py::module &m) {
202     py::class_<PyRegionIterator>(m, "RegionIterator")
203         .def("__iter__", &PyRegionIterator::dunderIter)
204         .def("__next__", &PyRegionIterator::dunderNext);
205   }
206 
207 private:
208   PyOperationRef operation;
209   int nextIndex = 0;
210 };
211 
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
PyRegionList(PyOperationRef operation)216   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217 
dunderLen()218   intptr_t dunderLen() {
219     operation->checkValid();
220     return mlirOperationGetNumRegions(operation->get());
221   }
222 
dunderGetItem(intptr_t index)223   PyRegion dunderGetItem(intptr_t index) {
224     // dunderLen checks validity.
225     if (index < 0 || index >= dunderLen()) {
226       throw SetPyError(PyExc_IndexError,
227                        "attempt to access out of bounds region");
228     }
229     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230     return PyRegion(operation, region);
231   }
232 
bind(py::module & m)233   static void bind(py::module &m) {
234     py::class_<PyRegionList>(m, "RegionSequence")
235         .def("__len__", &PyRegionList::dunderLen)
236         .def("__getitem__", &PyRegionList::dunderGetItem);
237   }
238 
239 private:
240   PyOperationRef operation;
241 };
242 
243 class PyBlockIterator {
244 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)245   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246       : operation(std::move(operation)), next(next) {}
247 
dunderIter()248   PyBlockIterator &dunderIter() { return *this; }
249 
dunderNext()250   PyBlock dunderNext() {
251     operation->checkValid();
252     if (mlirBlockIsNull(next)) {
253       throw py::stop_iteration();
254     }
255 
256     PyBlock returnBlock(operation, next);
257     next = mlirBlockGetNextInRegion(next);
258     return returnBlock;
259   }
260 
bind(py::module & m)261   static void bind(py::module &m) {
262     py::class_<PyBlockIterator>(m, "BlockIterator")
263         .def("__iter__", &PyBlockIterator::dunderIter)
264         .def("__next__", &PyBlockIterator::dunderNext);
265   }
266 
267 private:
268   PyOperationRef operation;
269   MlirBlock next;
270 };
271 
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
PyBlockList(PyOperationRef operation,MlirRegion region)277   PyBlockList(PyOperationRef operation, MlirRegion region)
278       : operation(std::move(operation)), region(region) {}
279 
dunderIter()280   PyBlockIterator dunderIter() {
281     operation->checkValid();
282     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283   }
284 
dunderLen()285   intptr_t dunderLen() {
286     operation->checkValid();
287     intptr_t count = 0;
288     MlirBlock block = mlirRegionGetFirstBlock(region);
289     while (!mlirBlockIsNull(block)) {
290       count += 1;
291       block = mlirBlockGetNextInRegion(block);
292     }
293     return count;
294   }
295 
dunderGetItem(intptr_t index)296   PyBlock dunderGetItem(intptr_t index) {
297     operation->checkValid();
298     if (index < 0) {
299       throw SetPyError(PyExc_IndexError,
300                        "attempt to access out of bounds block");
301     }
302     MlirBlock block = mlirRegionGetFirstBlock(region);
303     while (!mlirBlockIsNull(block)) {
304       if (index == 0) {
305         return PyBlock(operation, block);
306       }
307       block = mlirBlockGetNextInRegion(block);
308       index -= 1;
309     }
310     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311   }
312 
appendBlock(py::args pyArgTypes)313   PyBlock appendBlock(py::args pyArgTypes) {
314     operation->checkValid();
315     llvm::SmallVector<MlirType, 4> argTypes;
316     argTypes.reserve(pyArgTypes.size());
317     for (auto &pyArg : pyArgTypes) {
318       argTypes.push_back(pyArg.cast<PyType &>());
319     }
320 
321     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322     mlirRegionAppendOwnedBlock(region, block);
323     return PyBlock(operation, block);
324   }
325 
bind(py::module & m)326   static void bind(py::module &m) {
327     py::class_<PyBlockList>(m, "BlockList")
328         .def("__getitem__", &PyBlockList::dunderGetItem)
329         .def("__iter__", &PyBlockList::dunderIter)
330         .def("__len__", &PyBlockList::dunderLen)
331         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332   }
333 
334 private:
335   PyOperationRef operation;
336   MlirRegion region;
337 };
338 
339 class PyOperationIterator {
340 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)341   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342       : parentOperation(std::move(parentOperation)), next(next) {}
343 
dunderIter()344   PyOperationIterator &dunderIter() { return *this; }
345 
dunderNext()346   py::object dunderNext() {
347     parentOperation->checkValid();
348     if (mlirOperationIsNull(next)) {
349       throw py::stop_iteration();
350     }
351 
352     PyOperationRef returnOperation =
353         PyOperation::forOperation(parentOperation->getContext(), next);
354     next = mlirOperationGetNextInBlock(next);
355     return returnOperation->createOpView();
356   }
357 
bind(py::module & m)358   static void bind(py::module &m) {
359     py::class_<PyOperationIterator>(m, "OperationIterator")
360         .def("__iter__", &PyOperationIterator::dunderIter)
361         .def("__next__", &PyOperationIterator::dunderNext);
362   }
363 
364 private:
365   PyOperationRef parentOperation;
366   MlirOperation next;
367 };
368 
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)375   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376       : parentOperation(std::move(parentOperation)), block(block) {}
377 
dunderIter()378   PyOperationIterator dunderIter() {
379     parentOperation->checkValid();
380     return PyOperationIterator(parentOperation,
381                                mlirBlockGetFirstOperation(block));
382   }
383 
dunderLen()384   intptr_t dunderLen() {
385     parentOperation->checkValid();
386     intptr_t count = 0;
387     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388     while (!mlirOperationIsNull(childOp)) {
389       count += 1;
390       childOp = mlirOperationGetNextInBlock(childOp);
391     }
392     return count;
393   }
394 
dunderGetItem(intptr_t index)395   py::object dunderGetItem(intptr_t index) {
396     parentOperation->checkValid();
397     if (index < 0) {
398       throw SetPyError(PyExc_IndexError,
399                        "attempt to access out of bounds operation");
400     }
401     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402     while (!mlirOperationIsNull(childOp)) {
403       if (index == 0) {
404         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405             ->createOpView();
406       }
407       childOp = mlirOperationGetNextInBlock(childOp);
408       index -= 1;
409     }
410     throw SetPyError(PyExc_IndexError,
411                      "attempt to access out of bounds operation");
412   }
413 
bind(py::module & m)414   static void bind(py::module &m) {
415     py::class_<PyOperationList>(m, "OperationList")
416         .def("__getitem__", &PyOperationList::dunderGetItem)
417         .def("__iter__", &PyOperationList::dunderIter)
418         .def("__len__", &PyOperationList::dunderLen);
419   }
420 
421 private:
422   PyOperationRef parentOperation;
423   MlirBlock block;
424 };
425 
426 } // namespace
427 
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431 
PyMlirContext(MlirContext context)432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433   py::gil_scoped_acquire acquire;
434   auto &liveContexts = getLiveContexts();
435   liveContexts[context.ptr] = this;
436 }
437 
~PyMlirContext()438 PyMlirContext::~PyMlirContext() {
439   // Note that the only public way to construct an instance is via the
440   // forContext method, which always puts the associated handle into
441   // liveContexts.
442   py::gil_scoped_acquire acquire;
443   getLiveContexts().erase(context.ptr);
444   mlirContextDestroy(context);
445 }
446 
getCapsule()447 py::object PyMlirContext::getCapsule() {
448   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450 
createFromCapsule(py::object capsule)451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453   if (mlirContextIsNull(rawContext))
454     throw py::error_already_set();
455   return forContext(rawContext).releaseObject();
456 }
457 
createNewContextForInit()458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459   MlirContext context = mlirContextCreate();
460   mlirRegisterAllDialects(context);
461   return new PyMlirContext(context);
462 }
463 
forContext(MlirContext context)464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465   py::gil_scoped_acquire acquire;
466   auto &liveContexts = getLiveContexts();
467   auto it = liveContexts.find(context.ptr);
468   if (it == liveContexts.end()) {
469     // Create.
470     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471     py::object pyRef = py::cast(unownedContextWrapper);
472     assert(pyRef && "cast to py::object failed");
473     liveContexts[context.ptr] = unownedContextWrapper;
474     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475   }
476   // Use existing.
477   py::object pyRef = py::cast(it->second);
478   return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480 
getLiveContexts()481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482   static LiveContextMap liveContexts;
483   return liveContexts;
484 }
485 
getLiveCount()486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487 
getLiveOperationCount()488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489 
getLiveModuleCount()490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491 
contextEnter()492 pybind11::object PyMlirContext::contextEnter() {
493   return PyThreadContextEntry::pushContext(*this);
494 }
495 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)496 void PyMlirContext::contextExit(pybind11::object excType,
497                                 pybind11::object excVal,
498                                 pybind11::object excTb) {
499   PyThreadContextEntry::popContext(*this);
500 }
501 
resolve()502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504   if (!context) {
505     throw SetPyError(
506         PyExc_RuntimeError,
507         "An MLIR function requires a Context but none was provided in the call "
508         "or from the surrounding environment. Either pass to the function with "
509         "a 'context=' argument or establish a default using 'with Context():'");
510   }
511   return *context;
512 }
513 
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517 
getStack()518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519   static thread_local std::vector<PyThreadContextEntry> stack;
520   return stack;
521 }
522 
getTopOfStack()523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524   auto &stack = getStack();
525   if (stack.empty())
526     return nullptr;
527   return &stack.back();
528 }
529 
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531                                 py::object insertionPoint,
532                                 py::object location) {
533   auto &stack = getStack();
534   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535                      std::move(location));
536   // If the new stack has more than one entry and the context of the new top
537   // entry matches the previous, copy the insertionPoint and location from the
538   // previous entry if missing from the new top entry.
539   if (stack.size() > 1) {
540     auto &prev = *(stack.rbegin() + 1);
541     auto &current = stack.back();
542     if (current.context.is(prev.context)) {
543       // Default non-context objects from the previous entry.
544       if (!current.insertionPoint)
545         current.insertionPoint = prev.insertionPoint;
546       if (!current.location)
547         current.location = prev.location;
548     }
549   }
550 }
551 
getContext()552 PyMlirContext *PyThreadContextEntry::getContext() {
553   if (!context)
554     return nullptr;
555   return py::cast<PyMlirContext *>(context);
556 }
557 
getInsertionPoint()558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559   if (!insertionPoint)
560     return nullptr;
561   return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563 
getLocation()564 PyLocation *PyThreadContextEntry::getLocation() {
565   if (!location)
566     return nullptr;
567   return py::cast<PyLocation *>(location);
568 }
569 
getDefaultContext()570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571   auto *tos = getTopOfStack();
572   return tos ? tos->getContext() : nullptr;
573 }
574 
getDefaultInsertionPoint()575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576   auto *tos = getTopOfStack();
577   return tos ? tos->getInsertionPoint() : nullptr;
578 }
579 
getDefaultLocation()580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getLocation() : nullptr;
583 }
584 
pushContext(PyMlirContext & context)585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586   py::object contextObj = py::cast(context);
587   push(FrameKind::Context, /*context=*/contextObj,
588        /*insertionPoint=*/py::object(),
589        /*location=*/py::object());
590   return contextObj;
591 }
592 
popContext(PyMlirContext & context)593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594   auto &stack = getStack();
595   if (stack.empty())
596     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600   stack.pop_back();
601 }
602 
603 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605   py::object contextObj =
606       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607   py::object insertionPointObj = py::cast(insertionPoint);
608   push(FrameKind::InsertionPoint,
609        /*context=*/contextObj,
610        /*insertionPoint=*/insertionPointObj,
611        /*location=*/py::object());
612   return insertionPointObj;
613 }
614 
popInsertionPoint(PyInsertionPoint & insertionPoint)615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616   auto &stack = getStack();
617   if (stack.empty())
618     throw SetPyError(PyExc_RuntimeError,
619                      "Unbalanced InsertionPoint enter/exit");
620   auto &tos = stack.back();
621   if (tos.frameKind != FrameKind::InsertionPoint &&
622       tos.getInsertionPoint() != &insertionPoint)
623     throw SetPyError(PyExc_RuntimeError,
624                      "Unbalanced InsertionPoint enter/exit");
625   stack.pop_back();
626 }
627 
pushLocation(PyLocation & location)628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629   py::object contextObj = location.getContext().getObject();
630   py::object locationObj = py::cast(location);
631   push(FrameKind::Location, /*context=*/contextObj,
632        /*insertionPoint=*/py::object(),
633        /*location=*/locationObj);
634   return locationObj;
635 }
636 
popLocation(PyLocation & location)637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638   auto &stack = getStack();
639   if (stack.empty())
640     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641   auto &tos = stack.back();
642   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644   stack.pop_back();
645 }
646 
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650 
getDialectForKey(const std::string & key,bool attrError)651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652                                          bool attrError) {
653   // If the "std" dialect was asked for, substitute the empty namespace :(
654   static const std::string emptyKey;
655   const std::string *canonKey = key == "std" ? &emptyKey : &key;
656   MlirDialect dialect = mlirContextGetOrLoadDialect(
657       getContext()->get(), {canonKey->data(), canonKey->size()});
658   if (mlirDialectIsNull(dialect)) {
659     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
660                      Twine("Dialect '") + key + "' not found");
661   }
662   return dialect;
663 }
664 
665 //------------------------------------------------------------------------------
666 // PyLocation
667 //------------------------------------------------------------------------------
668 
getCapsule()669 py::object PyLocation::getCapsule() {
670   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
671 }
672 
createFromCapsule(py::object capsule)673 PyLocation PyLocation::createFromCapsule(py::object capsule) {
674   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
675   if (mlirLocationIsNull(rawLoc))
676     throw py::error_already_set();
677   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
678                     rawLoc);
679 }
680 
contextEnter()681 py::object PyLocation::contextEnter() {
682   return PyThreadContextEntry::pushLocation(*this);
683 }
684 
contextExit(py::object excType,py::object excVal,py::object excTb)685 void PyLocation::contextExit(py::object excType, py::object excVal,
686                              py::object excTb) {
687   PyThreadContextEntry::popLocation(*this);
688 }
689 
resolve()690 PyLocation &DefaultingPyLocation::resolve() {
691   auto *location = PyThreadContextEntry::getDefaultLocation();
692   if (!location) {
693     throw SetPyError(
694         PyExc_RuntimeError,
695         "An MLIR function requires a Location but none was provided in the "
696         "call or from the surrounding environment. Either pass to the function "
697         "with a 'loc=' argument or establish a default using 'with loc:'");
698   }
699   return *location;
700 }
701 
702 //------------------------------------------------------------------------------
703 // PyModule
704 //------------------------------------------------------------------------------
705 
PyModule(PyMlirContextRef contextRef,MlirModule module)706 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
707     : BaseContextObject(std::move(contextRef)), module(module) {}
708 
~PyModule()709 PyModule::~PyModule() {
710   py::gil_scoped_acquire acquire;
711   auto &liveModules = getContext()->liveModules;
712   assert(liveModules.count(module.ptr) == 1 &&
713          "destroying module not in live map");
714   liveModules.erase(module.ptr);
715   mlirModuleDestroy(module);
716 }
717 
forModule(MlirModule module)718 PyModuleRef PyModule::forModule(MlirModule module) {
719   MlirContext context = mlirModuleGetContext(module);
720   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
721 
722   py::gil_scoped_acquire acquire;
723   auto &liveModules = contextRef->liveModules;
724   auto it = liveModules.find(module.ptr);
725   if (it == liveModules.end()) {
726     // Create.
727     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
728     // Note that the default return value policy on cast is automatic_reference,
729     // which does not take ownership (delete will not be called).
730     // Just be explicit.
731     py::object pyRef =
732         py::cast(unownedModule, py::return_value_policy::take_ownership);
733     unownedModule->handle = pyRef;
734     liveModules[module.ptr] =
735         std::make_pair(unownedModule->handle, unownedModule);
736     return PyModuleRef(unownedModule, std::move(pyRef));
737   }
738   // Use existing.
739   PyModule *existing = it->second.second;
740   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
741   return PyModuleRef(existing, std::move(pyRef));
742 }
743 
createFromCapsule(py::object capsule)744 py::object PyModule::createFromCapsule(py::object capsule) {
745   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
746   if (mlirModuleIsNull(rawModule))
747     throw py::error_already_set();
748   return forModule(rawModule).releaseObject();
749 }
750 
getCapsule()751 py::object PyModule::getCapsule() {
752   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
753 }
754 
755 //------------------------------------------------------------------------------
756 // PyOperation
757 //------------------------------------------------------------------------------
758 
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)759 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
760     : BaseContextObject(std::move(contextRef)), operation(operation) {}
761 
~PyOperation()762 PyOperation::~PyOperation() {
763   auto &liveOperations = getContext()->liveOperations;
764   assert(liveOperations.count(operation.ptr) == 1 &&
765          "destroying operation not in live map");
766   liveOperations.erase(operation.ptr);
767   if (!isAttached()) {
768     mlirOperationDestroy(operation);
769   }
770 }
771 
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773                                            MlirOperation operation,
774                                            py::object parentKeepAlive) {
775   auto &liveOperations = contextRef->liveOperations;
776   // Create.
777   PyOperation *unownedOperation =
778       new PyOperation(std::move(contextRef), operation);
779   // Note that the default return value policy on cast is automatic_reference,
780   // which does not take ownership (delete will not be called).
781   // Just be explicit.
782   py::object pyRef =
783       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784   unownedOperation->handle = pyRef;
785   if (parentKeepAlive) {
786     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787   }
788   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789   return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791 
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793                                          MlirOperation operation,
794                                          py::object parentKeepAlive) {
795   auto &liveOperations = contextRef->liveOperations;
796   auto it = liveOperations.find(operation.ptr);
797   if (it == liveOperations.end()) {
798     // Create.
799     return createInstance(std::move(contextRef), operation,
800                           std::move(parentKeepAlive));
801   }
802   // Use existing.
803   PyOperation *existing = it->second.second;
804   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805   return PyOperationRef(existing, std::move(pyRef));
806 }
807 
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809                                            MlirOperation operation,
810                                            py::object parentKeepAlive) {
811   auto &liveOperations = contextRef->liveOperations;
812   assert(liveOperations.count(operation.ptr) == 0 &&
813          "cannot create detached operation that already exists");
814   (void)liveOperations;
815 
816   PyOperationRef created = createInstance(std::move(contextRef), operation,
817                                           std::move(parentKeepAlive));
818   created->attached = false;
819   return created;
820 }
821 
checkValid() const822 void PyOperation::checkValid() const {
823   if (!valid) {
824     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825   }
826 }
827 
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)828 void PyOperationBase::print(py::object fileObject, bool binary,
829                             llvm::Optional<int64_t> largeElementsLimit,
830                             bool enableDebugInfo, bool prettyDebugInfo,
831                             bool printGenericOpForm, bool useLocalScope) {
832   PyOperation &operation = getOperation();
833   operation.checkValid();
834   if (fileObject.is_none())
835     fileObject = py::module::import("sys").attr("stdout");
836 
837   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838     fileObject.attr("write")("// Verification failed, printing generic form\n");
839     printGenericOpForm = true;
840   }
841 
842   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843   if (largeElementsLimit)
844     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845   if (enableDebugInfo)
846     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847   if (printGenericOpForm)
848     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849 
850   PyFileAccumulator accum(fileObject, binary);
851   py::gil_scoped_release();
852   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853                               accum.getUserData());
854   mlirOpPrintingFlagsDestroy(flags);
855 }
856 
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)857 py::object PyOperationBase::getAsm(bool binary,
858                                    llvm::Optional<int64_t> largeElementsLimit,
859                                    bool enableDebugInfo, bool prettyDebugInfo,
860                                    bool printGenericOpForm,
861                                    bool useLocalScope) {
862   py::object fileObject;
863   if (binary) {
864     fileObject = py::module::import("io").attr("BytesIO")();
865   } else {
866     fileObject = py::module::import("io").attr("StringIO")();
867   }
868   print(fileObject, /*binary=*/binary,
869         /*largeElementsLimit=*/largeElementsLimit,
870         /*enableDebugInfo=*/enableDebugInfo,
871         /*prettyDebugInfo=*/prettyDebugInfo,
872         /*printGenericOpForm=*/printGenericOpForm,
873         /*useLocalScope=*/useLocalScope);
874 
875   return fileObject.attr("getvalue")();
876 }
877 
getParentOperation()878 PyOperationRef PyOperation::getParentOperation() {
879   if (!isAttached())
880     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
881   MlirOperation operation = mlirOperationGetParentOperation(get());
882   if (mlirOperationIsNull(operation))
883     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
884   return PyOperation::forOperation(getContext(), operation);
885 }
886 
getBlock()887 PyBlock PyOperation::getBlock() {
888   PyOperationRef parentOperation = getParentOperation();
889   MlirBlock block = mlirOperationGetBlock(get());
890   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
891   return PyBlock{std::move(parentOperation), block};
892 }
893 
create(std::string name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)894 py::object PyOperation::create(
895     std::string name, llvm::Optional<std::vector<PyType *>> results,
896     llvm::Optional<std::vector<PyValue *>> operands,
897     llvm::Optional<py::dict> attributes,
898     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
899     DefaultingPyLocation location, py::object maybeIp) {
900   llvm::SmallVector<MlirValue, 4> mlirOperands;
901   llvm::SmallVector<MlirType, 4> mlirResults;
902   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
903   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
904 
905   // General parameter validation.
906   if (regions < 0)
907     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
908 
909   // Unpack/validate operands.
910   if (operands) {
911     mlirOperands.reserve(operands->size());
912     for (PyValue *operand : *operands) {
913       if (!operand)
914         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
915       mlirOperands.push_back(operand->get());
916     }
917   }
918 
919   // Unpack/validate results.
920   if (results) {
921     mlirResults.reserve(results->size());
922     for (PyType *result : *results) {
923       // TODO: Verify result type originate from the same context.
924       if (!result)
925         throw SetPyError(PyExc_ValueError, "result type cannot be None");
926       mlirResults.push_back(*result);
927     }
928   }
929   // Unpack/validate attributes.
930   if (attributes) {
931     mlirAttributes.reserve(attributes->size());
932     for (auto &it : *attributes) {
933       std::string key;
934       try {
935         key = it.first.cast<std::string>();
936       } catch (py::cast_error &err) {
937         std::string msg = "Invalid attribute key (not a string) when "
938                           "attempting to create the operation \"" +
939                           name + "\" (" + err.what() + ")";
940         throw py::cast_error(msg);
941       }
942       try {
943         auto &attribute = it.second.cast<PyAttribute &>();
944         // TODO: Verify attribute originates from the same context.
945         mlirAttributes.emplace_back(std::move(key), attribute);
946       } catch (py::reference_cast_error &) {
947         // This exception seems thrown when the value is "None".
948         std::string msg =
949             "Found an invalid (`None`?) attribute value for the key \"" + key +
950             "\" when attempting to create the operation \"" + name + "\"";
951         throw py::cast_error(msg);
952       } catch (py::cast_error &err) {
953         std::string msg = "Invalid attribute value for the key \"" + key +
954                           "\" when attempting to create the operation \"" +
955                           name + "\" (" + err.what() + ")";
956         throw py::cast_error(msg);
957       }
958     }
959   }
960   // Unpack/validate successors.
961   if (successors) {
962     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
963     mlirSuccessors.reserve(successors->size());
964     for (auto *successor : *successors) {
965       // TODO: Verify successor originate from the same context.
966       if (!successor)
967         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
968       mlirSuccessors.push_back(successor->get());
969     }
970   }
971 
972   // Apply unpacked/validated to the operation state. Beyond this
973   // point, exceptions cannot be thrown or else the state will leak.
974   MlirOperationState state =
975       mlirOperationStateGet(toMlirStringRef(name), location);
976   if (!mlirOperands.empty())
977     mlirOperationStateAddOperands(&state, mlirOperands.size(),
978                                   mlirOperands.data());
979   if (!mlirResults.empty())
980     mlirOperationStateAddResults(&state, mlirResults.size(),
981                                  mlirResults.data());
982   if (!mlirAttributes.empty()) {
983     // Note that the attribute names directly reference bytes in
984     // mlirAttributes, so that vector must not be changed from here
985     // on.
986     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
987     mlirNamedAttributes.reserve(mlirAttributes.size());
988     for (auto &it : mlirAttributes)
989       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
990           mlirIdentifierGet(mlirAttributeGetContext(it.second),
991                             toMlirStringRef(it.first)),
992           it.second));
993     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
994                                     mlirNamedAttributes.data());
995   }
996   if (!mlirSuccessors.empty())
997     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
998                                     mlirSuccessors.data());
999   if (regions) {
1000     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1001     mlirRegions.resize(regions);
1002     for (int i = 0; i < regions; ++i)
1003       mlirRegions[i] = mlirRegionCreate();
1004     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1005                                       mlirRegions.data());
1006   }
1007 
1008   // Construct the operation.
1009   MlirOperation operation = mlirOperationCreate(&state);
1010   PyOperationRef created =
1011       PyOperation::createDetached(location->getContext(), operation);
1012 
1013   // InsertPoint active?
1014   if (!maybeIp.is(py::cast(false))) {
1015     PyInsertionPoint *ip;
1016     if (maybeIp.is_none()) {
1017       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1018     } else {
1019       ip = py::cast<PyInsertionPoint *>(maybeIp);
1020     }
1021     if (ip)
1022       ip->insert(*created.get());
1023   }
1024 
1025   return created->createOpView();
1026 }
1027 
createOpView()1028 py::object PyOperation::createOpView() {
1029   MlirIdentifier ident = mlirOperationGetName(get());
1030   MlirStringRef identStr = mlirIdentifierStr(ident);
1031   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1032       StringRef(identStr.data, identStr.length));
1033   if (opViewClass)
1034     return (*opViewClass)(getRef().getObject());
1035   return py::cast(PyOpView(getRef().getObject()));
1036 }
1037 
1038 //------------------------------------------------------------------------------
1039 // PyOpView
1040 //------------------------------------------------------------------------------
1041 
1042 py::object
buildGeneric(py::object cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,py::object maybeIp)1043 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1044                        py::list operandList,
1045                        llvm::Optional<py::dict> attributes,
1046                        llvm::Optional<std::vector<PyBlock *>> successors,
1047                        llvm::Optional<int> regions,
1048                        DefaultingPyLocation location, py::object maybeIp) {
1049   PyMlirContextRef context = location->getContext();
1050   // Class level operation construction metadata.
1051   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1052   // Operand and result segment specs are either none, which does no
1053   // variadic unpacking, or a list of ints with segment sizes, where each
1054   // element is either a positive number (typically 1 for a scalar) or -1 to
1055   // indicate that it is derived from the length of the same-indexed operand
1056   // or result (implying that it is a list at that position).
1057   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1058   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1059 
1060   std::vector<uint64_t> operandSegmentLengths;
1061   std::vector<uint64_t> resultSegmentLengths;
1062 
1063   // Validate/determine region count.
1064   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1065   int opMinRegionCount = std::get<0>(opRegionSpec);
1066   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1067   if (!regions) {
1068     regions = opMinRegionCount;
1069   }
1070   if (*regions < opMinRegionCount) {
1071     throw py::value_error(
1072         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1073          llvm::Twine(opMinRegionCount) +
1074          " regions but was built with regions=" + llvm::Twine(*regions))
1075             .str());
1076   }
1077   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1078     throw py::value_error(
1079         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1080          llvm::Twine(opMinRegionCount) +
1081          " regions but was built with regions=" + llvm::Twine(*regions))
1082             .str());
1083   }
1084 
1085   // Unpack results.
1086   std::vector<PyType *> resultTypes;
1087   resultTypes.reserve(resultTypeList.size());
1088   if (resultSegmentSpecObj.is_none()) {
1089     // Non-variadic result unpacking.
1090     for (auto it : llvm::enumerate(resultTypeList)) {
1091       try {
1092         resultTypes.push_back(py::cast<PyType *>(it.value()));
1093         if (!resultTypes.back())
1094           throw py::cast_error();
1095       } catch (py::cast_error &err) {
1096         throw py::value_error((llvm::Twine("Result ") +
1097                                llvm::Twine(it.index()) + " of operation \"" +
1098                                name + "\" must be a Type (" + err.what() + ")")
1099                                   .str());
1100       }
1101     }
1102   } else {
1103     // Sized result unpacking.
1104     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1105     if (resultSegmentSpec.size() != resultTypeList.size()) {
1106       throw py::value_error((llvm::Twine("Operation \"") + name +
1107                              "\" requires " +
1108                              llvm::Twine(resultSegmentSpec.size()) +
1109                              "result segments but was provided " +
1110                              llvm::Twine(resultTypeList.size()))
1111                                 .str());
1112     }
1113     resultSegmentLengths.reserve(resultTypeList.size());
1114     for (auto it :
1115          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1116       int segmentSpec = std::get<1>(it.value());
1117       if (segmentSpec == 1 || segmentSpec == 0) {
1118         // Unpack unary element.
1119         try {
1120           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1121           if (resultType) {
1122             resultTypes.push_back(resultType);
1123             resultSegmentLengths.push_back(1);
1124           } else if (segmentSpec == 0) {
1125             // Allowed to be optional.
1126             resultSegmentLengths.push_back(0);
1127           } else {
1128             throw py::cast_error("was None and result is not optional");
1129           }
1130         } catch (py::cast_error &err) {
1131           throw py::value_error((llvm::Twine("Result ") +
1132                                  llvm::Twine(it.index()) + " of operation \"" +
1133                                  name + "\" must be a Type (" + err.what() +
1134                                  ")")
1135                                     .str());
1136         }
1137       } else if (segmentSpec == -1) {
1138         // Unpack sequence by appending.
1139         try {
1140           if (std::get<0>(it.value()).is_none()) {
1141             // Treat it as an empty list.
1142             resultSegmentLengths.push_back(0);
1143           } else {
1144             // Unpack the list.
1145             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1146             for (py::object segmentItem : segment) {
1147               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1148               if (!resultTypes.back()) {
1149                 throw py::cast_error("contained a None item");
1150               }
1151             }
1152             resultSegmentLengths.push_back(segment.size());
1153           }
1154         } catch (std::exception &err) {
1155           // NOTE: Sloppy to be using a catch-all here, but there are at least
1156           // three different unrelated exceptions that can be thrown in the
1157           // above "casts". Just keep the scope above small and catch them all.
1158           throw py::value_error((llvm::Twine("Result ") +
1159                                  llvm::Twine(it.index()) + " of operation \"" +
1160                                  name + "\" must be a Sequence of Types (" +
1161                                  err.what() + ")")
1162                                     .str());
1163         }
1164       } else {
1165         throw py::value_error("Unexpected segment spec");
1166       }
1167     }
1168   }
1169 
1170   // Unpack operands.
1171   std::vector<PyValue *> operands;
1172   operands.reserve(operands.size());
1173   if (operandSegmentSpecObj.is_none()) {
1174     // Non-sized operand unpacking.
1175     for (auto it : llvm::enumerate(operandList)) {
1176       try {
1177         operands.push_back(py::cast<PyValue *>(it.value()));
1178         if (!operands.back())
1179           throw py::cast_error();
1180       } catch (py::cast_error &err) {
1181         throw py::value_error((llvm::Twine("Operand ") +
1182                                llvm::Twine(it.index()) + " of operation \"" +
1183                                name + "\" must be a Value (" + err.what() + ")")
1184                                   .str());
1185       }
1186     }
1187   } else {
1188     // Sized operand unpacking.
1189     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1190     if (operandSegmentSpec.size() != operandList.size()) {
1191       throw py::value_error((llvm::Twine("Operation \"") + name +
1192                              "\" requires " +
1193                              llvm::Twine(operandSegmentSpec.size()) +
1194                              "operand segments but was provided " +
1195                              llvm::Twine(operandList.size()))
1196                                 .str());
1197     }
1198     operandSegmentLengths.reserve(operandList.size());
1199     for (auto it :
1200          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1201       int segmentSpec = std::get<1>(it.value());
1202       if (segmentSpec == 1 || segmentSpec == 0) {
1203         // Unpack unary element.
1204         try {
1205           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1206           if (operandValue) {
1207             operands.push_back(operandValue);
1208             operandSegmentLengths.push_back(1);
1209           } else if (segmentSpec == 0) {
1210             // Allowed to be optional.
1211             operandSegmentLengths.push_back(0);
1212           } else {
1213             throw py::cast_error("was None and operand is not optional");
1214           }
1215         } catch (py::cast_error &err) {
1216           throw py::value_error((llvm::Twine("Operand ") +
1217                                  llvm::Twine(it.index()) + " of operation \"" +
1218                                  name + "\" must be a Value (" + err.what() +
1219                                  ")")
1220                                     .str());
1221         }
1222       } else if (segmentSpec == -1) {
1223         // Unpack sequence by appending.
1224         try {
1225           if (std::get<0>(it.value()).is_none()) {
1226             // Treat it as an empty list.
1227             operandSegmentLengths.push_back(0);
1228           } else {
1229             // Unpack the list.
1230             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1231             for (py::object segmentItem : segment) {
1232               operands.push_back(py::cast<PyValue *>(segmentItem));
1233               if (!operands.back()) {
1234                 throw py::cast_error("contained a None item");
1235               }
1236             }
1237             operandSegmentLengths.push_back(segment.size());
1238           }
1239         } catch (std::exception &err) {
1240           // NOTE: Sloppy to be using a catch-all here, but there are at least
1241           // three different unrelated exceptions that can be thrown in the
1242           // above "casts". Just keep the scope above small and catch them all.
1243           throw py::value_error((llvm::Twine("Operand ") +
1244                                  llvm::Twine(it.index()) + " of operation \"" +
1245                                  name + "\" must be a Sequence of Values (" +
1246                                  err.what() + ")")
1247                                     .str());
1248         }
1249       } else {
1250         throw py::value_error("Unexpected segment spec");
1251       }
1252     }
1253   }
1254 
1255   // Merge operand/result segment lengths into attributes if needed.
1256   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1257     // Dup.
1258     if (attributes) {
1259       attributes = py::dict(*attributes);
1260     } else {
1261       attributes = py::dict();
1262     }
1263     if (attributes->contains("result_segment_sizes") ||
1264         attributes->contains("operand_segment_sizes")) {
1265       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1266                             "'operand_segment_sizes' attribute is unsupported. "
1267                             "Use Operation.create for such low-level access.");
1268     }
1269 
1270     // Add result_segment_sizes attribute.
1271     if (!resultSegmentLengths.empty()) {
1272       int64_t size = resultSegmentLengths.size();
1273       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1274           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1275           resultSegmentLengths.size(), resultSegmentLengths.data());
1276       (*attributes)["result_segment_sizes"] =
1277           PyAttribute(context, segmentLengthAttr);
1278     }
1279 
1280     // Add operand_segment_sizes attribute.
1281     if (!operandSegmentLengths.empty()) {
1282       int64_t size = operandSegmentLengths.size();
1283       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1284           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1285           operandSegmentLengths.size(), operandSegmentLengths.data());
1286       (*attributes)["operand_segment_sizes"] =
1287           PyAttribute(context, segmentLengthAttr);
1288     }
1289   }
1290 
1291   // Delegate to create.
1292   return PyOperation::create(std::move(name),
1293                              /*results=*/std::move(resultTypes),
1294                              /*operands=*/std::move(operands),
1295                              /*attributes=*/std::move(attributes),
1296                              /*successors=*/std::move(successors),
1297                              /*regions=*/*regions, location, maybeIp);
1298 }
1299 
PyOpView(py::object operationObject)1300 PyOpView::PyOpView(py::object operationObject)
1301     // Casting through the PyOperationBase base-class and then back to the
1302     // Operation lets us accept any PyOperationBase subclass.
1303     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1304       operationObject(operation.getRef().getObject()) {}
1305 
createRawSubclass(py::object userClass)1306 py::object PyOpView::createRawSubclass(py::object userClass) {
1307   // This is... a little gross. The typical pattern is to have a pure python
1308   // class that extends OpView like:
1309   //   class AddFOp(_cext.ir.OpView):
1310   //     def __init__(self, loc, lhs, rhs):
1311   //       operation = loc.context.create_operation(
1312   //           "addf", lhs, rhs, results=[lhs.type])
1313   //       super().__init__(operation)
1314   //
1315   // I.e. The goal of the user facing type is to provide a nice constructor
1316   // that has complete freedom for the op under construction. This is at odds
1317   // with our other desire to sometimes create this object by just passing an
1318   // operation (to initialize the base class). We could do *arg and **kwargs
1319   // munging to try to make it work, but instead, we synthesize a new class
1320   // on the fly which extends this user class (AddFOp in this example) and
1321   // *give it* the base class's __init__ method, thus bypassing the
1322   // intermediate subclass's __init__ method entirely. While slightly,
1323   // underhanded, this is safe/legal because the type hierarchy has not changed
1324   // (we just added a new leaf) and we aren't mucking around with __new__.
1325   // Typically, this new class will be stored on the original as "_Raw" and will
1326   // be used for casts and other things that need a variant of the class that
1327   // is initialized purely from an operation.
1328   py::object parentMetaclass =
1329       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1330   py::dict attributes;
1331   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1332   // now.
1333   //   auto opViewType = py::type::of<PyOpView>();
1334   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1335   attributes["__init__"] = opViewType.attr("__init__");
1336   py::str origName = userClass.attr("__name__");
1337   py::str newName = py::str("_") + origName;
1338   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1339 }
1340 
1341 //------------------------------------------------------------------------------
1342 // PyInsertionPoint.
1343 //------------------------------------------------------------------------------
1344 
PyInsertionPoint(PyBlock & block)1345 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1346 
PyInsertionPoint(PyOperationBase & beforeOperationBase)1347 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1348     : refOperation(beforeOperationBase.getOperation().getRef()),
1349       block((*refOperation)->getBlock()) {}
1350 
insert(PyOperationBase & operationBase)1351 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1352   PyOperation &operation = operationBase.getOperation();
1353   if (operation.isAttached())
1354     throw SetPyError(PyExc_ValueError,
1355                      "Attempt to insert operation that is already attached");
1356   block.getParentOperation()->checkValid();
1357   MlirOperation beforeOp = {nullptr};
1358   if (refOperation) {
1359     // Insert before operation.
1360     (*refOperation)->checkValid();
1361     beforeOp = (*refOperation)->get();
1362   } else {
1363     // Insert at end (before null) is only valid if the block does not
1364     // already end in a known terminator (violating this will cause assertion
1365     // failures later).
1366     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1367       throw py::index_error("Cannot insert operation at the end of a block "
1368                             "that already has a terminator. Did you mean to "
1369                             "use 'InsertionPoint.at_block_terminator(block)' "
1370                             "versus 'InsertionPoint(block)'?");
1371     }
1372   }
1373   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1374   operation.setAttached();
1375 }
1376 
atBlockBegin(PyBlock & block)1377 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1378   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1379   if (mlirOperationIsNull(firstOp)) {
1380     // Just insert at end.
1381     return PyInsertionPoint(block);
1382   }
1383 
1384   // Insert before first op.
1385   PyOperationRef firstOpRef = PyOperation::forOperation(
1386       block.getParentOperation()->getContext(), firstOp);
1387   return PyInsertionPoint{block, std::move(firstOpRef)};
1388 }
1389 
atBlockTerminator(PyBlock & block)1390 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1391   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1392   if (mlirOperationIsNull(terminator))
1393     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1394   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1395       block.getParentOperation()->getContext(), terminator);
1396   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1397 }
1398 
contextEnter()1399 py::object PyInsertionPoint::contextEnter() {
1400   return PyThreadContextEntry::pushInsertionPoint(*this);
1401 }
1402 
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1403 void PyInsertionPoint::contextExit(pybind11::object excType,
1404                                    pybind11::object excVal,
1405                                    pybind11::object excTb) {
1406   PyThreadContextEntry::popInsertionPoint(*this);
1407 }
1408 
1409 //------------------------------------------------------------------------------
1410 // PyAttribute.
1411 //------------------------------------------------------------------------------
1412 
operator ==(const PyAttribute & other)1413 bool PyAttribute::operator==(const PyAttribute &other) {
1414   return mlirAttributeEqual(attr, other.attr);
1415 }
1416 
getCapsule()1417 py::object PyAttribute::getCapsule() {
1418   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1419 }
1420 
createFromCapsule(py::object capsule)1421 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1422   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1423   if (mlirAttributeIsNull(rawAttr))
1424     throw py::error_already_set();
1425   return PyAttribute(
1426       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1427 }
1428 
1429 //------------------------------------------------------------------------------
1430 // PyNamedAttribute.
1431 //------------------------------------------------------------------------------
1432 
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1433 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1434     : ownedName(new std::string(std::move(ownedName))) {
1435   namedAttr = mlirNamedAttributeGet(
1436       mlirIdentifierGet(mlirAttributeGetContext(attr),
1437                         toMlirStringRef(*this->ownedName)),
1438       attr);
1439 }
1440 
1441 //------------------------------------------------------------------------------
1442 // PyType.
1443 //------------------------------------------------------------------------------
1444 
operator ==(const PyType & other)1445 bool PyType::operator==(const PyType &other) {
1446   return mlirTypeEqual(type, other.type);
1447 }
1448 
getCapsule()1449 py::object PyType::getCapsule() {
1450   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1451 }
1452 
createFromCapsule(py::object capsule)1453 PyType PyType::createFromCapsule(py::object capsule) {
1454   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1455   if (mlirTypeIsNull(rawType))
1456     throw py::error_already_set();
1457   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1458                 rawType);
1459 }
1460 
1461 //------------------------------------------------------------------------------
1462 // PyValue and subclases.
1463 //------------------------------------------------------------------------------
1464 
1465 namespace {
1466 /// CRTP base class for Python MLIR values that subclass Value and should be
1467 /// castable from it. The value hierarchy is one level deep and is not supposed
1468 /// to accommodate other levels unless core MLIR changes.
1469 template <typename DerivedTy>
1470 class PyConcreteValue : public PyValue {
1471 public:
1472   // Derived classes must define statics for:
1473   //   IsAFunctionTy isaFunction
1474   //   const char *pyClassName
1475   // and redefine bindDerived.
1476   using ClassTy = py::class_<DerivedTy, PyValue>;
1477   using IsAFunctionTy = bool (*)(MlirValue);
1478 
1479   PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1480   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1481       : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1482   PyConcreteValue(PyValue &orig)
1483       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1484 
1485   /// Attempts to cast the original value to the derived type and throws on
1486   /// type mismatches.
castFrom(PyValue & orig)1487   static MlirValue castFrom(PyValue &orig) {
1488     if (!DerivedTy::isaFunction(orig.get())) {
1489       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1490       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1491                                              DerivedTy::pyClassName +
1492                                              " (from " + origRepr + ")");
1493     }
1494     return orig.get();
1495   }
1496 
1497   /// Binds the Python module objects to functions of this class.
bind(py::module & m)1498   static void bind(py::module &m) {
1499     auto cls = ClassTy(m, DerivedTy::pyClassName);
1500     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1501     DerivedTy::bindDerived(cls);
1502   }
1503 
1504   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1505   static void bindDerived(ClassTy &m) {}
1506 };
1507 
1508 /// Python wrapper for MlirBlockArgument.
1509 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1510 public:
1511   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1512   static constexpr const char *pyClassName = "BlockArgument";
1513   using PyConcreteValue::PyConcreteValue;
1514 
bindDerived(ClassTy & c)1515   static void bindDerived(ClassTy &c) {
1516     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1517       return PyBlock(self.getParentOperation(),
1518                      mlirBlockArgumentGetOwner(self.get()));
1519     });
1520     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1521       return mlirBlockArgumentGetArgNumber(self.get());
1522     });
1523     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1524       return mlirBlockArgumentSetType(self.get(), type);
1525     });
1526   }
1527 };
1528 
1529 /// Python wrapper for MlirOpResult.
1530 class PyOpResult : public PyConcreteValue<PyOpResult> {
1531 public:
1532   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1533   static constexpr const char *pyClassName = "OpResult";
1534   using PyConcreteValue::PyConcreteValue;
1535 
bindDerived(ClassTy & c)1536   static void bindDerived(ClassTy &c) {
1537     c.def_property_readonly("owner", [](PyOpResult &self) {
1538       assert(
1539           mlirOperationEqual(self.getParentOperation()->get(),
1540                              mlirOpResultGetOwner(self.get())) &&
1541           "expected the owner of the value in Python to match that in the IR");
1542       return self.getParentOperation();
1543     });
1544     c.def_property_readonly("result_number", [](PyOpResult &self) {
1545       return mlirOpResultGetResultNumber(self.get());
1546     });
1547   }
1548 };
1549 
1550 /// A list of block arguments. Internally, these are stored as consecutive
1551 /// elements, random access is cheap. The argument list is associated with the
1552 /// operation that contains the block (detached blocks are not allowed in
1553 /// Python bindings) and extends its lifetime.
1554 class PyBlockArgumentList {
1555 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1556   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1557       : operation(std::move(operation)), block(block) {}
1558 
1559   /// Returns the length of the block argument list.
dunderLen()1560   intptr_t dunderLen() {
1561     operation->checkValid();
1562     return mlirBlockGetNumArguments(block);
1563   }
1564 
1565   /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1566   PyBlockArgument dunderGetItem(intptr_t index) {
1567     if (index < 0 || index >= dunderLen()) {
1568       throw SetPyError(PyExc_IndexError,
1569                        "attempt to access out of bounds region");
1570     }
1571     PyValue value(operation, mlirBlockGetArgument(block, index));
1572     return PyBlockArgument(value);
1573   }
1574 
1575   /// Defines a Python class in the bindings.
bind(py::module & m)1576   static void bind(py::module &m) {
1577     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1578         .def("__len__", &PyBlockArgumentList::dunderLen)
1579         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1580   }
1581 
1582 private:
1583   PyOperationRef operation;
1584   MlirBlock block;
1585 };
1586 
1587 /// A list of operation operands. Internally, these are stored as consecutive
1588 /// elements, random access is cheap. The result list is associated with the
1589 /// operation whose results these are, and extends the lifetime of this
1590 /// operation.
1591 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1592 public:
1593   static constexpr const char *pyClassName = "OpOperandList";
1594 
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1595   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1596                   intptr_t length = -1, intptr_t step = 1)
1597       : Sliceable(startIndex,
1598                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1599                                : length,
1600                   step),
1601         operation(operation) {}
1602 
getNumElements()1603   intptr_t getNumElements() {
1604     operation->checkValid();
1605     return mlirOperationGetNumOperands(operation->get());
1606   }
1607 
getElement(intptr_t pos)1608   PyValue getElement(intptr_t pos) {
1609     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1610   }
1611 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1612   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1613     return PyOpOperandList(operation, startIndex, length, step);
1614   }
1615 
1616 private:
1617   PyOperationRef operation;
1618 };
1619 
1620 /// A list of operation results. Internally, these are stored as consecutive
1621 /// elements, random access is cheap. The result list is associated with the
1622 /// operation whose results these are, and extends the lifetime of this
1623 /// operation.
1624 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1625 public:
1626   static constexpr const char *pyClassName = "OpResultList";
1627 
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1628   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1629                  intptr_t length = -1, intptr_t step = 1)
1630       : Sliceable(startIndex,
1631                   length == -1 ? mlirOperationGetNumResults(operation->get())
1632                                : length,
1633                   step),
1634         operation(operation) {}
1635 
getNumElements()1636   intptr_t getNumElements() {
1637     operation->checkValid();
1638     return mlirOperationGetNumResults(operation->get());
1639   }
1640 
getElement(intptr_t index)1641   PyOpResult getElement(intptr_t index) {
1642     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1643     return PyOpResult(value);
1644   }
1645 
slice(intptr_t startIndex,intptr_t length,intptr_t step)1646   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1647     return PyOpResultList(operation, startIndex, length, step);
1648   }
1649 
1650 private:
1651   PyOperationRef operation;
1652 };
1653 
1654 /// A list of operation attributes. Can be indexed by name, producing
1655 /// attributes, or by index, producing named attributes.
1656 class PyOpAttributeMap {
1657 public:
PyOpAttributeMap(PyOperationRef operation)1658   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1659 
dunderGetItemNamed(const std::string & name)1660   PyAttribute dunderGetItemNamed(const std::string &name) {
1661     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1662                                                          toMlirStringRef(name));
1663     if (mlirAttributeIsNull(attr)) {
1664       throw SetPyError(PyExc_KeyError,
1665                        "attempt to access a non-existent attribute");
1666     }
1667     return PyAttribute(operation->getContext(), attr);
1668   }
1669 
dunderGetItemIndexed(intptr_t index)1670   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1671     if (index < 0 || index >= dunderLen()) {
1672       throw SetPyError(PyExc_IndexError,
1673                        "attempt to access out of bounds attribute");
1674     }
1675     MlirNamedAttribute namedAttr =
1676         mlirOperationGetAttribute(operation->get(), index);
1677     return PyNamedAttribute(
1678         namedAttr.attribute,
1679         std::string(mlirIdentifierStr(namedAttr.name).data));
1680   }
1681 
dunderSetItem(const std::string & name,PyAttribute attr)1682   void dunderSetItem(const std::string &name, PyAttribute attr) {
1683     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1684                                     attr);
1685   }
1686 
dunderDelItem(const std::string & name)1687   void dunderDelItem(const std::string &name) {
1688     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1689                                                      toMlirStringRef(name));
1690     if (!removed)
1691       throw SetPyError(PyExc_KeyError,
1692                        "attempt to delete a non-existent attribute");
1693   }
1694 
dunderLen()1695   intptr_t dunderLen() {
1696     return mlirOperationGetNumAttributes(operation->get());
1697   }
1698 
dunderContains(const std::string & name)1699   bool dunderContains(const std::string &name) {
1700     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1701         operation->get(), toMlirStringRef(name)));
1702   }
1703 
bind(py::module & m)1704   static void bind(py::module &m) {
1705     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1706         .def("__contains__", &PyOpAttributeMap::dunderContains)
1707         .def("__len__", &PyOpAttributeMap::dunderLen)
1708         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1709         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1710         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1711         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1712   }
1713 
1714 private:
1715   PyOperationRef operation;
1716 };
1717 
1718 } // end namespace
1719 
1720 //------------------------------------------------------------------------------
1721 // Builtin attribute subclasses.
1722 //------------------------------------------------------------------------------
1723 
1724 namespace {
1725 
1726 /// CRTP base classes for Python attributes that subclass Attribute and should
1727 /// be castable from it (i.e. via something like StringAttr(attr)).
1728 /// By default, attribute class hierarchies are one level deep (i.e. a
1729 /// concrete attribute class extends PyAttribute); however, intermediate
1730 /// python-visible base classes can be modeled by specifying a BaseTy.
1731 template <typename DerivedTy, typename BaseTy = PyAttribute>
1732 class PyConcreteAttribute : public BaseTy {
1733 public:
1734   // Derived classes must define statics for:
1735   //   IsAFunctionTy isaFunction
1736   //   const char *pyClassName
1737   using ClassTy = py::class_<DerivedTy, BaseTy>;
1738   using IsAFunctionTy = bool (*)(MlirAttribute);
1739 
1740   PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)1741   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1742       : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)1743   PyConcreteAttribute(PyAttribute &orig)
1744       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
1745 
castFrom(PyAttribute & orig)1746   static MlirAttribute castFrom(PyAttribute &orig) {
1747     if (!DerivedTy::isaFunction(orig)) {
1748       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1749       throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
1750                                              DerivedTy::pyClassName +
1751                                              " (from " + origRepr + ")");
1752     }
1753     return orig;
1754   }
1755 
bind(py::module & m)1756   static void bind(py::module &m) {
1757     auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
1758     cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
1759     DerivedTy::bindDerived(cls);
1760   }
1761 
1762   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1763   static void bindDerived(ClassTy &m) {}
1764 };
1765 
1766 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
1767 public:
1768   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
1769   static constexpr const char *pyClassName = "ArrayAttr";
1770   using PyConcreteAttribute::PyConcreteAttribute;
1771 
1772   class PyArrayAttributeIterator {
1773   public:
PyArrayAttributeIterator(PyAttribute attr)1774     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
1775 
dunderIter()1776     PyArrayAttributeIterator &dunderIter() { return *this; }
1777 
dunderNext()1778     PyAttribute dunderNext() {
1779       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
1780         throw py::stop_iteration();
1781       }
1782       return PyAttribute(attr.getContext(),
1783                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
1784     }
1785 
bind(py::module & m)1786     static void bind(py::module &m) {
1787       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
1788           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
1789           .def("__next__", &PyArrayAttributeIterator::dunderNext);
1790     }
1791 
1792   private:
1793     PyAttribute attr;
1794     int nextIndex = 0;
1795   };
1796 
bindDerived(ClassTy & c)1797   static void bindDerived(ClassTy &c) {
1798     c.def_static(
1799         "get",
1800         [](py::list attributes, DefaultingPyMlirContext context) {
1801           SmallVector<MlirAttribute> mlirAttributes;
1802           mlirAttributes.reserve(py::len(attributes));
1803           for (auto attribute : attributes) {
1804             try {
1805               mlirAttributes.push_back(attribute.cast<PyAttribute>());
1806             } catch (py::cast_error &err) {
1807               std::string msg = std::string("Invalid attribute when attempting "
1808                                             "to create an ArrayAttribute (") +
1809                                 err.what() + ")";
1810               throw py::cast_error(msg);
1811             } catch (py::reference_cast_error &err) {
1812               // This exception seems thrown when the value is "None".
1813               std::string msg =
1814                   std::string("Invalid attribute (None?) when attempting to "
1815                               "create an ArrayAttribute (") +
1816                   err.what() + ")";
1817               throw py::cast_error(msg);
1818             }
1819           }
1820           MlirAttribute attr = mlirArrayAttrGet(
1821               context->get(), mlirAttributes.size(), mlirAttributes.data());
1822           return PyArrayAttribute(context->getRef(), attr);
1823         },
1824         py::arg("attributes"), py::arg("context") = py::none(),
1825         "Gets a uniqued Array attribute");
1826     c.def("__getitem__",
1827           [](PyArrayAttribute &arr, intptr_t i) {
1828             if (i >= mlirArrayAttrGetNumElements(arr))
1829               throw py::index_error("ArrayAttribute index out of range");
1830             return PyAttribute(arr.getContext(),
1831                                mlirArrayAttrGetElement(arr, i));
1832           })
1833         .def("__len__",
1834              [](const PyArrayAttribute &arr) {
1835                return mlirArrayAttrGetNumElements(arr);
1836              })
1837         .def("__iter__", [](const PyArrayAttribute &arr) {
1838           return PyArrayAttributeIterator(arr);
1839         });
1840   }
1841 };
1842 
1843 /// Float Point Attribute subclass - FloatAttr.
1844 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
1845 public:
1846   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
1847   static constexpr const char *pyClassName = "FloatAttr";
1848   using PyConcreteAttribute::PyConcreteAttribute;
1849 
bindDerived(ClassTy & c)1850   static void bindDerived(ClassTy &c) {
1851     c.def_static(
1852         "get",
1853         [](PyType &type, double value, DefaultingPyLocation loc) {
1854           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
1855           // TODO: Rework error reporting once diagnostic engine is exposed
1856           // in C API.
1857           if (mlirAttributeIsNull(attr)) {
1858             throw SetPyError(PyExc_ValueError,
1859                              Twine("invalid '") +
1860                                  py::repr(py::cast(type)).cast<std::string>() +
1861                                  "' and expected floating point type.");
1862           }
1863           return PyFloatAttribute(type.getContext(), attr);
1864         },
1865         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
1866         "Gets an uniqued float point attribute associated to a type");
1867     c.def_static(
1868         "get_f32",
1869         [](double value, DefaultingPyMlirContext context) {
1870           MlirAttribute attr = mlirFloatAttrDoubleGet(
1871               context->get(), mlirF32TypeGet(context->get()), value);
1872           return PyFloatAttribute(context->getRef(), attr);
1873         },
1874         py::arg("value"), py::arg("context") = py::none(),
1875         "Gets an uniqued float point attribute associated to a f32 type");
1876     c.def_static(
1877         "get_f64",
1878         [](double value, DefaultingPyMlirContext context) {
1879           MlirAttribute attr = mlirFloatAttrDoubleGet(
1880               context->get(), mlirF64TypeGet(context->get()), value);
1881           return PyFloatAttribute(context->getRef(), attr);
1882         },
1883         py::arg("value"), py::arg("context") = py::none(),
1884         "Gets an uniqued float point attribute associated to a f64 type");
1885     c.def_property_readonly(
1886         "value",
1887         [](PyFloatAttribute &self) {
1888           return mlirFloatAttrGetValueDouble(self);
1889         },
1890         "Returns the value of the float point attribute");
1891   }
1892 };
1893 
1894 /// Integer Attribute subclass - IntegerAttr.
1895 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
1896 public:
1897   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
1898   static constexpr const char *pyClassName = "IntegerAttr";
1899   using PyConcreteAttribute::PyConcreteAttribute;
1900 
bindDerived(ClassTy & c)1901   static void bindDerived(ClassTy &c) {
1902     c.def_static(
1903         "get",
1904         [](PyType &type, int64_t value) {
1905           MlirAttribute attr = mlirIntegerAttrGet(type, value);
1906           return PyIntegerAttribute(type.getContext(), attr);
1907         },
1908         py::arg("type"), py::arg("value"),
1909         "Gets an uniqued integer attribute associated to a type");
1910     c.def_property_readonly(
1911         "value",
1912         [](PyIntegerAttribute &self) {
1913           return mlirIntegerAttrGetValueInt(self);
1914         },
1915         "Returns the value of the integer attribute");
1916   }
1917 };
1918 
1919 /// Bool Attribute subclass - BoolAttr.
1920 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
1921 public:
1922   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
1923   static constexpr const char *pyClassName = "BoolAttr";
1924   using PyConcreteAttribute::PyConcreteAttribute;
1925 
bindDerived(ClassTy & c)1926   static void bindDerived(ClassTy &c) {
1927     c.def_static(
1928         "get",
1929         [](bool value, DefaultingPyMlirContext context) {
1930           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
1931           return PyBoolAttribute(context->getRef(), attr);
1932         },
1933         py::arg("value"), py::arg("context") = py::none(),
1934         "Gets an uniqued bool attribute");
1935     c.def_property_readonly(
1936         "value",
1937         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
1938         "Returns the value of the bool attribute");
1939   }
1940 };
1941 
1942 class PyFlatSymbolRefAttribute
1943     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
1944 public:
1945   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
1946   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
1947   using PyConcreteAttribute::PyConcreteAttribute;
1948 
bindDerived(ClassTy & c)1949   static void bindDerived(ClassTy &c) {
1950     c.def_static(
1951         "get",
1952         [](std::string value, DefaultingPyMlirContext context) {
1953           MlirAttribute attr =
1954               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
1955           return PyFlatSymbolRefAttribute(context->getRef(), attr);
1956         },
1957         py::arg("value"), py::arg("context") = py::none(),
1958         "Gets a uniqued FlatSymbolRef attribute");
1959     c.def_property_readonly(
1960         "value",
1961         [](PyFlatSymbolRefAttribute &self) {
1962           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
1963           return py::str(stringRef.data, stringRef.length);
1964         },
1965         "Returns the value of the FlatSymbolRef attribute as a string");
1966   }
1967 };
1968 
1969 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
1970 public:
1971   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
1972   static constexpr const char *pyClassName = "StringAttr";
1973   using PyConcreteAttribute::PyConcreteAttribute;
1974 
bindDerived(ClassTy & c)1975   static void bindDerived(ClassTy &c) {
1976     c.def_static(
1977         "get",
1978         [](std::string value, DefaultingPyMlirContext context) {
1979           MlirAttribute attr =
1980               mlirStringAttrGet(context->get(), toMlirStringRef(value));
1981           return PyStringAttribute(context->getRef(), attr);
1982         },
1983         py::arg("value"), py::arg("context") = py::none(),
1984         "Gets a uniqued string attribute");
1985     c.def_static(
1986         "get_typed",
1987         [](PyType &type, std::string value) {
1988           MlirAttribute attr =
1989               mlirStringAttrTypedGet(type, toMlirStringRef(value));
1990           return PyStringAttribute(type.getContext(), attr);
1991         },
1992 
1993         "Gets a uniqued string attribute associated to a type");
1994     c.def_property_readonly(
1995         "value",
1996         [](PyStringAttribute &self) {
1997           MlirStringRef stringRef = mlirStringAttrGetValue(self);
1998           return py::str(stringRef.data, stringRef.length);
1999         },
2000         "Returns the value of the string attribute");
2001   }
2002 };
2003 
2004 // TODO: Support construction of bool elements.
2005 // TODO: Support construction of string elements.
2006 class PyDenseElementsAttribute
2007     : public PyConcreteAttribute<PyDenseElementsAttribute> {
2008 public:
2009   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
2010   static constexpr const char *pyClassName = "DenseElementsAttr";
2011   using PyConcreteAttribute::PyConcreteAttribute;
2012 
2013   static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,DefaultingPyMlirContext contextWrapper)2014   getFromBuffer(py::buffer array, bool signless,
2015                 DefaultingPyMlirContext contextWrapper) {
2016     // Request a contiguous view. In exotic cases, this will cause a copy.
2017     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
2018     Py_buffer *view = new Py_buffer();
2019     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
2020       delete view;
2021       throw py::error_already_set();
2022     }
2023     py::buffer_info arrayInfo(view);
2024 
2025     MlirContext context = contextWrapper->get();
2026     // Switch on the types that can be bulk loaded between the Python and
2027     // MLIR-C APIs.
2028     // See: https://docs.python.org/3/library/struct.html#format-characters
2029     if (arrayInfo.format == "f") {
2030       // f32
2031       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
2032       return PyDenseElementsAttribute(
2033           contextWrapper->getRef(),
2034           bulkLoad(context, mlirDenseElementsAttrFloatGet,
2035                    mlirF32TypeGet(context), arrayInfo));
2036     } else if (arrayInfo.format == "d") {
2037       // f64
2038       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
2039       return PyDenseElementsAttribute(
2040           contextWrapper->getRef(),
2041           bulkLoad(context, mlirDenseElementsAttrDoubleGet,
2042                    mlirF64TypeGet(context), arrayInfo));
2043     } else if (isSignedIntegerFormat(arrayInfo.format)) {
2044       if (arrayInfo.itemsize == 4) {
2045         // i32
2046         MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
2047                                         : mlirIntegerTypeSignedGet(context, 32);
2048         return PyDenseElementsAttribute(contextWrapper->getRef(),
2049                                         bulkLoad(context,
2050                                                  mlirDenseElementsAttrInt32Get,
2051                                                  elementType, arrayInfo));
2052       } else if (arrayInfo.itemsize == 8) {
2053         // i64
2054         MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
2055                                         : mlirIntegerTypeSignedGet(context, 64);
2056         return PyDenseElementsAttribute(contextWrapper->getRef(),
2057                                         bulkLoad(context,
2058                                                  mlirDenseElementsAttrInt64Get,
2059                                                  elementType, arrayInfo));
2060       }
2061     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
2062       if (arrayInfo.itemsize == 4) {
2063         // unsigned i32
2064         MlirType elementType = signless
2065                                    ? mlirIntegerTypeGet(context, 32)
2066                                    : mlirIntegerTypeUnsignedGet(context, 32);
2067         return PyDenseElementsAttribute(contextWrapper->getRef(),
2068                                         bulkLoad(context,
2069                                                  mlirDenseElementsAttrUInt32Get,
2070                                                  elementType, arrayInfo));
2071       } else if (arrayInfo.itemsize == 8) {
2072         // unsigned i64
2073         MlirType elementType = signless
2074                                    ? mlirIntegerTypeGet(context, 64)
2075                                    : mlirIntegerTypeUnsignedGet(context, 64);
2076         return PyDenseElementsAttribute(contextWrapper->getRef(),
2077                                         bulkLoad(context,
2078                                                  mlirDenseElementsAttrUInt64Get,
2079                                                  elementType, arrayInfo));
2080       }
2081     }
2082 
2083     // TODO: Fall back to string-based get.
2084     std::string message = "unimplemented array format conversion from format: ";
2085     message.append(arrayInfo.format);
2086     throw SetPyError(PyExc_ValueError, message);
2087   }
2088 
getSplat(PyType shapedType,PyAttribute & elementAttr)2089   static PyDenseElementsAttribute getSplat(PyType shapedType,
2090                                            PyAttribute &elementAttr) {
2091     auto contextWrapper =
2092         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
2093     if (!mlirAttributeIsAInteger(elementAttr) &&
2094         !mlirAttributeIsAFloat(elementAttr)) {
2095       std::string message = "Illegal element type for DenseElementsAttr: ";
2096       message.append(py::repr(py::cast(elementAttr)));
2097       throw SetPyError(PyExc_ValueError, message);
2098     }
2099     if (!mlirTypeIsAShaped(shapedType) ||
2100         !mlirShapedTypeHasStaticShape(shapedType)) {
2101       std::string message =
2102           "Expected a static ShapedType for the shaped_type parameter: ";
2103       message.append(py::repr(py::cast(shapedType)));
2104       throw SetPyError(PyExc_ValueError, message);
2105     }
2106     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
2107     MlirType attrType = mlirAttributeGetType(elementAttr);
2108     if (!mlirTypeEqual(shapedElementType, attrType)) {
2109       std::string message =
2110           "Shaped element type and attribute type must be equal: shaped=";
2111       message.append(py::repr(py::cast(shapedType)));
2112       message.append(", element=");
2113       message.append(py::repr(py::cast(elementAttr)));
2114       throw SetPyError(PyExc_ValueError, message);
2115     }
2116 
2117     MlirAttribute elements =
2118         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
2119     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
2120   }
2121 
dunderLen()2122   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
2123 
accessBuffer()2124   py::buffer_info accessBuffer() {
2125     MlirType shapedType = mlirAttributeGetType(*this);
2126     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
2127 
2128     if (mlirTypeIsAF32(elementType)) {
2129       // f32
2130       return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
2131     } else if (mlirTypeIsAF64(elementType)) {
2132       // f64
2133       return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
2134     } else if (mlirTypeIsAInteger(elementType) &&
2135                mlirIntegerTypeGetWidth(elementType) == 32) {
2136       if (mlirIntegerTypeIsSignless(elementType) ||
2137           mlirIntegerTypeIsSigned(elementType)) {
2138         // i32
2139         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
2140       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
2141         // unsigned i32
2142         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
2143       }
2144     } else if (mlirTypeIsAInteger(elementType) &&
2145                mlirIntegerTypeGetWidth(elementType) == 64) {
2146       if (mlirIntegerTypeIsSignless(elementType) ||
2147           mlirIntegerTypeIsSigned(elementType)) {
2148         // i64
2149         return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
2150       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
2151         // unsigned i64
2152         return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
2153       }
2154     }
2155 
2156     std::string message = "unimplemented array format.";
2157     throw SetPyError(PyExc_ValueError, message);
2158   }
2159 
bindDerived(ClassTy & c)2160   static void bindDerived(ClassTy &c) {
2161     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
2162         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
2163                     py::arg("array"), py::arg("signless") = true,
2164                     py::arg("context") = py::none(),
2165                     "Gets from a buffer or ndarray")
2166         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
2167                     py::arg("shaped_type"), py::arg("element_attr"),
2168                     "Gets a DenseElementsAttr where all values are the same")
2169         .def_property_readonly("is_splat",
2170                                [](PyDenseElementsAttribute &self) -> bool {
2171                                  return mlirDenseElementsAttrIsSplat(self);
2172                                })
2173         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
2174   }
2175 
2176 private:
2177   template <typename ElementTy>
2178   static MlirAttribute
bulkLoad(MlirContext context,MlirAttribute (* ctor)(MlirType,intptr_t,ElementTy *),MlirType mlirElementType,py::buffer_info & arrayInfo)2179   bulkLoad(MlirContext context,
2180            MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
2181            MlirType mlirElementType, py::buffer_info &arrayInfo) {
2182     SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
2183                                   arrayInfo.shape.begin() + arrayInfo.ndim);
2184     auto shapedType =
2185         mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
2186     intptr_t numElements = arrayInfo.size;
2187     const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
2188     return ctor(shapedType, numElements, contents);
2189   }
2190 
isUnsignedIntegerFormat(const std::string & format)2191   static bool isUnsignedIntegerFormat(const std::string &format) {
2192     if (format.empty())
2193       return false;
2194     char code = format[0];
2195     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
2196            code == 'Q';
2197   }
2198 
isSignedIntegerFormat(const std::string & format)2199   static bool isSignedIntegerFormat(const std::string &format) {
2200     if (format.empty())
2201       return false;
2202     char code = format[0];
2203     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
2204            code == 'q';
2205   }
2206 
2207   template <typename Type>
bufferInfo(MlirType shapedType,Type (* value)(MlirAttribute,intptr_t))2208   py::buffer_info bufferInfo(MlirType shapedType,
2209                              Type (*value)(MlirAttribute, intptr_t)) {
2210     intptr_t rank = mlirShapedTypeGetRank(shapedType);
2211     // Prepare the data for the buffer_info.
2212     // Buffer is configured for read-only access below.
2213     Type *data = static_cast<Type *>(
2214         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
2215     // Prepare the shape for the buffer_info.
2216     SmallVector<intptr_t, 4> shape;
2217     for (intptr_t i = 0; i < rank; ++i)
2218       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
2219     // Prepare the strides for the buffer_info.
2220     SmallVector<intptr_t, 4> strides;
2221     intptr_t strideFactor = 1;
2222     for (intptr_t i = 1; i < rank; ++i) {
2223       strideFactor = 1;
2224       for (intptr_t j = i; j < rank; ++j) {
2225         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
2226       }
2227       strides.push_back(sizeof(Type) * strideFactor);
2228     }
2229     strides.push_back(sizeof(Type));
2230     return py::buffer_info(data, sizeof(Type),
2231                            py::format_descriptor<Type>::format(), rank, shape,
2232                            strides, /*readonly=*/true);
2233   }
2234 }; // namespace
2235 
2236 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
2237 /// (and boolean) values. Supports element access.
2238 class PyDenseIntElementsAttribute
2239     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
2240                                  PyDenseElementsAttribute> {
2241 public:
2242   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
2243   static constexpr const char *pyClassName = "DenseIntElementsAttr";
2244   using PyConcreteAttribute::PyConcreteAttribute;
2245 
2246   /// Returns the element at the given linear position. Asserts if the index is
2247   /// out of range.
dunderGetItem(intptr_t pos)2248   py::int_ dunderGetItem(intptr_t pos) {
2249     if (pos < 0 || pos >= dunderLen()) {
2250       throw SetPyError(PyExc_IndexError,
2251                        "attempt to access out of bounds element");
2252     }
2253 
2254     MlirType type = mlirAttributeGetType(*this);
2255     type = mlirShapedTypeGetElementType(type);
2256     assert(mlirTypeIsAInteger(type) &&
2257            "expected integer element type in dense int elements attribute");
2258     // Dispatch element extraction to an appropriate C function based on the
2259     // elemental type of the attribute. py::int_ is implicitly constructible
2260     // from any C++ integral type and handles bitwidth correctly.
2261     // TODO: consider caching the type properties in the constructor to avoid
2262     // querying them on each element access.
2263     unsigned width = mlirIntegerTypeGetWidth(type);
2264     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
2265     if (isUnsigned) {
2266       if (width == 1) {
2267         return mlirDenseElementsAttrGetBoolValue(*this, pos);
2268       }
2269       if (width == 32) {
2270         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
2271       }
2272       if (width == 64) {
2273         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
2274       }
2275     } else {
2276       if (width == 1) {
2277         return mlirDenseElementsAttrGetBoolValue(*this, pos);
2278       }
2279       if (width == 32) {
2280         return mlirDenseElementsAttrGetInt32Value(*this, pos);
2281       }
2282       if (width == 64) {
2283         return mlirDenseElementsAttrGetInt64Value(*this, pos);
2284       }
2285     }
2286     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
2287   }
2288 
bindDerived(ClassTy & c)2289   static void bindDerived(ClassTy &c) {
2290     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
2291   }
2292 };
2293 
2294 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
2295 public:
2296   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
2297   static constexpr const char *pyClassName = "DictAttr";
2298   using PyConcreteAttribute::PyConcreteAttribute;
2299 
dunderLen()2300   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
2301 
bindDerived(ClassTy & c)2302   static void bindDerived(ClassTy &c) {
2303     c.def("__len__", &PyDictAttribute::dunderLen);
2304     c.def_static(
2305         "get",
2306         [](py::dict attributes, DefaultingPyMlirContext context) {
2307           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
2308           mlirNamedAttributes.reserve(attributes.size());
2309           for (auto &it : attributes) {
2310             auto &mlir_attr = it.second.cast<PyAttribute &>();
2311             auto name = it.first.cast<std::string>();
2312             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
2313                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
2314                                   toMlirStringRef(name)),
2315                 mlir_attr));
2316           }
2317           MlirAttribute attr =
2318               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
2319                                     mlirNamedAttributes.data());
2320           return PyDictAttribute(context->getRef(), attr);
2321         },
2322         py::arg("value"), py::arg("context") = py::none(),
2323         "Gets an uniqued dict attribute");
2324     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
2325       MlirAttribute attr =
2326           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
2327       if (mlirAttributeIsNull(attr)) {
2328         throw SetPyError(PyExc_KeyError,
2329                          "attempt to access a non-existent attribute");
2330       }
2331       return PyAttribute(self.getContext(), attr);
2332     });
2333     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
2334       if (index < 0 || index >= self.dunderLen()) {
2335         throw SetPyError(PyExc_IndexError,
2336                          "attempt to access out of bounds attribute");
2337       }
2338       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
2339       return PyNamedAttribute(
2340           namedAttr.attribute,
2341           std::string(mlirIdentifierStr(namedAttr.name).data));
2342     });
2343   }
2344 };
2345 
2346 /// Refinement of PyDenseElementsAttribute for attributes containing
2347 /// floating-point values. Supports element access.
2348 class PyDenseFPElementsAttribute
2349     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
2350                                  PyDenseElementsAttribute> {
2351 public:
2352   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
2353   static constexpr const char *pyClassName = "DenseFPElementsAttr";
2354   using PyConcreteAttribute::PyConcreteAttribute;
2355 
dunderGetItem(intptr_t pos)2356   py::float_ dunderGetItem(intptr_t pos) {
2357     if (pos < 0 || pos >= dunderLen()) {
2358       throw SetPyError(PyExc_IndexError,
2359                        "attempt to access out of bounds element");
2360     }
2361 
2362     MlirType type = mlirAttributeGetType(*this);
2363     type = mlirShapedTypeGetElementType(type);
2364     // Dispatch element extraction to an appropriate C function based on the
2365     // elemental type of the attribute. py::float_ is implicitly constructible
2366     // from float and double.
2367     // TODO: consider caching the type properties in the constructor to avoid
2368     // querying them on each element access.
2369     if (mlirTypeIsAF32(type)) {
2370       return mlirDenseElementsAttrGetFloatValue(*this, pos);
2371     }
2372     if (mlirTypeIsAF64(type)) {
2373       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
2374     }
2375     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
2376   }
2377 
bindDerived(ClassTy & c)2378   static void bindDerived(ClassTy &c) {
2379     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
2380   }
2381 };
2382 
2383 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
2384 public:
2385   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
2386   static constexpr const char *pyClassName = "TypeAttr";
2387   using PyConcreteAttribute::PyConcreteAttribute;
2388 
bindDerived(ClassTy & c)2389   static void bindDerived(ClassTy &c) {
2390     c.def_static(
2391         "get",
2392         [](PyType value, DefaultingPyMlirContext context) {
2393           MlirAttribute attr = mlirTypeAttrGet(value.get());
2394           return PyTypeAttribute(context->getRef(), attr);
2395         },
2396         py::arg("value"), py::arg("context") = py::none(),
2397         "Gets a uniqued Type attribute");
2398     c.def_property_readonly("value", [](PyTypeAttribute &self) {
2399       return PyType(self.getContext()->getRef(),
2400                     mlirTypeAttrGetValue(self.get()));
2401     });
2402   }
2403 };
2404 
2405 /// Unit Attribute subclass. Unit attributes don't have values.
2406 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
2407 public:
2408   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
2409   static constexpr const char *pyClassName = "UnitAttr";
2410   using PyConcreteAttribute::PyConcreteAttribute;
2411 
bindDerived(ClassTy & c)2412   static void bindDerived(ClassTy &c) {
2413     c.def_static(
2414         "get",
2415         [](DefaultingPyMlirContext context) {
2416           return PyUnitAttribute(context->getRef(),
2417                                  mlirUnitAttrGet(context->get()));
2418         },
2419         py::arg("context") = py::none(), "Create a Unit attribute.");
2420   }
2421 };
2422 
2423 } // namespace
2424 
2425 //------------------------------------------------------------------------------
2426 // Builtin type subclasses.
2427 //------------------------------------------------------------------------------
2428 
2429 namespace {
2430 
2431 /// CRTP base classes for Python types that subclass Type and should be
2432 /// castable from it (i.e. via something like IntegerType(t)).
2433 /// By default, type class hierarchies are one level deep (i.e. a
2434 /// concrete type class extends PyType); however, intermediate python-visible
2435 /// base classes can be modeled by specifying a BaseTy.
2436 template <typename DerivedTy, typename BaseTy = PyType>
2437 class PyConcreteType : public BaseTy {
2438 public:
2439   // Derived classes must define statics for:
2440   //   IsAFunctionTy isaFunction
2441   //   const char *pyClassName
2442   using ClassTy = py::class_<DerivedTy, BaseTy>;
2443   using IsAFunctionTy = bool (*)(MlirType);
2444 
2445   PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)2446   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
2447       : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)2448   PyConcreteType(PyType &orig)
2449       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
2450 
castFrom(PyType & orig)2451   static MlirType castFrom(PyType &orig) {
2452     if (!DerivedTy::isaFunction(orig)) {
2453       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2454       throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
2455                                              DerivedTy::pyClassName +
2456                                              " (from " + origRepr + ")");
2457     }
2458     return orig;
2459   }
2460 
bind(py::module & m)2461   static void bind(py::module &m) {
2462     auto cls = ClassTy(m, DerivedTy::pyClassName);
2463     cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
2464     DerivedTy::bindDerived(cls);
2465   }
2466 
2467   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)2468   static void bindDerived(ClassTy &m) {}
2469 };
2470 
2471 class PyIntegerType : public PyConcreteType<PyIntegerType> {
2472 public:
2473   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
2474   static constexpr const char *pyClassName = "IntegerType";
2475   using PyConcreteType::PyConcreteType;
2476 
bindDerived(ClassTy & c)2477   static void bindDerived(ClassTy &c) {
2478     c.def_static(
2479         "get_signless",
2480         [](unsigned width, DefaultingPyMlirContext context) {
2481           MlirType t = mlirIntegerTypeGet(context->get(), width);
2482           return PyIntegerType(context->getRef(), t);
2483         },
2484         py::arg("width"), py::arg("context") = py::none(),
2485         "Create a signless integer type");
2486     c.def_static(
2487         "get_signed",
2488         [](unsigned width, DefaultingPyMlirContext context) {
2489           MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
2490           return PyIntegerType(context->getRef(), t);
2491         },
2492         py::arg("width"), py::arg("context") = py::none(),
2493         "Create a signed integer type");
2494     c.def_static(
2495         "get_unsigned",
2496         [](unsigned width, DefaultingPyMlirContext context) {
2497           MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
2498           return PyIntegerType(context->getRef(), t);
2499         },
2500         py::arg("width"), py::arg("context") = py::none(),
2501         "Create an unsigned integer type");
2502     c.def_property_readonly(
2503         "width",
2504         [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
2505         "Returns the width of the integer type");
2506     c.def_property_readonly(
2507         "is_signless",
2508         [](PyIntegerType &self) -> bool {
2509           return mlirIntegerTypeIsSignless(self);
2510         },
2511         "Returns whether this is a signless integer");
2512     c.def_property_readonly(
2513         "is_signed",
2514         [](PyIntegerType &self) -> bool {
2515           return mlirIntegerTypeIsSigned(self);
2516         },
2517         "Returns whether this is a signed integer");
2518     c.def_property_readonly(
2519         "is_unsigned",
2520         [](PyIntegerType &self) -> bool {
2521           return mlirIntegerTypeIsUnsigned(self);
2522         },
2523         "Returns whether this is an unsigned integer");
2524   }
2525 };
2526 
2527 /// Index Type subclass - IndexType.
2528 class PyIndexType : public PyConcreteType<PyIndexType> {
2529 public:
2530   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
2531   static constexpr const char *pyClassName = "IndexType";
2532   using PyConcreteType::PyConcreteType;
2533 
bindDerived(ClassTy & c)2534   static void bindDerived(ClassTy &c) {
2535     c.def_static(
2536         "get",
2537         [](DefaultingPyMlirContext context) {
2538           MlirType t = mlirIndexTypeGet(context->get());
2539           return PyIndexType(context->getRef(), t);
2540         },
2541         py::arg("context") = py::none(), "Create a index type.");
2542   }
2543 };
2544 
2545 /// Floating Point Type subclass - BF16Type.
2546 class PyBF16Type : public PyConcreteType<PyBF16Type> {
2547 public:
2548   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
2549   static constexpr const char *pyClassName = "BF16Type";
2550   using PyConcreteType::PyConcreteType;
2551 
bindDerived(ClassTy & c)2552   static void bindDerived(ClassTy &c) {
2553     c.def_static(
2554         "get",
2555         [](DefaultingPyMlirContext context) {
2556           MlirType t = mlirBF16TypeGet(context->get());
2557           return PyBF16Type(context->getRef(), t);
2558         },
2559         py::arg("context") = py::none(), "Create a bf16 type.");
2560   }
2561 };
2562 
2563 /// Floating Point Type subclass - F16Type.
2564 class PyF16Type : public PyConcreteType<PyF16Type> {
2565 public:
2566   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
2567   static constexpr const char *pyClassName = "F16Type";
2568   using PyConcreteType::PyConcreteType;
2569 
bindDerived(ClassTy & c)2570   static void bindDerived(ClassTy &c) {
2571     c.def_static(
2572         "get",
2573         [](DefaultingPyMlirContext context) {
2574           MlirType t = mlirF16TypeGet(context->get());
2575           return PyF16Type(context->getRef(), t);
2576         },
2577         py::arg("context") = py::none(), "Create a f16 type.");
2578   }
2579 };
2580 
2581 /// Floating Point Type subclass - F32Type.
2582 class PyF32Type : public PyConcreteType<PyF32Type> {
2583 public:
2584   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
2585   static constexpr const char *pyClassName = "F32Type";
2586   using PyConcreteType::PyConcreteType;
2587 
bindDerived(ClassTy & c)2588   static void bindDerived(ClassTy &c) {
2589     c.def_static(
2590         "get",
2591         [](DefaultingPyMlirContext context) {
2592           MlirType t = mlirF32TypeGet(context->get());
2593           return PyF32Type(context->getRef(), t);
2594         },
2595         py::arg("context") = py::none(), "Create a f32 type.");
2596   }
2597 };
2598 
2599 /// Floating Point Type subclass - F64Type.
2600 class PyF64Type : public PyConcreteType<PyF64Type> {
2601 public:
2602   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
2603   static constexpr const char *pyClassName = "F64Type";
2604   using PyConcreteType::PyConcreteType;
2605 
bindDerived(ClassTy & c)2606   static void bindDerived(ClassTy &c) {
2607     c.def_static(
2608         "get",
2609         [](DefaultingPyMlirContext context) {
2610           MlirType t = mlirF64TypeGet(context->get());
2611           return PyF64Type(context->getRef(), t);
2612         },
2613         py::arg("context") = py::none(), "Create a f64 type.");
2614   }
2615 };
2616 
2617 /// None Type subclass - NoneType.
2618 class PyNoneType : public PyConcreteType<PyNoneType> {
2619 public:
2620   static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
2621   static constexpr const char *pyClassName = "NoneType";
2622   using PyConcreteType::PyConcreteType;
2623 
bindDerived(ClassTy & c)2624   static void bindDerived(ClassTy &c) {
2625     c.def_static(
2626         "get",
2627         [](DefaultingPyMlirContext context) {
2628           MlirType t = mlirNoneTypeGet(context->get());
2629           return PyNoneType(context->getRef(), t);
2630         },
2631         py::arg("context") = py::none(), "Create a none type.");
2632   }
2633 };
2634 
2635 /// Complex Type subclass - ComplexType.
2636 class PyComplexType : public PyConcreteType<PyComplexType> {
2637 public:
2638   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
2639   static constexpr const char *pyClassName = "ComplexType";
2640   using PyConcreteType::PyConcreteType;
2641 
bindDerived(ClassTy & c)2642   static void bindDerived(ClassTy &c) {
2643     c.def_static(
2644         "get",
2645         [](PyType &elementType) {
2646           // The element must be a floating point or integer scalar type.
2647           if (mlirTypeIsAIntegerOrFloat(elementType)) {
2648             MlirType t = mlirComplexTypeGet(elementType);
2649             return PyComplexType(elementType.getContext(), t);
2650           }
2651           throw SetPyError(
2652               PyExc_ValueError,
2653               Twine("invalid '") +
2654                   py::repr(py::cast(elementType)).cast<std::string>() +
2655                   "' and expected floating point or integer type.");
2656         },
2657         "Create a complex type");
2658     c.def_property_readonly(
2659         "element_type",
2660         [](PyComplexType &self) -> PyType {
2661           MlirType t = mlirComplexTypeGetElementType(self);
2662           return PyType(self.getContext(), t);
2663         },
2664         "Returns element type.");
2665   }
2666 };
2667 
2668 class PyShapedType : public PyConcreteType<PyShapedType> {
2669 public:
2670   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
2671   static constexpr const char *pyClassName = "ShapedType";
2672   using PyConcreteType::PyConcreteType;
2673 
bindDerived(ClassTy & c)2674   static void bindDerived(ClassTy &c) {
2675     c.def_property_readonly(
2676         "element_type",
2677         [](PyShapedType &self) {
2678           MlirType t = mlirShapedTypeGetElementType(self);
2679           return PyType(self.getContext(), t);
2680         },
2681         "Returns the element type of the shaped type.");
2682     c.def_property_readonly(
2683         "has_rank",
2684         [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
2685         "Returns whether the given shaped type is ranked.");
2686     c.def_property_readonly(
2687         "rank",
2688         [](PyShapedType &self) {
2689           self.requireHasRank();
2690           return mlirShapedTypeGetRank(self);
2691         },
2692         "Returns the rank of the given ranked shaped type.");
2693     c.def_property_readonly(
2694         "has_static_shape",
2695         [](PyShapedType &self) -> bool {
2696           return mlirShapedTypeHasStaticShape(self);
2697         },
2698         "Returns whether the given shaped type has a static shape.");
2699     c.def(
2700         "is_dynamic_dim",
2701         [](PyShapedType &self, intptr_t dim) -> bool {
2702           self.requireHasRank();
2703           return mlirShapedTypeIsDynamicDim(self, dim);
2704         },
2705         "Returns whether the dim-th dimension of the given shaped type is "
2706         "dynamic.");
2707     c.def(
2708         "get_dim_size",
2709         [](PyShapedType &self, intptr_t dim) {
2710           self.requireHasRank();
2711           return mlirShapedTypeGetDimSize(self, dim);
2712         },
2713         "Returns the dim-th dimension of the given ranked shaped type.");
2714     c.def_static(
2715         "is_dynamic_size",
2716         [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
2717         "Returns whether the given dimension size indicates a dynamic "
2718         "dimension.");
2719     c.def(
2720         "is_dynamic_stride_or_offset",
2721         [](PyShapedType &self, int64_t val) -> bool {
2722           self.requireHasRank();
2723           return mlirShapedTypeIsDynamicStrideOrOffset(val);
2724         },
2725         "Returns whether the given value is used as a placeholder for dynamic "
2726         "strides and offsets in shaped types.");
2727   }
2728 
2729 private:
requireHasRank()2730   void requireHasRank() {
2731     if (!mlirShapedTypeHasRank(*this)) {
2732       throw SetPyError(
2733           PyExc_ValueError,
2734           "calling this method requires that the type has a rank.");
2735     }
2736   }
2737 };
2738 
2739 /// Vector Type subclass - VectorType.
2740 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
2741 public:
2742   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
2743   static constexpr const char *pyClassName = "VectorType";
2744   using PyConcreteType::PyConcreteType;
2745 
bindDerived(ClassTy & c)2746   static void bindDerived(ClassTy &c) {
2747     c.def_static(
2748         "get",
2749         [](std::vector<int64_t> shape, PyType &elementType,
2750            DefaultingPyLocation loc) {
2751           MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
2752                                                 elementType, loc);
2753           // TODO: Rework error reporting once diagnostic engine is exposed
2754           // in C API.
2755           if (mlirTypeIsNull(t)) {
2756             throw SetPyError(
2757                 PyExc_ValueError,
2758                 Twine("invalid '") +
2759                     py::repr(py::cast(elementType)).cast<std::string>() +
2760                     "' and expected floating point or integer type.");
2761           }
2762           return PyVectorType(elementType.getContext(), t);
2763         },
2764         py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
2765         "Create a vector type");
2766   }
2767 };
2768 
2769 /// Ranked Tensor Type subclass - RankedTensorType.
2770 class PyRankedTensorType
2771     : public PyConcreteType<PyRankedTensorType, PyShapedType> {
2772 public:
2773   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2774   static constexpr const char *pyClassName = "RankedTensorType";
2775   using PyConcreteType::PyConcreteType;
2776 
bindDerived(ClassTy & c)2777   static void bindDerived(ClassTy &c) {
2778     c.def_static(
2779         "get",
2780         [](std::vector<int64_t> shape, PyType &elementType,
2781            DefaultingPyLocation loc) {
2782           MlirType t = mlirRankedTensorTypeGetChecked(
2783               shape.size(), shape.data(), elementType, loc);
2784           // TODO: Rework error reporting once diagnostic engine is exposed
2785           // in C API.
2786           if (mlirTypeIsNull(t)) {
2787             throw SetPyError(
2788                 PyExc_ValueError,
2789                 Twine("invalid '") +
2790                     py::repr(py::cast(elementType)).cast<std::string>() +
2791                     "' and expected floating point, integer, vector or "
2792                     "complex "
2793                     "type.");
2794           }
2795           return PyRankedTensorType(elementType.getContext(), t);
2796         },
2797         py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
2798         "Create a ranked tensor type");
2799   }
2800 };
2801 
2802 /// Unranked Tensor Type subclass - UnrankedTensorType.
2803 class PyUnrankedTensorType
2804     : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
2805 public:
2806   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
2807   static constexpr const char *pyClassName = "UnrankedTensorType";
2808   using PyConcreteType::PyConcreteType;
2809 
bindDerived(ClassTy & c)2810   static void bindDerived(ClassTy &c) {
2811     c.def_static(
2812         "get",
2813         [](PyType &elementType, DefaultingPyLocation loc) {
2814           MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
2815           // TODO: Rework error reporting once diagnostic engine is exposed
2816           // in C API.
2817           if (mlirTypeIsNull(t)) {
2818             throw SetPyError(
2819                 PyExc_ValueError,
2820                 Twine("invalid '") +
2821                     py::repr(py::cast(elementType)).cast<std::string>() +
2822                     "' and expected floating point, integer, vector or "
2823                     "complex "
2824                     "type.");
2825           }
2826           return PyUnrankedTensorType(elementType.getContext(), t);
2827         },
2828         py::arg("element_type"), py::arg("loc") = py::none(),
2829         "Create a unranked tensor type");
2830   }
2831 };
2832 
2833 class PyMemRefLayoutMapList;
2834 
2835 /// Ranked MemRef Type subclass - MemRefType.
2836 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
2837 public:
2838   static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2839   static constexpr const char *pyClassName = "MemRefType";
2840   using PyConcreteType::PyConcreteType;
2841 
2842   PyMemRefLayoutMapList getLayout();
2843 
bindDerived(ClassTy & c)2844   static void bindDerived(ClassTy &c) {
2845     c.def_static(
2846          "get",
2847          [](std::vector<int64_t> shape, PyType &elementType,
2848             std::vector<PyAffineMap> layout, unsigned memorySpace,
2849             DefaultingPyLocation loc) {
2850            SmallVector<MlirAffineMap> maps;
2851            maps.reserve(layout.size());
2852            for (PyAffineMap &map : layout)
2853              maps.push_back(map);
2854 
2855            MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
2856                                                  shape.data(), maps.size(),
2857                                                  maps.data(), memorySpace, loc);
2858            // TODO: Rework error reporting once diagnostic engine is exposed
2859            // in C API.
2860            if (mlirTypeIsNull(t)) {
2861              throw SetPyError(
2862                  PyExc_ValueError,
2863                  Twine("invalid '") +
2864                      py::repr(py::cast(elementType)).cast<std::string>() +
2865                      "' and expected floating point, integer, vector or "
2866                      "complex "
2867                      "type.");
2868            }
2869            return PyMemRefType(elementType.getContext(), t);
2870          },
2871          py::arg("shape"), py::arg("element_type"),
2872          py::arg("layout") = py::list(), py::arg("memory_space") = 0,
2873          py::arg("loc") = py::none(), "Create a memref type")
2874         .def_property_readonly("layout", &PyMemRefType::getLayout,
2875                                "The list of layout maps of the MemRef type.")
2876         .def_property_readonly(
2877             "memory_space",
2878             [](PyMemRefType &self) -> unsigned {
2879               return mlirMemRefTypeGetMemorySpace(self);
2880             },
2881             "Returns the memory space of the given MemRef type.");
2882   }
2883 };
2884 
2885 /// A list of affine layout maps in a memref type. Internally, these are stored
2886 /// as consecutive elements, random access is cheap. Both the type and the maps
2887 /// are owned by the context, no need to worry about lifetime extension.
2888 class PyMemRefLayoutMapList
2889     : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
2890 public:
2891   static constexpr const char *pyClassName = "MemRefLayoutMapList";
2892 
PyMemRefLayoutMapList(PyMemRefType type,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2893   PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
2894                         intptr_t length = -1, intptr_t step = 1)
2895       : Sliceable(startIndex,
2896                   length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
2897                   step),
2898         memref(type) {}
2899 
getNumElements()2900   intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
2901 
getElement(intptr_t index)2902   PyAffineMap getElement(intptr_t index) {
2903     return PyAffineMap(memref.getContext(),
2904                        mlirMemRefTypeGetAffineMap(memref, index));
2905   }
2906 
slice(intptr_t startIndex,intptr_t length,intptr_t step)2907   PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
2908                               intptr_t step) {
2909     return PyMemRefLayoutMapList(memref, startIndex, length, step);
2910   }
2911 
2912 private:
2913   PyMemRefType memref;
2914 };
2915 
getLayout()2916 PyMemRefLayoutMapList PyMemRefType::getLayout() {
2917   return PyMemRefLayoutMapList(*this);
2918 }
2919 
2920 /// Unranked MemRef Type subclass - UnrankedMemRefType.
2921 class PyUnrankedMemRefType
2922     : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
2923 public:
2924   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
2925   static constexpr const char *pyClassName = "UnrankedMemRefType";
2926   using PyConcreteType::PyConcreteType;
2927 
bindDerived(ClassTy & c)2928   static void bindDerived(ClassTy &c) {
2929     c.def_static(
2930          "get",
2931          [](PyType &elementType, unsigned memorySpace,
2932             DefaultingPyLocation loc) {
2933            MlirType t =
2934                mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
2935            // TODO: Rework error reporting once diagnostic engine is exposed
2936            // in C API.
2937            if (mlirTypeIsNull(t)) {
2938              throw SetPyError(
2939                  PyExc_ValueError,
2940                  Twine("invalid '") +
2941                      py::repr(py::cast(elementType)).cast<std::string>() +
2942                      "' and expected floating point, integer, vector or "
2943                      "complex "
2944                      "type.");
2945            }
2946            return PyUnrankedMemRefType(elementType.getContext(), t);
2947          },
2948          py::arg("element_type"), py::arg("memory_space"),
2949          py::arg("loc") = py::none(), "Create a unranked memref type")
2950         .def_property_readonly(
2951             "memory_space",
2952             [](PyUnrankedMemRefType &self) -> unsigned {
2953               return mlirUnrankedMemrefGetMemorySpace(self);
2954             },
2955             "Returns the memory space of the given Unranked MemRef type.");
2956   }
2957 };
2958 
2959 /// Tuple Type subclass - TupleType.
2960 class PyTupleType : public PyConcreteType<PyTupleType> {
2961 public:
2962   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
2963   static constexpr const char *pyClassName = "TupleType";
2964   using PyConcreteType::PyConcreteType;
2965 
bindDerived(ClassTy & c)2966   static void bindDerived(ClassTy &c) {
2967     c.def_static(
2968         "get_tuple",
2969         [](py::list elementList, DefaultingPyMlirContext context) {
2970           intptr_t num = py::len(elementList);
2971           // Mapping py::list to SmallVector.
2972           SmallVector<MlirType, 4> elements;
2973           for (auto element : elementList)
2974             elements.push_back(element.cast<PyType>());
2975           MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
2976           return PyTupleType(context->getRef(), t);
2977         },
2978         py::arg("elements"), py::arg("context") = py::none(),
2979         "Create a tuple type");
2980     c.def(
2981         "get_type",
2982         [](PyTupleType &self, intptr_t pos) -> PyType {
2983           MlirType t = mlirTupleTypeGetType(self, pos);
2984           return PyType(self.getContext(), t);
2985         },
2986         "Returns the pos-th type in the tuple type.");
2987     c.def_property_readonly(
2988         "num_types",
2989         [](PyTupleType &self) -> intptr_t {
2990           return mlirTupleTypeGetNumTypes(self);
2991         },
2992         "Returns the number of types contained in a tuple.");
2993   }
2994 };
2995 
2996 /// Function type.
2997 class PyFunctionType : public PyConcreteType<PyFunctionType> {
2998 public:
2999   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
3000   static constexpr const char *pyClassName = "FunctionType";
3001   using PyConcreteType::PyConcreteType;
3002 
bindDerived(ClassTy & c)3003   static void bindDerived(ClassTy &c) {
3004     c.def_static(
3005         "get",
3006         [](std::vector<PyType> inputs, std::vector<PyType> results,
3007            DefaultingPyMlirContext context) {
3008           SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
3009           SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
3010           MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
3011                                            inputsRaw.data(), resultsRaw.size(),
3012                                            resultsRaw.data());
3013           return PyFunctionType(context->getRef(), t);
3014         },
3015         py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
3016         "Gets a FunctionType from a list of input and result types");
3017     c.def_property_readonly(
3018         "inputs",
3019         [](PyFunctionType &self) {
3020           MlirType t = self;
3021           auto contextRef = self.getContext();
3022           py::list types;
3023           for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
3024                ++i) {
3025             types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
3026           }
3027           return types;
3028         },
3029         "Returns the list of input types in the FunctionType.");
3030     c.def_property_readonly(
3031         "results",
3032         [](PyFunctionType &self) {
3033           auto contextRef = self.getContext();
3034           py::list types;
3035           for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
3036                ++i) {
3037             types.append(
3038                 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
3039           }
3040           return types;
3041         },
3042         "Returns the list of result types in the FunctionType.");
3043   }
3044 };
3045 
3046 } // namespace
3047 
3048 //------------------------------------------------------------------------------
3049 // PyAffineExpr and subclasses.
3050 //------------------------------------------------------------------------------
3051 
3052 namespace {
3053 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
3054 /// and should be castable from it. Intermediate hierarchy classes can be
3055 /// modeled by specifying BaseTy.
3056 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
3057 class PyConcreteAffineExpr : public BaseTy {
3058 public:
3059   // Derived classes must define statics for:
3060   //   IsAFunctionTy isaFunction
3061   //   const char *pyClassName
3062   // and redefine bindDerived.
3063   using ClassTy = py::class_<DerivedTy, BaseTy>;
3064   using IsAFunctionTy = bool (*)(MlirAffineExpr);
3065 
3066   PyConcreteAffineExpr() = default;
PyConcreteAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)3067   PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
3068       : BaseTy(std::move(contextRef), affineExpr) {}
PyConcreteAffineExpr(PyAffineExpr & orig)3069   PyConcreteAffineExpr(PyAffineExpr &orig)
3070       : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
3071 
castFrom(PyAffineExpr & orig)3072   static MlirAffineExpr castFrom(PyAffineExpr &orig) {
3073     if (!DerivedTy::isaFunction(orig)) {
3074       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
3075       throw SetPyError(PyExc_ValueError,
3076                        Twine("Cannot cast affine expression to ") +
3077                            DerivedTy::pyClassName + " (from " + origRepr + ")");
3078     }
3079     return orig;
3080   }
3081 
bind(py::module & m)3082   static void bind(py::module &m) {
3083     auto cls = ClassTy(m, DerivedTy::pyClassName);
3084     cls.def(py::init<PyAffineExpr &>());
3085     DerivedTy::bindDerived(cls);
3086   }
3087 
3088   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)3089   static void bindDerived(ClassTy &m) {}
3090 };
3091 
3092 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
3093 public:
3094   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
3095   static constexpr const char *pyClassName = "AffineConstantExpr";
3096   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3097 
get(intptr_t value,DefaultingPyMlirContext context)3098   static PyAffineConstantExpr get(intptr_t value,
3099                                   DefaultingPyMlirContext context) {
3100     MlirAffineExpr affineExpr =
3101         mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
3102     return PyAffineConstantExpr(context->getRef(), affineExpr);
3103   }
3104 
bindDerived(ClassTy & c)3105   static void bindDerived(ClassTy &c) {
3106     c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
3107                  py::arg("context") = py::none());
3108     c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
3109       return mlirAffineConstantExprGetValue(self);
3110     });
3111   }
3112 };
3113 
3114 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
3115 public:
3116   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
3117   static constexpr const char *pyClassName = "AffineDimExpr";
3118   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3119 
get(intptr_t pos,DefaultingPyMlirContext context)3120   static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
3121     MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
3122     return PyAffineDimExpr(context->getRef(), affineExpr);
3123   }
3124 
bindDerived(ClassTy & c)3125   static void bindDerived(ClassTy &c) {
3126     c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
3127                  py::arg("context") = py::none());
3128     c.def_property_readonly("position", [](PyAffineDimExpr &self) {
3129       return mlirAffineDimExprGetPosition(self);
3130     });
3131   }
3132 };
3133 
3134 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
3135 public:
3136   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
3137   static constexpr const char *pyClassName = "AffineSymbolExpr";
3138   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3139 
get(intptr_t pos,DefaultingPyMlirContext context)3140   static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
3141     MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
3142     return PyAffineSymbolExpr(context->getRef(), affineExpr);
3143   }
3144 
bindDerived(ClassTy & c)3145   static void bindDerived(ClassTy &c) {
3146     c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
3147                  py::arg("context") = py::none());
3148     c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
3149       return mlirAffineSymbolExprGetPosition(self);
3150     });
3151   }
3152 };
3153 
3154 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
3155 public:
3156   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
3157   static constexpr const char *pyClassName = "AffineBinaryExpr";
3158   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3159 
lhs()3160   PyAffineExpr lhs() {
3161     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
3162     return PyAffineExpr(getContext(), lhsExpr);
3163   }
3164 
rhs()3165   PyAffineExpr rhs() {
3166     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
3167     return PyAffineExpr(getContext(), rhsExpr);
3168   }
3169 
bindDerived(ClassTy & c)3170   static void bindDerived(ClassTy &c) {
3171     c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
3172     c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
3173   }
3174 };
3175 
3176 class PyAffineAddExpr
3177     : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
3178 public:
3179   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
3180   static constexpr const char *pyClassName = "AffineAddExpr";
3181   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3182 
get(PyAffineExpr lhs,PyAffineExpr rhs)3183   static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3184     MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
3185     return PyAffineAddExpr(lhs.getContext(), expr);
3186   }
3187 
bindDerived(ClassTy & c)3188   static void bindDerived(ClassTy &c) {
3189     c.def_static("get", &PyAffineAddExpr::get);
3190   }
3191 };
3192 
3193 class PyAffineMulExpr
3194     : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
3195 public:
3196   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
3197   static constexpr const char *pyClassName = "AffineMulExpr";
3198   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3199 
get(PyAffineExpr lhs,PyAffineExpr rhs)3200   static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3201     MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
3202     return PyAffineMulExpr(lhs.getContext(), expr);
3203   }
3204 
bindDerived(ClassTy & c)3205   static void bindDerived(ClassTy &c) {
3206     c.def_static("get", &PyAffineMulExpr::get);
3207   }
3208 };
3209 
3210 class PyAffineModExpr
3211     : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
3212 public:
3213   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
3214   static constexpr const char *pyClassName = "AffineModExpr";
3215   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3216 
get(PyAffineExpr lhs,PyAffineExpr rhs)3217   static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3218     MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
3219     return PyAffineModExpr(lhs.getContext(), expr);
3220   }
3221 
bindDerived(ClassTy & c)3222   static void bindDerived(ClassTy &c) {
3223     c.def_static("get", &PyAffineModExpr::get);
3224   }
3225 };
3226 
3227 class PyAffineFloorDivExpr
3228     : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
3229 public:
3230   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
3231   static constexpr const char *pyClassName = "AffineFloorDivExpr";
3232   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3233 
get(PyAffineExpr lhs,PyAffineExpr rhs)3234   static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3235     MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
3236     return PyAffineFloorDivExpr(lhs.getContext(), expr);
3237   }
3238 
bindDerived(ClassTy & c)3239   static void bindDerived(ClassTy &c) {
3240     c.def_static("get", &PyAffineFloorDivExpr::get);
3241   }
3242 };
3243 
3244 class PyAffineCeilDivExpr
3245     : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
3246 public:
3247   static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
3248   static constexpr const char *pyClassName = "AffineCeilDivExpr";
3249   using PyConcreteAffineExpr::PyConcreteAffineExpr;
3250 
get(PyAffineExpr lhs,PyAffineExpr rhs)3251   static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3252     MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
3253     return PyAffineCeilDivExpr(lhs.getContext(), expr);
3254   }
3255 
bindDerived(ClassTy & c)3256   static void bindDerived(ClassTy &c) {
3257     c.def_static("get", &PyAffineCeilDivExpr::get);
3258   }
3259 };
3260 } // namespace
3261 
operator ==(const PyAffineExpr & other)3262 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
3263   return mlirAffineExprEqual(affineExpr, other.affineExpr);
3264 }
3265 
getCapsule()3266 py::object PyAffineExpr::getCapsule() {
3267   return py::reinterpret_steal<py::object>(
3268       mlirPythonAffineExprToCapsule(*this));
3269 }
3270 
createFromCapsule(py::object capsule)3271 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
3272   MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
3273   if (mlirAffineExprIsNull(rawAffineExpr))
3274     throw py::error_already_set();
3275   return PyAffineExpr(
3276       PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
3277       rawAffineExpr);
3278 }
3279 
3280 //------------------------------------------------------------------------------
3281 // PyAffineMap and utilities.
3282 //------------------------------------------------------------------------------
3283 
3284 namespace {
3285 /// A list of expressions contained in an affine map. Internally these are
3286 /// stored as a consecutive array leading to inexpensive random access. Both
3287 /// the map and the expression are owned by the context so we need not bother
3288 /// with lifetime extension.
3289 class PyAffineMapExprList
3290     : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
3291 public:
3292   static constexpr const char *pyClassName = "AffineExprList";
3293 
PyAffineMapExprList(PyAffineMap map,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)3294   PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
3295                       intptr_t length = -1, intptr_t step = 1)
3296       : Sliceable(startIndex,
3297                   length == -1 ? mlirAffineMapGetNumResults(map) : length,
3298                   step),
3299         affineMap(map) {}
3300 
getNumElements()3301   intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
3302 
getElement(intptr_t pos)3303   PyAffineExpr getElement(intptr_t pos) {
3304     return PyAffineExpr(affineMap.getContext(),
3305                         mlirAffineMapGetResult(affineMap, pos));
3306   }
3307 
slice(intptr_t startIndex,intptr_t length,intptr_t step)3308   PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
3309                             intptr_t step) {
3310     return PyAffineMapExprList(affineMap, startIndex, length, step);
3311   }
3312 
3313 private:
3314   PyAffineMap affineMap;
3315 };
3316 } // end namespace
3317 
operator ==(const PyAffineMap & other)3318 bool PyAffineMap::operator==(const PyAffineMap &other) {
3319   return mlirAffineMapEqual(affineMap, other.affineMap);
3320 }
3321 
getCapsule()3322 py::object PyAffineMap::getCapsule() {
3323   return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
3324 }
3325 
createFromCapsule(py::object capsule)3326 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
3327   MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
3328   if (mlirAffineMapIsNull(rawAffineMap))
3329     throw py::error_already_set();
3330   return PyAffineMap(
3331       PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
3332       rawAffineMap);
3333 }
3334 
3335 //------------------------------------------------------------------------------
3336 // PyIntegerSet and utilities.
3337 //------------------------------------------------------------------------------
3338 
3339 class PyIntegerSetConstraint {
3340 public:
PyIntegerSetConstraint(PyIntegerSet set,intptr_t pos)3341   PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
3342 
getExpr()3343   PyAffineExpr getExpr() {
3344     return PyAffineExpr(set.getContext(),
3345                         mlirIntegerSetGetConstraint(set, pos));
3346   }
3347 
isEq()3348   bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
3349 
bind(py::module & m)3350   static void bind(py::module &m) {
3351     py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
3352         .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
3353         .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
3354   }
3355 
3356 private:
3357   PyIntegerSet set;
3358   intptr_t pos;
3359 };
3360 
3361 class PyIntegerSetConstraintList
3362     : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
3363 public:
3364   static constexpr const char *pyClassName = "IntegerSetConstraintList";
3365 
PyIntegerSetConstraintList(PyIntegerSet set,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)3366   PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
3367                              intptr_t length = -1, intptr_t step = 1)
3368       : Sliceable(startIndex,
3369                   length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
3370                   step),
3371         set(set) {}
3372 
getNumElements()3373   intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
3374 
getElement(intptr_t pos)3375   PyIntegerSetConstraint getElement(intptr_t pos) {
3376     return PyIntegerSetConstraint(set, pos);
3377   }
3378 
slice(intptr_t startIndex,intptr_t length,intptr_t step)3379   PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
3380                                    intptr_t step) {
3381     return PyIntegerSetConstraintList(set, startIndex, length, step);
3382   }
3383 
3384 private:
3385   PyIntegerSet set;
3386 };
3387 
operator ==(const PyIntegerSet & other)3388 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
3389   return mlirIntegerSetEqual(integerSet, other.integerSet);
3390 }
3391 
getCapsule()3392 py::object PyIntegerSet::getCapsule() {
3393   return py::reinterpret_steal<py::object>(
3394       mlirPythonIntegerSetToCapsule(*this));
3395 }
3396 
createFromCapsule(py::object capsule)3397 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
3398   MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
3399   if (mlirIntegerSetIsNull(rawIntegerSet))
3400     throw py::error_already_set();
3401   return PyIntegerSet(
3402       PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
3403       rawIntegerSet);
3404 }
3405 
3406 /// Attempts to populate `result` with the content of `list` casted to the
3407 /// appropriate type (Python and C types are provided as template arguments).
3408 /// Throws errors in case of failure, using "action" to describe what the caller
3409 /// was attempting to do.
3410 template <typename PyType, typename CType>
pyListToVector(py::list list,llvm::SmallVectorImpl<CType> & result,StringRef action)3411 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
3412                            StringRef action) {
3413   result.reserve(py::len(list));
3414   for (py::handle item : list) {
3415     try {
3416       result.push_back(item.cast<PyType>());
3417     } catch (py::cast_error &err) {
3418       std::string msg = (llvm::Twine("Invalid expression when ") + action +
3419                          " (" + err.what() + ")")
3420                             .str();
3421       throw py::cast_error(msg);
3422     } catch (py::reference_cast_error &err) {
3423       std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
3424                          action + " (" + err.what() + ")")
3425                             .str();
3426       throw py::cast_error(msg);
3427     }
3428   }
3429 }
3430 
3431 //------------------------------------------------------------------------------
3432 // Populates the pybind11 IR submodule.
3433 //------------------------------------------------------------------------------
3434 
populateIRSubmodule(py::module & m)3435 void mlir::python::populateIRSubmodule(py::module &m) {
3436   //----------------------------------------------------------------------------
3437   // Mapping of MlirContext
3438   //----------------------------------------------------------------------------
3439   py::class_<PyMlirContext>(m, "Context")
3440       .def(py::init<>(&PyMlirContext::createNewContextForInit))
3441       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
3442       .def("_get_context_again",
3443            [](PyMlirContext &self) {
3444              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
3445              return ref.releaseObject();
3446            })
3447       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
3448       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
3449       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3450                              &PyMlirContext::getCapsule)
3451       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
3452       .def("__enter__", &PyMlirContext::contextEnter)
3453       .def("__exit__", &PyMlirContext::contextExit)
3454       .def_property_readonly_static(
3455           "current",
3456           [](py::object & /*class*/) {
3457             auto *context = PyThreadContextEntry::getDefaultContext();
3458             if (!context)
3459               throw SetPyError(PyExc_ValueError, "No current Context");
3460             return context;
3461           },
3462           "Gets the Context bound to the current thread or raises ValueError")
3463       .def_property_readonly(
3464           "dialects",
3465           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3466           "Gets a container for accessing dialects by name")
3467       .def_property_readonly(
3468           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3469           "Alias for 'dialect'")
3470       .def(
3471           "get_dialect_descriptor",
3472           [=](PyMlirContext &self, std::string &name) {
3473             MlirDialect dialect = mlirContextGetOrLoadDialect(
3474                 self.get(), {name.data(), name.size()});
3475             if (mlirDialectIsNull(dialect)) {
3476               throw SetPyError(PyExc_ValueError,
3477                                Twine("Dialect '") + name + "' not found");
3478             }
3479             return PyDialectDescriptor(self.getRef(), dialect);
3480           },
3481           "Gets or loads a dialect by name, returning its descriptor object")
3482       .def_property(
3483           "allow_unregistered_dialects",
3484           [](PyMlirContext &self) -> bool {
3485             return mlirContextGetAllowUnregisteredDialects(self.get());
3486           },
3487           [](PyMlirContext &self, bool value) {
3488             mlirContextSetAllowUnregisteredDialects(self.get(), value);
3489           });
3490 
3491   //----------------------------------------------------------------------------
3492   // Mapping of PyDialectDescriptor
3493   //----------------------------------------------------------------------------
3494   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3495       .def_property_readonly("namespace",
3496                              [](PyDialectDescriptor &self) {
3497                                MlirStringRef ns =
3498                                    mlirDialectGetNamespace(self.get());
3499                                return py::str(ns.data, ns.length);
3500                              })
3501       .def("__repr__", [](PyDialectDescriptor &self) {
3502         MlirStringRef ns = mlirDialectGetNamespace(self.get());
3503         std::string repr("<DialectDescriptor ");
3504         repr.append(ns.data, ns.length);
3505         repr.append(">");
3506         return repr;
3507       });
3508 
3509   //----------------------------------------------------------------------------
3510   // Mapping of PyDialects
3511   //----------------------------------------------------------------------------
3512   py::class_<PyDialects>(m, "Dialects")
3513       .def("__getitem__",
3514            [=](PyDialects &self, std::string keyName) {
3515              MlirDialect dialect =
3516                  self.getDialectForKey(keyName, /*attrError=*/false);
3517              py::object descriptor =
3518                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
3519              return createCustomDialectWrapper(keyName, std::move(descriptor));
3520            })
3521       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
3522         MlirDialect dialect =
3523             self.getDialectForKey(attrName, /*attrError=*/true);
3524         py::object descriptor =
3525             py::cast(PyDialectDescriptor{self.getContext(), dialect});
3526         return createCustomDialectWrapper(attrName, std::move(descriptor));
3527       });
3528 
3529   //----------------------------------------------------------------------------
3530   // Mapping of PyDialect
3531   //----------------------------------------------------------------------------
3532   py::class_<PyDialect>(m, "Dialect")
3533       .def(py::init<py::object>(), "descriptor")
3534       .def_property_readonly(
3535           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
3536       .def("__repr__", [](py::object self) {
3537         auto clazz = self.attr("__class__");
3538         return py::str("<Dialect ") +
3539                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
3540                clazz.attr("__module__") + py::str(".") +
3541                clazz.attr("__name__") + py::str(")>");
3542       });
3543 
3544   //----------------------------------------------------------------------------
3545   // Mapping of Location
3546   //----------------------------------------------------------------------------
3547   py::class_<PyLocation>(m, "Location")
3548       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
3549       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
3550       .def("__enter__", &PyLocation::contextEnter)
3551       .def("__exit__", &PyLocation::contextExit)
3552       .def("__eq__",
3553            [](PyLocation &self, PyLocation &other) -> bool {
3554              return mlirLocationEqual(self, other);
3555            })
3556       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
3557       .def_property_readonly_static(
3558           "current",
3559           [](py::object & /*class*/) {
3560             auto *loc = PyThreadContextEntry::getDefaultLocation();
3561             if (!loc)
3562               throw SetPyError(PyExc_ValueError, "No current Location");
3563             return loc;
3564           },
3565           "Gets the Location bound to the current thread or raises ValueError")
3566       .def_static(
3567           "unknown",
3568           [](DefaultingPyMlirContext context) {
3569             return PyLocation(context->getRef(),
3570                               mlirLocationUnknownGet(context->get()));
3571           },
3572           py::arg("context") = py::none(),
3573           "Gets a Location representing an unknown location")
3574       .def_static(
3575           "file",
3576           [](std::string filename, int line, int col,
3577              DefaultingPyMlirContext context) {
3578             return PyLocation(
3579                 context->getRef(),
3580                 mlirLocationFileLineColGet(
3581                     context->get(), toMlirStringRef(filename), line, col));
3582           },
3583           py::arg("filename"), py::arg("line"), py::arg("col"),
3584           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
3585       .def_property_readonly(
3586           "context",
3587           [](PyLocation &self) { return self.getContext().getObject(); },
3588           "Context that owns the Location")
3589       .def("__repr__", [](PyLocation &self) {
3590         PyPrintAccumulator printAccum;
3591         mlirLocationPrint(self, printAccum.getCallback(),
3592                           printAccum.getUserData());
3593         return printAccum.join();
3594       });
3595 
3596   //----------------------------------------------------------------------------
3597   // Mapping of Module
3598   //----------------------------------------------------------------------------
3599   py::class_<PyModule>(m, "Module")
3600       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3601       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3602       .def_static(
3603           "parse",
3604           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
3605             MlirModule module = mlirModuleCreateParse(
3606                 context->get(), toMlirStringRef(moduleAsm));
3607             // TODO: Rework error reporting once diagnostic engine is exposed
3608             // in C API.
3609             if (mlirModuleIsNull(module)) {
3610               throw SetPyError(
3611                   PyExc_ValueError,
3612                   "Unable to parse module assembly (see diagnostics)");
3613             }
3614             return PyModule::forModule(module).releaseObject();
3615           },
3616           py::arg("asm"), py::arg("context") = py::none(),
3617           kModuleParseDocstring)
3618       .def_static(
3619           "create",
3620           [](DefaultingPyLocation loc) {
3621             MlirModule module = mlirModuleCreateEmpty(loc);
3622             return PyModule::forModule(module).releaseObject();
3623           },
3624           py::arg("loc") = py::none(), "Creates an empty module")
3625       .def_property_readonly(
3626           "context",
3627           [](PyModule &self) { return self.getContext().getObject(); },
3628           "Context that created the Module")
3629       .def_property_readonly(
3630           "operation",
3631           [](PyModule &self) {
3632             return PyOperation::forOperation(self.getContext(),
3633                                              mlirModuleGetOperation(self.get()),
3634                                              self.getRef().releaseObject())
3635                 .releaseObject();
3636           },
3637           "Accesses the module as an operation")
3638       .def_property_readonly(
3639           "body",
3640           [](PyModule &self) {
3641             PyOperationRef module_op = PyOperation::forOperation(
3642                 self.getContext(), mlirModuleGetOperation(self.get()),
3643                 self.getRef().releaseObject());
3644             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
3645             return returnBlock;
3646           },
3647           "Return the block for this module")
3648       .def(
3649           "dump",
3650           [](PyModule &self) {
3651             mlirOperationDump(mlirModuleGetOperation(self.get()));
3652           },
3653           kDumpDocstring)
3654       .def(
3655           "__str__",
3656           [](PyModule &self) {
3657             MlirOperation operation = mlirModuleGetOperation(self.get());
3658             PyPrintAccumulator printAccum;
3659             mlirOperationPrint(operation, printAccum.getCallback(),
3660                                printAccum.getUserData());
3661             return printAccum.join();
3662           },
3663           kOperationStrDunderDocstring);
3664 
3665   //----------------------------------------------------------------------------
3666   // Mapping of Operation.
3667   //----------------------------------------------------------------------------
3668   py::class_<PyOperationBase>(m, "_OperationBase")
3669       .def("__eq__",
3670            [](PyOperationBase &self, PyOperationBase &other) {
3671              return &self.getOperation() == &other.getOperation();
3672            })
3673       .def("__eq__",
3674            [](PyOperationBase &self, py::object other) { return false; })
3675       .def_property_readonly("attributes",
3676                              [](PyOperationBase &self) {
3677                                return PyOpAttributeMap(
3678                                    self.getOperation().getRef());
3679                              })
3680       .def_property_readonly("operands",
3681                              [](PyOperationBase &self) {
3682                                return PyOpOperandList(
3683                                    self.getOperation().getRef());
3684                              })
3685       .def_property_readonly("regions",
3686                              [](PyOperationBase &self) {
3687                                return PyRegionList(
3688                                    self.getOperation().getRef());
3689                              })
3690       .def_property_readonly(
3691           "results",
3692           [](PyOperationBase &self) {
3693             return PyOpResultList(self.getOperation().getRef());
3694           },
3695           "Returns the list of Operation results.")
3696       .def_property_readonly(
3697           "result",
3698           [](PyOperationBase &self) {
3699             auto &operation = self.getOperation();
3700             auto numResults = mlirOperationGetNumResults(operation);
3701             if (numResults != 1) {
3702               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3703               throw SetPyError(
3704                   PyExc_ValueError,
3705                   Twine("Cannot call .result on operation ") +
3706                       StringRef(name.data, name.length) + " which has " +
3707                       Twine(numResults) +
3708                       " results (it is only valid for operations with a "
3709                       "single result)");
3710             }
3711             return PyOpResult(operation.getRef(),
3712                               mlirOperationGetResult(operation, 0));
3713           },
3714           "Shortcut to get an op result if it has only one (throws an error "
3715           "otherwise).")
3716       .def("__iter__",
3717            [](PyOperationBase &self) {
3718              return PyRegionIterator(self.getOperation().getRef());
3719            })
3720       .def(
3721           "__str__",
3722           [](PyOperationBase &self) {
3723             return self.getAsm(/*binary=*/false,
3724                                /*largeElementsLimit=*/llvm::None,
3725                                /*enableDebugInfo=*/false,
3726                                /*prettyDebugInfo=*/false,
3727                                /*printGenericOpForm=*/false,
3728                                /*useLocalScope=*/false);
3729           },
3730           "Returns the assembly form of the operation.")
3731       .def("print", &PyOperationBase::print,
3732            // Careful: Lots of arguments must match up with print method.
3733            py::arg("file") = py::none(), py::arg("binary") = false,
3734            py::arg("large_elements_limit") = py::none(),
3735            py::arg("enable_debug_info") = false,
3736            py::arg("pretty_debug_info") = false,
3737            py::arg("print_generic_op_form") = false,
3738            py::arg("use_local_scope") = false, kOperationPrintDocstring)
3739       .def("get_asm", &PyOperationBase::getAsm,
3740            // Careful: Lots of arguments must match up with get_asm method.
3741            py::arg("binary") = false,
3742            py::arg("large_elements_limit") = py::none(),
3743            py::arg("enable_debug_info") = false,
3744            py::arg("pretty_debug_info") = false,
3745            py::arg("print_generic_op_form") = false,
3746            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
3747       .def(
3748           "verify",
3749           [](PyOperationBase &self) {
3750             return mlirOperationVerify(self.getOperation());
3751           },
3752           "Verify the operation and return true if it passes, false if it "
3753           "fails.");
3754 
3755   py::class_<PyOperation, PyOperationBase>(m, "Operation")
3756       .def_static("create", &PyOperation::create, py::arg("name"),
3757                   py::arg("results") = py::none(),
3758                   py::arg("operands") = py::none(),
3759                   py::arg("attributes") = py::none(),
3760                   py::arg("successors") = py::none(), py::arg("regions") = 0,
3761                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3762                   kOperationCreateDocstring)
3763       .def_property_readonly("name",
3764                              [](PyOperation &self) {
3765                                MlirOperation operation = self.get();
3766                                MlirStringRef name = mlirIdentifierStr(
3767                                    mlirOperationGetName(operation));
3768                                return py::str(name.data, name.length);
3769                              })
3770       .def_property_readonly(
3771           "context",
3772           [](PyOperation &self) { return self.getContext().getObject(); },
3773           "Context that owns the Operation")
3774       .def_property_readonly("opview", &PyOperation::createOpView);
3775 
3776   auto opViewClass =
3777       py::class_<PyOpView, PyOperationBase>(m, "OpView")
3778           .def(py::init<py::object>())
3779           .def_property_readonly("operation", &PyOpView::getOperationObject)
3780           .def_property_readonly(
3781               "context",
3782               [](PyOpView &self) {
3783                 return self.getOperation().getContext().getObject();
3784               },
3785               "Context that owns the Operation")
3786           .def("__str__", [](PyOpView &self) {
3787             return py::str(self.getOperationObject());
3788           });
3789   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3790   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3791   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3792   opViewClass.attr("build_generic") = classmethod(
3793       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3794       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3795       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3796       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3797       "Builds a specific, generated OpView based on class level attributes.");
3798 
3799   //----------------------------------------------------------------------------
3800   // Mapping of PyRegion.
3801   //----------------------------------------------------------------------------
3802   py::class_<PyRegion>(m, "Region")
3803       .def_property_readonly(
3804           "blocks",
3805           [](PyRegion &self) {
3806             return PyBlockList(self.getParentOperation(), self.get());
3807           },
3808           "Returns a forward-optimized sequence of blocks.")
3809       .def(
3810           "__iter__",
3811           [](PyRegion &self) {
3812             self.checkValid();
3813             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3814             return PyBlockIterator(self.getParentOperation(), firstBlock);
3815           },
3816           "Iterates over blocks in the region.")
3817       .def("__eq__",
3818            [](PyRegion &self, PyRegion &other) {
3819              return self.get().ptr == other.get().ptr;
3820            })
3821       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3822 
3823   //----------------------------------------------------------------------------
3824   // Mapping of PyBlock.
3825   //----------------------------------------------------------------------------
3826   py::class_<PyBlock>(m, "Block")
3827       .def_property_readonly(
3828           "arguments",
3829           [](PyBlock &self) {
3830             return PyBlockArgumentList(self.getParentOperation(), self.get());
3831           },
3832           "Returns a list of block arguments.")
3833       .def_property_readonly(
3834           "operations",
3835           [](PyBlock &self) {
3836             return PyOperationList(self.getParentOperation(), self.get());
3837           },
3838           "Returns a forward-optimized sequence of operations.")
3839       .def(
3840           "__iter__",
3841           [](PyBlock &self) {
3842             self.checkValid();
3843             MlirOperation firstOperation =
3844                 mlirBlockGetFirstOperation(self.get());
3845             return PyOperationIterator(self.getParentOperation(),
3846                                        firstOperation);
3847           },
3848           "Iterates over operations in the block.")
3849       .def("__eq__",
3850            [](PyBlock &self, PyBlock &other) {
3851              return self.get().ptr == other.get().ptr;
3852            })
3853       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3854       .def(
3855           "__str__",
3856           [](PyBlock &self) {
3857             self.checkValid();
3858             PyPrintAccumulator printAccum;
3859             mlirBlockPrint(self.get(), printAccum.getCallback(),
3860                            printAccum.getUserData());
3861             return printAccum.join();
3862           },
3863           "Returns the assembly form of the block.");
3864 
3865   //----------------------------------------------------------------------------
3866   // Mapping of PyInsertionPoint.
3867   //----------------------------------------------------------------------------
3868 
3869   py::class_<PyInsertionPoint>(m, "InsertionPoint")
3870       .def(py::init<PyBlock &>(), py::arg("block"),
3871            "Inserts after the last operation but still inside the block.")
3872       .def("__enter__", &PyInsertionPoint::contextEnter)
3873       .def("__exit__", &PyInsertionPoint::contextExit)
3874       .def_property_readonly_static(
3875           "current",
3876           [](py::object & /*class*/) {
3877             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3878             if (!ip)
3879               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
3880             return ip;
3881           },
3882           "Gets the InsertionPoint bound to the current thread or raises "
3883           "ValueError if none has been set")
3884       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3885            "Inserts before a referenced operation.")
3886       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3887                   py::arg("block"), "Inserts at the beginning of the block.")
3888       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3889                   py::arg("block"), "Inserts before the block terminator.")
3890       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3891            "Inserts an operation.");
3892 
3893   //----------------------------------------------------------------------------
3894   // Mapping of PyAttribute.
3895   //----------------------------------------------------------------------------
3896   py::class_<PyAttribute>(m, "Attribute")
3897       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3898                              &PyAttribute::getCapsule)
3899       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3900       .def_static(
3901           "parse",
3902           [](std::string attrSpec, DefaultingPyMlirContext context) {
3903             MlirAttribute type = mlirAttributeParseGet(
3904                 context->get(), toMlirStringRef(attrSpec));
3905             // TODO: Rework error reporting once diagnostic engine is exposed
3906             // in C API.
3907             if (mlirAttributeIsNull(type)) {
3908               throw SetPyError(PyExc_ValueError,
3909                                Twine("Unable to parse attribute: '") +
3910                                    attrSpec + "'");
3911             }
3912             return PyAttribute(context->getRef(), type);
3913           },
3914           py::arg("asm"), py::arg("context") = py::none(),
3915           "Parses an attribute from an assembly form")
3916       .def_property_readonly(
3917           "context",
3918           [](PyAttribute &self) { return self.getContext().getObject(); },
3919           "Context that owns the Attribute")
3920       .def_property_readonly("type",
3921                              [](PyAttribute &self) {
3922                                return PyType(self.getContext()->getRef(),
3923                                              mlirAttributeGetType(self));
3924                              })
3925       .def(
3926           "get_named",
3927           [](PyAttribute &self, std::string name) {
3928             return PyNamedAttribute(self, std::move(name));
3929           },
3930           py::keep_alive<0, 1>(), "Binds a name to the attribute")
3931       .def("__eq__",
3932            [](PyAttribute &self, PyAttribute &other) { return self == other; })
3933       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3934       .def(
3935           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3936           kDumpDocstring)
3937       .def(
3938           "__str__",
3939           [](PyAttribute &self) {
3940             PyPrintAccumulator printAccum;
3941             mlirAttributePrint(self, printAccum.getCallback(),
3942                                printAccum.getUserData());
3943             return printAccum.join();
3944           },
3945           "Returns the assembly form of the Attribute.")
3946       .def("__repr__", [](PyAttribute &self) {
3947         // Generally, assembly formats are not printed for __repr__ because
3948         // this can cause exceptionally long debug output and exceptions.
3949         // However, attribute values are generally considered useful and are
3950         // printed. This may need to be re-evaluated if debug dumps end up
3951         // being excessive.
3952         PyPrintAccumulator printAccum;
3953         printAccum.parts.append("Attribute(");
3954         mlirAttributePrint(self, printAccum.getCallback(),
3955                            printAccum.getUserData());
3956         printAccum.parts.append(")");
3957         return printAccum.join();
3958       });
3959 
3960   //----------------------------------------------------------------------------
3961   // Mapping of PyNamedAttribute
3962   //----------------------------------------------------------------------------
3963   py::class_<PyNamedAttribute>(m, "NamedAttribute")
3964       .def("__repr__",
3965            [](PyNamedAttribute &self) {
3966              PyPrintAccumulator printAccum;
3967              printAccum.parts.append("NamedAttribute(");
3968              printAccum.parts.append(
3969                  mlirIdentifierStr(self.namedAttr.name).data);
3970              printAccum.parts.append("=");
3971              mlirAttributePrint(self.namedAttr.attribute,
3972                                 printAccum.getCallback(),
3973                                 printAccum.getUserData());
3974              printAccum.parts.append(")");
3975              return printAccum.join();
3976            })
3977       .def_property_readonly(
3978           "name",
3979           [](PyNamedAttribute &self) {
3980             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3981                            mlirIdentifierStr(self.namedAttr.name).length);
3982           },
3983           "The name of the NamedAttribute binding")
3984       .def_property_readonly(
3985           "attr",
3986           [](PyNamedAttribute &self) {
3987             // TODO: When named attribute is removed/refactored, also remove
3988             // this constructor (it does an inefficient table lookup).
3989             auto contextRef = PyMlirContext::forContext(
3990                 mlirAttributeGetContext(self.namedAttr.attribute));
3991             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3992           },
3993           py::keep_alive<0, 1>(),
3994           "The underlying generic attribute of the NamedAttribute binding");
3995 
3996   // Builtin attribute bindings.
3997   PyFloatAttribute::bind(m);
3998   PyArrayAttribute::bind(m);
3999   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
4000   PyIntegerAttribute::bind(m);
4001   PyBoolAttribute::bind(m);
4002   PyFlatSymbolRefAttribute::bind(m);
4003   PyStringAttribute::bind(m);
4004   PyDenseElementsAttribute::bind(m);
4005   PyDenseIntElementsAttribute::bind(m);
4006   PyDenseFPElementsAttribute::bind(m);
4007   PyDictAttribute::bind(m);
4008   PyTypeAttribute::bind(m);
4009   PyUnitAttribute::bind(m);
4010 
4011   //----------------------------------------------------------------------------
4012   // Mapping of PyType.
4013   //----------------------------------------------------------------------------
4014   py::class_<PyType>(m, "Type")
4015       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
4016       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
4017       .def_static(
4018           "parse",
4019           [](std::string typeSpec, DefaultingPyMlirContext context) {
4020             MlirType type =
4021                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4022             // TODO: Rework error reporting once diagnostic engine is exposed
4023             // in C API.
4024             if (mlirTypeIsNull(type)) {
4025               throw SetPyError(PyExc_ValueError,
4026                                Twine("Unable to parse type: '") + typeSpec +
4027                                    "'");
4028             }
4029             return PyType(context->getRef(), type);
4030           },
4031           py::arg("asm"), py::arg("context") = py::none(),
4032           kContextParseTypeDocstring)
4033       .def_property_readonly(
4034           "context", [](PyType &self) { return self.getContext().getObject(); },
4035           "Context that owns the Type")
4036       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
4037       .def("__eq__", [](PyType &self, py::object &other) { return false; })
4038       .def(
4039           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4040       .def(
4041           "__str__",
4042           [](PyType &self) {
4043             PyPrintAccumulator printAccum;
4044             mlirTypePrint(self, printAccum.getCallback(),
4045                           printAccum.getUserData());
4046             return printAccum.join();
4047           },
4048           "Returns the assembly form of the type.")
4049       .def("__repr__", [](PyType &self) {
4050         // Generally, assembly formats are not printed for __repr__ because
4051         // this can cause exceptionally long debug output and exceptions.
4052         // However, types are an exception as they typically have compact
4053         // assembly forms and printing them is useful.
4054         PyPrintAccumulator printAccum;
4055         printAccum.parts.append("Type(");
4056         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
4057         printAccum.parts.append(")");
4058         return printAccum.join();
4059       });
4060 
4061   // Builtin type bindings.
4062   PyIntegerType::bind(m);
4063   PyIndexType::bind(m);
4064   PyBF16Type::bind(m);
4065   PyF16Type::bind(m);
4066   PyF32Type::bind(m);
4067   PyF64Type::bind(m);
4068   PyNoneType::bind(m);
4069   PyComplexType::bind(m);
4070   PyShapedType::bind(m);
4071   PyVectorType::bind(m);
4072   PyRankedTensorType::bind(m);
4073   PyUnrankedTensorType::bind(m);
4074   PyMemRefType::bind(m);
4075   PyMemRefLayoutMapList::bind(m);
4076   PyUnrankedMemRefType::bind(m);
4077   PyTupleType::bind(m);
4078   PyFunctionType::bind(m);
4079 
4080   //----------------------------------------------------------------------------
4081   // Mapping of Value.
4082   //----------------------------------------------------------------------------
4083   py::class_<PyValue>(m, "Value")
4084       .def_property_readonly(
4085           "context",
4086           [](PyValue &self) { return self.getParentOperation()->getContext(); },
4087           "Context in which the value lives.")
4088       .def(
4089           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4090           kDumpDocstring)
4091       .def("__eq__",
4092            [](PyValue &self, PyValue &other) {
4093              return self.get().ptr == other.get().ptr;
4094            })
4095       .def("__eq__", [](PyValue &self, py::object other) { return false; })
4096       .def(
4097           "__str__",
4098           [](PyValue &self) {
4099             PyPrintAccumulator printAccum;
4100             printAccum.parts.append("Value(");
4101             mlirValuePrint(self.get(), printAccum.getCallback(),
4102                            printAccum.getUserData());
4103             printAccum.parts.append(")");
4104             return printAccum.join();
4105           },
4106           kValueDunderStrDocstring)
4107       .def_property_readonly("type", [](PyValue &self) {
4108         return PyType(self.getParentOperation()->getContext(),
4109                       mlirValueGetType(self.get()));
4110       });
4111   PyBlockArgument::bind(m);
4112   PyOpResult::bind(m);
4113 
4114   // Container bindings.
4115   PyBlockArgumentList::bind(m);
4116   PyBlockIterator::bind(m);
4117   PyBlockList::bind(m);
4118   PyOperationIterator::bind(m);
4119   PyOperationList::bind(m);
4120   PyOpAttributeMap::bind(m);
4121   PyOpOperandList::bind(m);
4122   PyOpResultList::bind(m);
4123   PyRegionIterator::bind(m);
4124   PyRegionList::bind(m);
4125 
4126   //----------------------------------------------------------------------------
4127   // Mapping of PyAffineExpr and derived classes.
4128   //----------------------------------------------------------------------------
4129   py::class_<PyAffineExpr>(m, "AffineExpr")
4130       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4131                              &PyAffineExpr::getCapsule)
4132       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
4133       .def("__add__",
4134            [](PyAffineExpr &self, PyAffineExpr &other) {
4135              return PyAffineAddExpr::get(self, other);
4136            })
4137       .def("__mul__",
4138            [](PyAffineExpr &self, PyAffineExpr &other) {
4139              return PyAffineMulExpr::get(self, other);
4140            })
4141       .def("__mod__",
4142            [](PyAffineExpr &self, PyAffineExpr &other) {
4143              return PyAffineModExpr::get(self, other);
4144            })
4145       .def("__sub__",
4146            [](PyAffineExpr &self, PyAffineExpr &other) {
4147              auto negOne =
4148                  PyAffineConstantExpr::get(-1, *self.getContext().get());
4149              return PyAffineAddExpr::get(self,
4150                                          PyAffineMulExpr::get(negOne, other));
4151            })
4152       .def("__eq__", [](PyAffineExpr &self,
4153                         PyAffineExpr &other) { return self == other; })
4154       .def("__eq__",
4155            [](PyAffineExpr &self, py::object &other) { return false; })
4156       .def("__str__",
4157            [](PyAffineExpr &self) {
4158              PyPrintAccumulator printAccum;
4159              mlirAffineExprPrint(self, printAccum.getCallback(),
4160                                  printAccum.getUserData());
4161              return printAccum.join();
4162            })
4163       .def("__repr__",
4164            [](PyAffineExpr &self) {
4165              PyPrintAccumulator printAccum;
4166              printAccum.parts.append("AffineExpr(");
4167              mlirAffineExprPrint(self, printAccum.getCallback(),
4168                                  printAccum.getUserData());
4169              printAccum.parts.append(")");
4170              return printAccum.join();
4171            })
4172       .def_property_readonly(
4173           "context",
4174           [](PyAffineExpr &self) { return self.getContext().getObject(); })
4175       .def_static(
4176           "get_add", &PyAffineAddExpr::get,
4177           "Gets an affine expression containing a sum of two expressions.")
4178       .def_static(
4179           "get_mul", &PyAffineMulExpr::get,
4180           "Gets an affine expression containing a product of two expressions.")
4181       .def_static("get_mod", &PyAffineModExpr::get,
4182                   "Gets an affine expression containing the modulo of dividing "
4183                   "one expression by another.")
4184       .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
4185                   "Gets an affine expression containing the rounded-down "
4186                   "result of dividing one expression by another.")
4187       .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
4188                   "Gets an affine expression containing the rounded-up result "
4189                   "of dividing one expression by another.")
4190       .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
4191                   py::arg("context") = py::none(),
4192                   "Gets a constant affine expression with the given value.")
4193       .def_static(
4194           "get_dim", &PyAffineDimExpr::get, py::arg("position"),
4195           py::arg("context") = py::none(),
4196           "Gets an affine expression of a dimension at the given position.")
4197       .def_static(
4198           "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
4199           py::arg("context") = py::none(),
4200           "Gets an affine expression of a symbol at the given position.")
4201       .def(
4202           "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
4203           kDumpDocstring);
4204   PyAffineConstantExpr::bind(m);
4205   PyAffineDimExpr::bind(m);
4206   PyAffineSymbolExpr::bind(m);
4207   PyAffineBinaryExpr::bind(m);
4208   PyAffineAddExpr::bind(m);
4209   PyAffineMulExpr::bind(m);
4210   PyAffineModExpr::bind(m);
4211   PyAffineFloorDivExpr::bind(m);
4212   PyAffineCeilDivExpr::bind(m);
4213 
4214   //----------------------------------------------------------------------------
4215   // Mapping of PyAffineMap.
4216   //----------------------------------------------------------------------------
4217   py::class_<PyAffineMap>(m, "AffineMap")
4218       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4219                              &PyAffineMap::getCapsule)
4220       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
4221       .def("__eq__",
4222            [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
4223       .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
4224       .def("__str__",
4225            [](PyAffineMap &self) {
4226              PyPrintAccumulator printAccum;
4227              mlirAffineMapPrint(self, printAccum.getCallback(),
4228                                 printAccum.getUserData());
4229              return printAccum.join();
4230            })
4231       .def("__repr__",
4232            [](PyAffineMap &self) {
4233              PyPrintAccumulator printAccum;
4234              printAccum.parts.append("AffineMap(");
4235              mlirAffineMapPrint(self, printAccum.getCallback(),
4236                                 printAccum.getUserData());
4237              printAccum.parts.append(")");
4238              return printAccum.join();
4239            })
4240       .def_property_readonly(
4241           "context",
4242           [](PyAffineMap &self) { return self.getContext().getObject(); },
4243           "Context that owns the Affine Map")
4244       .def(
4245           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
4246           kDumpDocstring)
4247       .def_static(
4248           "get",
4249           [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
4250              DefaultingPyMlirContext context) {
4251             SmallVector<MlirAffineExpr> affineExprs;
4252             pyListToVector<PyAffineExpr, MlirAffineExpr>(
4253                 exprs, affineExprs, "attempting to create an AffineMap");
4254             MlirAffineMap map =
4255                 mlirAffineMapGet(context->get(), dimCount, symbolCount,
4256                                  affineExprs.size(), affineExprs.data());
4257             return PyAffineMap(context->getRef(), map);
4258           },
4259           py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
4260           py::arg("context") = py::none(),
4261           "Gets a map with the given expressions as results.")
4262       .def_static(
4263           "get_constant",
4264           [](intptr_t value, DefaultingPyMlirContext context) {
4265             MlirAffineMap affineMap =
4266                 mlirAffineMapConstantGet(context->get(), value);
4267             return PyAffineMap(context->getRef(), affineMap);
4268           },
4269           py::arg("value"), py::arg("context") = py::none(),
4270           "Gets an affine map with a single constant result")
4271       .def_static(
4272           "get_empty",
4273           [](DefaultingPyMlirContext context) {
4274             MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
4275             return PyAffineMap(context->getRef(), affineMap);
4276           },
4277           py::arg("context") = py::none(), "Gets an empty affine map.")
4278       .def_static(
4279           "get_identity",
4280           [](intptr_t nDims, DefaultingPyMlirContext context) {
4281             MlirAffineMap affineMap =
4282                 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
4283             return PyAffineMap(context->getRef(), affineMap);
4284           },
4285           py::arg("n_dims"), py::arg("context") = py::none(),
4286           "Gets an identity map with the given number of dimensions.")
4287       .def_static(
4288           "get_minor_identity",
4289           [](intptr_t nDims, intptr_t nResults,
4290              DefaultingPyMlirContext context) {
4291             MlirAffineMap affineMap =
4292                 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
4293             return PyAffineMap(context->getRef(), affineMap);
4294           },
4295           py::arg("n_dims"), py::arg("n_results"),
4296           py::arg("context") = py::none(),
4297           "Gets a minor identity map with the given number of dimensions and "
4298           "results.")
4299       .def_static(
4300           "get_permutation",
4301           [](std::vector<unsigned> permutation,
4302              DefaultingPyMlirContext context) {
4303             if (!isPermutation(permutation))
4304               throw py::cast_error("Invalid permutation when attempting to "
4305                                    "create an AffineMap");
4306             MlirAffineMap affineMap = mlirAffineMapPermutationGet(
4307                 context->get(), permutation.size(), permutation.data());
4308             return PyAffineMap(context->getRef(), affineMap);
4309           },
4310           py::arg("permutation"), py::arg("context") = py::none(),
4311           "Gets an affine map that permutes its inputs.")
4312       .def("get_submap",
4313            [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
4314              intptr_t numResults = mlirAffineMapGetNumResults(self);
4315              for (intptr_t pos : resultPos) {
4316                if (pos < 0 || pos >= numResults)
4317                  throw py::value_error("result position out of bounds");
4318              }
4319              MlirAffineMap affineMap = mlirAffineMapGetSubMap(
4320                  self, resultPos.size(), resultPos.data());
4321              return PyAffineMap(self.getContext(), affineMap);
4322            })
4323       .def("get_major_submap",
4324            [](PyAffineMap &self, intptr_t nResults) {
4325              if (nResults >= mlirAffineMapGetNumResults(self))
4326                throw py::value_error("number of results out of bounds");
4327              MlirAffineMap affineMap =
4328                  mlirAffineMapGetMajorSubMap(self, nResults);
4329              return PyAffineMap(self.getContext(), affineMap);
4330            })
4331       .def("get_minor_submap",
4332            [](PyAffineMap &self, intptr_t nResults) {
4333              if (nResults >= mlirAffineMapGetNumResults(self))
4334                throw py::value_error("number of results out of bounds");
4335              MlirAffineMap affineMap =
4336                  mlirAffineMapGetMinorSubMap(self, nResults);
4337              return PyAffineMap(self.getContext(), affineMap);
4338            })
4339       .def_property_readonly(
4340           "is_permutation",
4341           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
4342       .def_property_readonly("is_projected_permutation",
4343                              [](PyAffineMap &self) {
4344                                return mlirAffineMapIsProjectedPermutation(self);
4345                              })
4346       .def_property_readonly(
4347           "n_dims",
4348           [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
4349       .def_property_readonly(
4350           "n_inputs",
4351           [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
4352       .def_property_readonly(
4353           "n_symbols",
4354           [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
4355       .def_property_readonly("results", [](PyAffineMap &self) {
4356         return PyAffineMapExprList(self);
4357       });
4358   PyAffineMapExprList::bind(m);
4359 
4360   //----------------------------------------------------------------------------
4361   // Mapping of PyIntegerSet.
4362   //----------------------------------------------------------------------------
4363   py::class_<PyIntegerSet>(m, "IntegerSet")
4364       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4365                              &PyIntegerSet::getCapsule)
4366       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
4367       .def("__eq__", [](PyIntegerSet &self,
4368                         PyIntegerSet &other) { return self == other; })
4369       .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
4370       .def("__str__",
4371            [](PyIntegerSet &self) {
4372              PyPrintAccumulator printAccum;
4373              mlirIntegerSetPrint(self, printAccum.getCallback(),
4374                                  printAccum.getUserData());
4375              return printAccum.join();
4376            })
4377       .def("__repr__",
4378            [](PyIntegerSet &self) {
4379              PyPrintAccumulator printAccum;
4380              printAccum.parts.append("IntegerSet(");
4381              mlirIntegerSetPrint(self, printAccum.getCallback(),
4382                                  printAccum.getUserData());
4383              printAccum.parts.append(")");
4384              return printAccum.join();
4385            })
4386       .def_property_readonly(
4387           "context",
4388           [](PyIntegerSet &self) { return self.getContext().getObject(); })
4389       .def(
4390           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
4391           kDumpDocstring)
4392       .def_static(
4393           "get",
4394           [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
4395              std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
4396             if (exprs.size() != eqFlags.size())
4397               throw py::value_error(
4398                   "Expected the number of constraints to match "
4399                   "that of equality flags");
4400             if (exprs.empty())
4401               throw py::value_error("Expected non-empty list of constraints");
4402 
4403             // Copy over to a SmallVector because std::vector has a
4404             // specialization for booleans that packs data and does not
4405             // expose a `bool *`.
4406             SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
4407 
4408             SmallVector<MlirAffineExpr> affineExprs;
4409             pyListToVector<PyAffineExpr>(exprs, affineExprs,
4410                                          "attempting to create an IntegerSet");
4411             MlirIntegerSet set = mlirIntegerSetGet(
4412                 context->get(), numDims, numSymbols, exprs.size(),
4413                 affineExprs.data(), flags.data());
4414             return PyIntegerSet(context->getRef(), set);
4415           },
4416           py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
4417           py::arg("eq_flags"), py::arg("context") = py::none())
4418       .def_static(
4419           "get_empty",
4420           [](intptr_t numDims, intptr_t numSymbols,
4421              DefaultingPyMlirContext context) {
4422             MlirIntegerSet set =
4423                 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
4424             return PyIntegerSet(context->getRef(), set);
4425           },
4426           py::arg("num_dims"), py::arg("num_symbols"),
4427           py::arg("context") = py::none())
4428       .def("get_replaced",
4429            [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
4430               intptr_t numResultDims, intptr_t numResultSymbols) {
4431              if (static_cast<intptr_t>(dimExprs.size()) !=
4432                  mlirIntegerSetGetNumDims(self))
4433                throw py::value_error(
4434                    "Expected the number of dimension replacement expressions "
4435                    "to match that of dimensions");
4436              if (static_cast<intptr_t>(symbolExprs.size()) !=
4437                  mlirIntegerSetGetNumSymbols(self))
4438                throw py::value_error(
4439                    "Expected the number of symbol replacement expressions "
4440                    "to match that of symbols");
4441 
4442              SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
4443              pyListToVector<PyAffineExpr>(
4444                  dimExprs, dimAffineExprs,
4445                  "attempting to create an IntegerSet by replacing dimensions");
4446              pyListToVector<PyAffineExpr>(
4447                  symbolExprs, symbolAffineExprs,
4448                  "attempting to create an IntegerSet by replacing symbols");
4449              MlirIntegerSet set = mlirIntegerSetReplaceGet(
4450                  self, dimAffineExprs.data(), symbolAffineExprs.data(),
4451                  numResultDims, numResultSymbols);
4452              return PyIntegerSet(self.getContext(), set);
4453            })
4454       .def_property_readonly("is_canonical_empty",
4455                              [](PyIntegerSet &self) {
4456                                return mlirIntegerSetIsCanonicalEmpty(self);
4457                              })
4458       .def_property_readonly(
4459           "n_dims",
4460           [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
4461       .def_property_readonly(
4462           "n_symbols",
4463           [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
4464       .def_property_readonly(
4465           "n_inputs",
4466           [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
4467       .def_property_readonly("n_equalities",
4468                              [](PyIntegerSet &self) {
4469                                return mlirIntegerSetGetNumEqualities(self);
4470                              })
4471       .def_property_readonly("n_inequalities",
4472                              [](PyIntegerSet &self) {
4473                                return mlirIntegerSetGetNumInequalities(self);
4474                              })
4475       .def_property_readonly("constraints", [](PyIntegerSet &self) {
4476         return PyIntegerSetConstraintList(self);
4477       });
4478   PyIntegerSetConstraint::bind(m);
4479   PyIntegerSetConstraintList::bind(m);
4480 }
4481