1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModules.h"
10
11 #include "Globals.h"
12 #include "PybindUtils.h"
13
14 #include "mlir-c/AffineMap.h"
15 #include "mlir-c/Bindings/Python/Interop.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/IntegerSet.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34
35 static const char kContextParseTypeDocstring[] =
36 R"(Parses the assembly form of a type.
37
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42
43 static const char kContextGetFileLocationDocstring[] =
44 R"(Gets a Location representing a file, line and column)";
45
46 static const char kModuleParseDocstring[] =
47 R"(Parses a module's assembly format from a string.
48
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53
54 static const char kOperationCreateDocstring[] =
55 R"(Creates a new operation.
56
57 Args:
58 name: Operation name (e.g. "dialect.operation").
59 results: Sequence of Type representing op result types.
60 attributes: Dict of str:Attribute.
61 successors: List of Block for the operation's successors.
62 regions: Number of regions to create.
63 location: A Location object (defaults to resolve from context manager).
64 ip: An InsertionPoint (defaults to resolve from context manager or set to
65 False to disable insertion, even with an insertion point set in the
66 context manager).
67 Returns:
68 A new "detached" Operation object. Detached operations can be added
69 to blocks, which causes them to become "attached."
70 )";
71
72 static const char kOperationPrintDocstring[] =
73 R"(Prints the assembly form of the operation to a file like object.
74
75 Args:
76 file: The file like object to write to. Defaults to sys.stdout.
77 binary: Whether to write bytes (True) or str (False). Defaults to False.
78 large_elements_limit: Whether to elide elements attributes above this
79 number of elements. Defaults to None (no limit).
80 enable_debug_info: Whether to print debug/location information. Defaults
81 to False.
82 pretty_debug_info: Whether to format debug information for easier reading
83 by a human (warning: the result is unparseable).
84 print_generic_op_form: Whether to print the generic assembly forms of all
85 ops. Defaults to False.
86 use_local_Scope: Whether to print in a way that is more optimized for
87 multi-threaded access but may not be consistent with how the overall
88 module prints.
89 )";
90
91 static const char kOperationGetAsmDocstring[] =
92 R"(Gets the assembly form of the operation with all options available.
93
94 Args:
95 binary: Whether to return a bytes (True) or str (False) object. Defaults to
96 False.
97 ... others ...: See the print() method for common keyword arguments for
98 configuring the printout.
99 Returns:
100 Either a bytes or str object, depending on the setting of the 'binary'
101 argument.
102 )";
103
104 static const char kOperationStrDunderDocstring[] =
105 R"(Gets the assembly form of the operation with default options.
106
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111
112 static const char kDumpDocstring[] =
113 R"(Dumps a debug representation of the object to stderr.)";
114
115 static const char kAppendBlockDocstring[] =
116 R"(Appends a new block, with argument types as positional args.
117
118 Returns:
119 The created block.
120 )";
121
122 static const char kValueDunderStrDocstring[] =
123 R"(Returns the string form of the value.
124
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133
134 // Helper for creating an @classmethod.
135 template <class Func, typename... Args>
classmethod(Func f,Args...args)136 py::object classmethod(Func f, Args... args) {
137 py::object cf = py::cpp_function(f, args...);
138 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140
141 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)142 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
143 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
144 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
145 }
146
147 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)148 createCustomDialectWrapper(const std::string &dialectNamespace,
149 py::object dialectDescriptor) {
150 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
151 if (!dialectClass) {
152 // Use the base class.
153 return py::cast(PyDialect(std::move(dialectDescriptor)));
154 }
155
156 // Create the custom implementation.
157 return (*dialectClass)(std::move(dialectDescriptor));
158 }
159
toMlirStringRef(const std::string & s)160 static MlirStringRef toMlirStringRef(const std::string &s) {
161 return mlirStringRefCreate(s.data(), s.size());
162 }
163
164 template <typename PermutationTy>
isPermutation(std::vector<PermutationTy> permutation)165 static bool isPermutation(std::vector<PermutationTy> permutation) {
166 llvm::SmallVector<bool, 8> seen(permutation.size(), false);
167 for (auto val : permutation) {
168 if (val < permutation.size()) {
169 if (seen[val])
170 return false;
171 seen[val] = true;
172 continue;
173 }
174 return false;
175 }
176 return true;
177 }
178
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182
183 namespace {
184
185 class PyRegionIterator {
186 public:
PyRegionIterator(PyOperationRef operation)187 PyRegionIterator(PyOperationRef operation)
188 : operation(std::move(operation)) {}
189
dunderIter()190 PyRegionIterator &dunderIter() { return *this; }
191
dunderNext()192 PyRegion dunderNext() {
193 operation->checkValid();
194 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195 throw py::stop_iteration();
196 }
197 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198 return PyRegion(operation, region);
199 }
200
bind(py::module & m)201 static void bind(py::module &m) {
202 py::class_<PyRegionIterator>(m, "RegionIterator")
203 .def("__iter__", &PyRegionIterator::dunderIter)
204 .def("__next__", &PyRegionIterator::dunderNext);
205 }
206
207 private:
208 PyOperationRef operation;
209 int nextIndex = 0;
210 };
211
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
PyRegionList(PyOperationRef operation)216 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217
dunderLen()218 intptr_t dunderLen() {
219 operation->checkValid();
220 return mlirOperationGetNumRegions(operation->get());
221 }
222
dunderGetItem(intptr_t index)223 PyRegion dunderGetItem(intptr_t index) {
224 // dunderLen checks validity.
225 if (index < 0 || index >= dunderLen()) {
226 throw SetPyError(PyExc_IndexError,
227 "attempt to access out of bounds region");
228 }
229 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230 return PyRegion(operation, region);
231 }
232
bind(py::module & m)233 static void bind(py::module &m) {
234 py::class_<PyRegionList>(m, "RegionSequence")
235 .def("__len__", &PyRegionList::dunderLen)
236 .def("__getitem__", &PyRegionList::dunderGetItem);
237 }
238
239 private:
240 PyOperationRef operation;
241 };
242
243 class PyBlockIterator {
244 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)245 PyBlockIterator(PyOperationRef operation, MlirBlock next)
246 : operation(std::move(operation)), next(next) {}
247
dunderIter()248 PyBlockIterator &dunderIter() { return *this; }
249
dunderNext()250 PyBlock dunderNext() {
251 operation->checkValid();
252 if (mlirBlockIsNull(next)) {
253 throw py::stop_iteration();
254 }
255
256 PyBlock returnBlock(operation, next);
257 next = mlirBlockGetNextInRegion(next);
258 return returnBlock;
259 }
260
bind(py::module & m)261 static void bind(py::module &m) {
262 py::class_<PyBlockIterator>(m, "BlockIterator")
263 .def("__iter__", &PyBlockIterator::dunderIter)
264 .def("__next__", &PyBlockIterator::dunderNext);
265 }
266
267 private:
268 PyOperationRef operation;
269 MlirBlock next;
270 };
271
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
PyBlockList(PyOperationRef operation,MlirRegion region)277 PyBlockList(PyOperationRef operation, MlirRegion region)
278 : operation(std::move(operation)), region(region) {}
279
dunderIter()280 PyBlockIterator dunderIter() {
281 operation->checkValid();
282 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283 }
284
dunderLen()285 intptr_t dunderLen() {
286 operation->checkValid();
287 intptr_t count = 0;
288 MlirBlock block = mlirRegionGetFirstBlock(region);
289 while (!mlirBlockIsNull(block)) {
290 count += 1;
291 block = mlirBlockGetNextInRegion(block);
292 }
293 return count;
294 }
295
dunderGetItem(intptr_t index)296 PyBlock dunderGetItem(intptr_t index) {
297 operation->checkValid();
298 if (index < 0) {
299 throw SetPyError(PyExc_IndexError,
300 "attempt to access out of bounds block");
301 }
302 MlirBlock block = mlirRegionGetFirstBlock(region);
303 while (!mlirBlockIsNull(block)) {
304 if (index == 0) {
305 return PyBlock(operation, block);
306 }
307 block = mlirBlockGetNextInRegion(block);
308 index -= 1;
309 }
310 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311 }
312
appendBlock(py::args pyArgTypes)313 PyBlock appendBlock(py::args pyArgTypes) {
314 operation->checkValid();
315 llvm::SmallVector<MlirType, 4> argTypes;
316 argTypes.reserve(pyArgTypes.size());
317 for (auto &pyArg : pyArgTypes) {
318 argTypes.push_back(pyArg.cast<PyType &>());
319 }
320
321 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322 mlirRegionAppendOwnedBlock(region, block);
323 return PyBlock(operation, block);
324 }
325
bind(py::module & m)326 static void bind(py::module &m) {
327 py::class_<PyBlockList>(m, "BlockList")
328 .def("__getitem__", &PyBlockList::dunderGetItem)
329 .def("__iter__", &PyBlockList::dunderIter)
330 .def("__len__", &PyBlockList::dunderLen)
331 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332 }
333
334 private:
335 PyOperationRef operation;
336 MlirRegion region;
337 };
338
339 class PyOperationIterator {
340 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)341 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342 : parentOperation(std::move(parentOperation)), next(next) {}
343
dunderIter()344 PyOperationIterator &dunderIter() { return *this; }
345
dunderNext()346 py::object dunderNext() {
347 parentOperation->checkValid();
348 if (mlirOperationIsNull(next)) {
349 throw py::stop_iteration();
350 }
351
352 PyOperationRef returnOperation =
353 PyOperation::forOperation(parentOperation->getContext(), next);
354 next = mlirOperationGetNextInBlock(next);
355 return returnOperation->createOpView();
356 }
357
bind(py::module & m)358 static void bind(py::module &m) {
359 py::class_<PyOperationIterator>(m, "OperationIterator")
360 .def("__iter__", &PyOperationIterator::dunderIter)
361 .def("__next__", &PyOperationIterator::dunderNext);
362 }
363
364 private:
365 PyOperationRef parentOperation;
366 MlirOperation next;
367 };
368
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)375 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376 : parentOperation(std::move(parentOperation)), block(block) {}
377
dunderIter()378 PyOperationIterator dunderIter() {
379 parentOperation->checkValid();
380 return PyOperationIterator(parentOperation,
381 mlirBlockGetFirstOperation(block));
382 }
383
dunderLen()384 intptr_t dunderLen() {
385 parentOperation->checkValid();
386 intptr_t count = 0;
387 MlirOperation childOp = mlirBlockGetFirstOperation(block);
388 while (!mlirOperationIsNull(childOp)) {
389 count += 1;
390 childOp = mlirOperationGetNextInBlock(childOp);
391 }
392 return count;
393 }
394
dunderGetItem(intptr_t index)395 py::object dunderGetItem(intptr_t index) {
396 parentOperation->checkValid();
397 if (index < 0) {
398 throw SetPyError(PyExc_IndexError,
399 "attempt to access out of bounds operation");
400 }
401 MlirOperation childOp = mlirBlockGetFirstOperation(block);
402 while (!mlirOperationIsNull(childOp)) {
403 if (index == 0) {
404 return PyOperation::forOperation(parentOperation->getContext(), childOp)
405 ->createOpView();
406 }
407 childOp = mlirOperationGetNextInBlock(childOp);
408 index -= 1;
409 }
410 throw SetPyError(PyExc_IndexError,
411 "attempt to access out of bounds operation");
412 }
413
bind(py::module & m)414 static void bind(py::module &m) {
415 py::class_<PyOperationList>(m, "OperationList")
416 .def("__getitem__", &PyOperationList::dunderGetItem)
417 .def("__iter__", &PyOperationList::dunderIter)
418 .def("__len__", &PyOperationList::dunderLen);
419 }
420
421 private:
422 PyOperationRef parentOperation;
423 MlirBlock block;
424 };
425
426 } // namespace
427
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431
PyMlirContext(MlirContext context)432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433 py::gil_scoped_acquire acquire;
434 auto &liveContexts = getLiveContexts();
435 liveContexts[context.ptr] = this;
436 }
437
~PyMlirContext()438 PyMlirContext::~PyMlirContext() {
439 // Note that the only public way to construct an instance is via the
440 // forContext method, which always puts the associated handle into
441 // liveContexts.
442 py::gil_scoped_acquire acquire;
443 getLiveContexts().erase(context.ptr);
444 mlirContextDestroy(context);
445 }
446
getCapsule()447 py::object PyMlirContext::getCapsule() {
448 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450
createFromCapsule(py::object capsule)451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453 if (mlirContextIsNull(rawContext))
454 throw py::error_already_set();
455 return forContext(rawContext).releaseObject();
456 }
457
createNewContextForInit()458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459 MlirContext context = mlirContextCreate();
460 mlirRegisterAllDialects(context);
461 return new PyMlirContext(context);
462 }
463
forContext(MlirContext context)464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465 py::gil_scoped_acquire acquire;
466 auto &liveContexts = getLiveContexts();
467 auto it = liveContexts.find(context.ptr);
468 if (it == liveContexts.end()) {
469 // Create.
470 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471 py::object pyRef = py::cast(unownedContextWrapper);
472 assert(pyRef && "cast to py::object failed");
473 liveContexts[context.ptr] = unownedContextWrapper;
474 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475 }
476 // Use existing.
477 py::object pyRef = py::cast(it->second);
478 return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480
getLiveContexts()481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482 static LiveContextMap liveContexts;
483 return liveContexts;
484 }
485
getLiveCount()486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487
getLiveOperationCount()488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489
getLiveModuleCount()490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491
contextEnter()492 pybind11::object PyMlirContext::contextEnter() {
493 return PyThreadContextEntry::pushContext(*this);
494 }
495
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)496 void PyMlirContext::contextExit(pybind11::object excType,
497 pybind11::object excVal,
498 pybind11::object excTb) {
499 PyThreadContextEntry::popContext(*this);
500 }
501
resolve()502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504 if (!context) {
505 throw SetPyError(
506 PyExc_RuntimeError,
507 "An MLIR function requires a Context but none was provided in the call "
508 "or from the surrounding environment. Either pass to the function with "
509 "a 'context=' argument or establish a default using 'with Context():'");
510 }
511 return *context;
512 }
513
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517
getStack()518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519 static thread_local std::vector<PyThreadContextEntry> stack;
520 return stack;
521 }
522
getTopOfStack()523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524 auto &stack = getStack();
525 if (stack.empty())
526 return nullptr;
527 return &stack.back();
528 }
529
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531 py::object insertionPoint,
532 py::object location) {
533 auto &stack = getStack();
534 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535 std::move(location));
536 // If the new stack has more than one entry and the context of the new top
537 // entry matches the previous, copy the insertionPoint and location from the
538 // previous entry if missing from the new top entry.
539 if (stack.size() > 1) {
540 auto &prev = *(stack.rbegin() + 1);
541 auto ¤t = stack.back();
542 if (current.context.is(prev.context)) {
543 // Default non-context objects from the previous entry.
544 if (!current.insertionPoint)
545 current.insertionPoint = prev.insertionPoint;
546 if (!current.location)
547 current.location = prev.location;
548 }
549 }
550 }
551
getContext()552 PyMlirContext *PyThreadContextEntry::getContext() {
553 if (!context)
554 return nullptr;
555 return py::cast<PyMlirContext *>(context);
556 }
557
getInsertionPoint()558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559 if (!insertionPoint)
560 return nullptr;
561 return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563
getLocation()564 PyLocation *PyThreadContextEntry::getLocation() {
565 if (!location)
566 return nullptr;
567 return py::cast<PyLocation *>(location);
568 }
569
getDefaultContext()570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571 auto *tos = getTopOfStack();
572 return tos ? tos->getContext() : nullptr;
573 }
574
getDefaultInsertionPoint()575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576 auto *tos = getTopOfStack();
577 return tos ? tos->getInsertionPoint() : nullptr;
578 }
579
getDefaultLocation()580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581 auto *tos = getTopOfStack();
582 return tos ? tos->getLocation() : nullptr;
583 }
584
pushContext(PyMlirContext & context)585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586 py::object contextObj = py::cast(context);
587 push(FrameKind::Context, /*context=*/contextObj,
588 /*insertionPoint=*/py::object(),
589 /*location=*/py::object());
590 return contextObj;
591 }
592
popContext(PyMlirContext & context)593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594 auto &stack = getStack();
595 if (stack.empty())
596 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597 auto &tos = stack.back();
598 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600 stack.pop_back();
601 }
602
603 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605 py::object contextObj =
606 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607 py::object insertionPointObj = py::cast(insertionPoint);
608 push(FrameKind::InsertionPoint,
609 /*context=*/contextObj,
610 /*insertionPoint=*/insertionPointObj,
611 /*location=*/py::object());
612 return insertionPointObj;
613 }
614
popInsertionPoint(PyInsertionPoint & insertionPoint)615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616 auto &stack = getStack();
617 if (stack.empty())
618 throw SetPyError(PyExc_RuntimeError,
619 "Unbalanced InsertionPoint enter/exit");
620 auto &tos = stack.back();
621 if (tos.frameKind != FrameKind::InsertionPoint &&
622 tos.getInsertionPoint() != &insertionPoint)
623 throw SetPyError(PyExc_RuntimeError,
624 "Unbalanced InsertionPoint enter/exit");
625 stack.pop_back();
626 }
627
pushLocation(PyLocation & location)628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629 py::object contextObj = location.getContext().getObject();
630 py::object locationObj = py::cast(location);
631 push(FrameKind::Location, /*context=*/contextObj,
632 /*insertionPoint=*/py::object(),
633 /*location=*/locationObj);
634 return locationObj;
635 }
636
popLocation(PyLocation & location)637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638 auto &stack = getStack();
639 if (stack.empty())
640 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641 auto &tos = stack.back();
642 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644 stack.pop_back();
645 }
646
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650
getDialectForKey(const std::string & key,bool attrError)651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652 bool attrError) {
653 // If the "std" dialect was asked for, substitute the empty namespace :(
654 static const std::string emptyKey;
655 const std::string *canonKey = key == "std" ? &emptyKey : &key;
656 MlirDialect dialect = mlirContextGetOrLoadDialect(
657 getContext()->get(), {canonKey->data(), canonKey->size()});
658 if (mlirDialectIsNull(dialect)) {
659 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
660 Twine("Dialect '") + key + "' not found");
661 }
662 return dialect;
663 }
664
665 //------------------------------------------------------------------------------
666 // PyLocation
667 //------------------------------------------------------------------------------
668
getCapsule()669 py::object PyLocation::getCapsule() {
670 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
671 }
672
createFromCapsule(py::object capsule)673 PyLocation PyLocation::createFromCapsule(py::object capsule) {
674 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
675 if (mlirLocationIsNull(rawLoc))
676 throw py::error_already_set();
677 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
678 rawLoc);
679 }
680
contextEnter()681 py::object PyLocation::contextEnter() {
682 return PyThreadContextEntry::pushLocation(*this);
683 }
684
contextExit(py::object excType,py::object excVal,py::object excTb)685 void PyLocation::contextExit(py::object excType, py::object excVal,
686 py::object excTb) {
687 PyThreadContextEntry::popLocation(*this);
688 }
689
resolve()690 PyLocation &DefaultingPyLocation::resolve() {
691 auto *location = PyThreadContextEntry::getDefaultLocation();
692 if (!location) {
693 throw SetPyError(
694 PyExc_RuntimeError,
695 "An MLIR function requires a Location but none was provided in the "
696 "call or from the surrounding environment. Either pass to the function "
697 "with a 'loc=' argument or establish a default using 'with loc:'");
698 }
699 return *location;
700 }
701
702 //------------------------------------------------------------------------------
703 // PyModule
704 //------------------------------------------------------------------------------
705
PyModule(PyMlirContextRef contextRef,MlirModule module)706 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
707 : BaseContextObject(std::move(contextRef)), module(module) {}
708
~PyModule()709 PyModule::~PyModule() {
710 py::gil_scoped_acquire acquire;
711 auto &liveModules = getContext()->liveModules;
712 assert(liveModules.count(module.ptr) == 1 &&
713 "destroying module not in live map");
714 liveModules.erase(module.ptr);
715 mlirModuleDestroy(module);
716 }
717
forModule(MlirModule module)718 PyModuleRef PyModule::forModule(MlirModule module) {
719 MlirContext context = mlirModuleGetContext(module);
720 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
721
722 py::gil_scoped_acquire acquire;
723 auto &liveModules = contextRef->liveModules;
724 auto it = liveModules.find(module.ptr);
725 if (it == liveModules.end()) {
726 // Create.
727 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
728 // Note that the default return value policy on cast is automatic_reference,
729 // which does not take ownership (delete will not be called).
730 // Just be explicit.
731 py::object pyRef =
732 py::cast(unownedModule, py::return_value_policy::take_ownership);
733 unownedModule->handle = pyRef;
734 liveModules[module.ptr] =
735 std::make_pair(unownedModule->handle, unownedModule);
736 return PyModuleRef(unownedModule, std::move(pyRef));
737 }
738 // Use existing.
739 PyModule *existing = it->second.second;
740 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
741 return PyModuleRef(existing, std::move(pyRef));
742 }
743
createFromCapsule(py::object capsule)744 py::object PyModule::createFromCapsule(py::object capsule) {
745 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
746 if (mlirModuleIsNull(rawModule))
747 throw py::error_already_set();
748 return forModule(rawModule).releaseObject();
749 }
750
getCapsule()751 py::object PyModule::getCapsule() {
752 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
753 }
754
755 //------------------------------------------------------------------------------
756 // PyOperation
757 //------------------------------------------------------------------------------
758
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)759 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
760 : BaseContextObject(std::move(contextRef)), operation(operation) {}
761
~PyOperation()762 PyOperation::~PyOperation() {
763 auto &liveOperations = getContext()->liveOperations;
764 assert(liveOperations.count(operation.ptr) == 1 &&
765 "destroying operation not in live map");
766 liveOperations.erase(operation.ptr);
767 if (!isAttached()) {
768 mlirOperationDestroy(operation);
769 }
770 }
771
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773 MlirOperation operation,
774 py::object parentKeepAlive) {
775 auto &liveOperations = contextRef->liveOperations;
776 // Create.
777 PyOperation *unownedOperation =
778 new PyOperation(std::move(contextRef), operation);
779 // Note that the default return value policy on cast is automatic_reference,
780 // which does not take ownership (delete will not be called).
781 // Just be explicit.
782 py::object pyRef =
783 py::cast(unownedOperation, py::return_value_policy::take_ownership);
784 unownedOperation->handle = pyRef;
785 if (parentKeepAlive) {
786 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787 }
788 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789 return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793 MlirOperation operation,
794 py::object parentKeepAlive) {
795 auto &liveOperations = contextRef->liveOperations;
796 auto it = liveOperations.find(operation.ptr);
797 if (it == liveOperations.end()) {
798 // Create.
799 return createInstance(std::move(contextRef), operation,
800 std::move(parentKeepAlive));
801 }
802 // Use existing.
803 PyOperation *existing = it->second.second;
804 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805 return PyOperationRef(existing, std::move(pyRef));
806 }
807
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809 MlirOperation operation,
810 py::object parentKeepAlive) {
811 auto &liveOperations = contextRef->liveOperations;
812 assert(liveOperations.count(operation.ptr) == 0 &&
813 "cannot create detached operation that already exists");
814 (void)liveOperations;
815
816 PyOperationRef created = createInstance(std::move(contextRef), operation,
817 std::move(parentKeepAlive));
818 created->attached = false;
819 return created;
820 }
821
checkValid() const822 void PyOperation::checkValid() const {
823 if (!valid) {
824 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825 }
826 }
827
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)828 void PyOperationBase::print(py::object fileObject, bool binary,
829 llvm::Optional<int64_t> largeElementsLimit,
830 bool enableDebugInfo, bool prettyDebugInfo,
831 bool printGenericOpForm, bool useLocalScope) {
832 PyOperation &operation = getOperation();
833 operation.checkValid();
834 if (fileObject.is_none())
835 fileObject = py::module::import("sys").attr("stdout");
836
837 if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838 fileObject.attr("write")("// Verification failed, printing generic form\n");
839 printGenericOpForm = true;
840 }
841
842 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843 if (largeElementsLimit)
844 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845 if (enableDebugInfo)
846 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847 if (printGenericOpForm)
848 mlirOpPrintingFlagsPrintGenericOpForm(flags);
849
850 PyFileAccumulator accum(fileObject, binary);
851 py::gil_scoped_release();
852 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853 accum.getUserData());
854 mlirOpPrintingFlagsDestroy(flags);
855 }
856
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)857 py::object PyOperationBase::getAsm(bool binary,
858 llvm::Optional<int64_t> largeElementsLimit,
859 bool enableDebugInfo, bool prettyDebugInfo,
860 bool printGenericOpForm,
861 bool useLocalScope) {
862 py::object fileObject;
863 if (binary) {
864 fileObject = py::module::import("io").attr("BytesIO")();
865 } else {
866 fileObject = py::module::import("io").attr("StringIO")();
867 }
868 print(fileObject, /*binary=*/binary,
869 /*largeElementsLimit=*/largeElementsLimit,
870 /*enableDebugInfo=*/enableDebugInfo,
871 /*prettyDebugInfo=*/prettyDebugInfo,
872 /*printGenericOpForm=*/printGenericOpForm,
873 /*useLocalScope=*/useLocalScope);
874
875 return fileObject.attr("getvalue")();
876 }
877
getParentOperation()878 PyOperationRef PyOperation::getParentOperation() {
879 if (!isAttached())
880 throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
881 MlirOperation operation = mlirOperationGetParentOperation(get());
882 if (mlirOperationIsNull(operation))
883 throw SetPyError(PyExc_ValueError, "Operation has no parent.");
884 return PyOperation::forOperation(getContext(), operation);
885 }
886
getBlock()887 PyBlock PyOperation::getBlock() {
888 PyOperationRef parentOperation = getParentOperation();
889 MlirBlock block = mlirOperationGetBlock(get());
890 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
891 return PyBlock{std::move(parentOperation), block};
892 }
893
create(std::string name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)894 py::object PyOperation::create(
895 std::string name, llvm::Optional<std::vector<PyType *>> results,
896 llvm::Optional<std::vector<PyValue *>> operands,
897 llvm::Optional<py::dict> attributes,
898 llvm::Optional<std::vector<PyBlock *>> successors, int regions,
899 DefaultingPyLocation location, py::object maybeIp) {
900 llvm::SmallVector<MlirValue, 4> mlirOperands;
901 llvm::SmallVector<MlirType, 4> mlirResults;
902 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
903 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
904
905 // General parameter validation.
906 if (regions < 0)
907 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
908
909 // Unpack/validate operands.
910 if (operands) {
911 mlirOperands.reserve(operands->size());
912 for (PyValue *operand : *operands) {
913 if (!operand)
914 throw SetPyError(PyExc_ValueError, "operand value cannot be None");
915 mlirOperands.push_back(operand->get());
916 }
917 }
918
919 // Unpack/validate results.
920 if (results) {
921 mlirResults.reserve(results->size());
922 for (PyType *result : *results) {
923 // TODO: Verify result type originate from the same context.
924 if (!result)
925 throw SetPyError(PyExc_ValueError, "result type cannot be None");
926 mlirResults.push_back(*result);
927 }
928 }
929 // Unpack/validate attributes.
930 if (attributes) {
931 mlirAttributes.reserve(attributes->size());
932 for (auto &it : *attributes) {
933 std::string key;
934 try {
935 key = it.first.cast<std::string>();
936 } catch (py::cast_error &err) {
937 std::string msg = "Invalid attribute key (not a string) when "
938 "attempting to create the operation \"" +
939 name + "\" (" + err.what() + ")";
940 throw py::cast_error(msg);
941 }
942 try {
943 auto &attribute = it.second.cast<PyAttribute &>();
944 // TODO: Verify attribute originates from the same context.
945 mlirAttributes.emplace_back(std::move(key), attribute);
946 } catch (py::reference_cast_error &) {
947 // This exception seems thrown when the value is "None".
948 std::string msg =
949 "Found an invalid (`None`?) attribute value for the key \"" + key +
950 "\" when attempting to create the operation \"" + name + "\"";
951 throw py::cast_error(msg);
952 } catch (py::cast_error &err) {
953 std::string msg = "Invalid attribute value for the key \"" + key +
954 "\" when attempting to create the operation \"" +
955 name + "\" (" + err.what() + ")";
956 throw py::cast_error(msg);
957 }
958 }
959 }
960 // Unpack/validate successors.
961 if (successors) {
962 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
963 mlirSuccessors.reserve(successors->size());
964 for (auto *successor : *successors) {
965 // TODO: Verify successor originate from the same context.
966 if (!successor)
967 throw SetPyError(PyExc_ValueError, "successor block cannot be None");
968 mlirSuccessors.push_back(successor->get());
969 }
970 }
971
972 // Apply unpacked/validated to the operation state. Beyond this
973 // point, exceptions cannot be thrown or else the state will leak.
974 MlirOperationState state =
975 mlirOperationStateGet(toMlirStringRef(name), location);
976 if (!mlirOperands.empty())
977 mlirOperationStateAddOperands(&state, mlirOperands.size(),
978 mlirOperands.data());
979 if (!mlirResults.empty())
980 mlirOperationStateAddResults(&state, mlirResults.size(),
981 mlirResults.data());
982 if (!mlirAttributes.empty()) {
983 // Note that the attribute names directly reference bytes in
984 // mlirAttributes, so that vector must not be changed from here
985 // on.
986 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
987 mlirNamedAttributes.reserve(mlirAttributes.size());
988 for (auto &it : mlirAttributes)
989 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
990 mlirIdentifierGet(mlirAttributeGetContext(it.second),
991 toMlirStringRef(it.first)),
992 it.second));
993 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
994 mlirNamedAttributes.data());
995 }
996 if (!mlirSuccessors.empty())
997 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
998 mlirSuccessors.data());
999 if (regions) {
1000 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1001 mlirRegions.resize(regions);
1002 for (int i = 0; i < regions; ++i)
1003 mlirRegions[i] = mlirRegionCreate();
1004 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1005 mlirRegions.data());
1006 }
1007
1008 // Construct the operation.
1009 MlirOperation operation = mlirOperationCreate(&state);
1010 PyOperationRef created =
1011 PyOperation::createDetached(location->getContext(), operation);
1012
1013 // InsertPoint active?
1014 if (!maybeIp.is(py::cast(false))) {
1015 PyInsertionPoint *ip;
1016 if (maybeIp.is_none()) {
1017 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1018 } else {
1019 ip = py::cast<PyInsertionPoint *>(maybeIp);
1020 }
1021 if (ip)
1022 ip->insert(*created.get());
1023 }
1024
1025 return created->createOpView();
1026 }
1027
createOpView()1028 py::object PyOperation::createOpView() {
1029 MlirIdentifier ident = mlirOperationGetName(get());
1030 MlirStringRef identStr = mlirIdentifierStr(ident);
1031 auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1032 StringRef(identStr.data, identStr.length));
1033 if (opViewClass)
1034 return (*opViewClass)(getRef().getObject());
1035 return py::cast(PyOpView(getRef().getObject()));
1036 }
1037
1038 //------------------------------------------------------------------------------
1039 // PyOpView
1040 //------------------------------------------------------------------------------
1041
1042 py::object
buildGeneric(py::object cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,py::object maybeIp)1043 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1044 py::list operandList,
1045 llvm::Optional<py::dict> attributes,
1046 llvm::Optional<std::vector<PyBlock *>> successors,
1047 llvm::Optional<int> regions,
1048 DefaultingPyLocation location, py::object maybeIp) {
1049 PyMlirContextRef context = location->getContext();
1050 // Class level operation construction metadata.
1051 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1052 // Operand and result segment specs are either none, which does no
1053 // variadic unpacking, or a list of ints with segment sizes, where each
1054 // element is either a positive number (typically 1 for a scalar) or -1 to
1055 // indicate that it is derived from the length of the same-indexed operand
1056 // or result (implying that it is a list at that position).
1057 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1058 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1059
1060 std::vector<uint64_t> operandSegmentLengths;
1061 std::vector<uint64_t> resultSegmentLengths;
1062
1063 // Validate/determine region count.
1064 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1065 int opMinRegionCount = std::get<0>(opRegionSpec);
1066 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1067 if (!regions) {
1068 regions = opMinRegionCount;
1069 }
1070 if (*regions < opMinRegionCount) {
1071 throw py::value_error(
1072 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1073 llvm::Twine(opMinRegionCount) +
1074 " regions but was built with regions=" + llvm::Twine(*regions))
1075 .str());
1076 }
1077 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1078 throw py::value_error(
1079 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1080 llvm::Twine(opMinRegionCount) +
1081 " regions but was built with regions=" + llvm::Twine(*regions))
1082 .str());
1083 }
1084
1085 // Unpack results.
1086 std::vector<PyType *> resultTypes;
1087 resultTypes.reserve(resultTypeList.size());
1088 if (resultSegmentSpecObj.is_none()) {
1089 // Non-variadic result unpacking.
1090 for (auto it : llvm::enumerate(resultTypeList)) {
1091 try {
1092 resultTypes.push_back(py::cast<PyType *>(it.value()));
1093 if (!resultTypes.back())
1094 throw py::cast_error();
1095 } catch (py::cast_error &err) {
1096 throw py::value_error((llvm::Twine("Result ") +
1097 llvm::Twine(it.index()) + " of operation \"" +
1098 name + "\" must be a Type (" + err.what() + ")")
1099 .str());
1100 }
1101 }
1102 } else {
1103 // Sized result unpacking.
1104 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1105 if (resultSegmentSpec.size() != resultTypeList.size()) {
1106 throw py::value_error((llvm::Twine("Operation \"") + name +
1107 "\" requires " +
1108 llvm::Twine(resultSegmentSpec.size()) +
1109 "result segments but was provided " +
1110 llvm::Twine(resultTypeList.size()))
1111 .str());
1112 }
1113 resultSegmentLengths.reserve(resultTypeList.size());
1114 for (auto it :
1115 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1116 int segmentSpec = std::get<1>(it.value());
1117 if (segmentSpec == 1 || segmentSpec == 0) {
1118 // Unpack unary element.
1119 try {
1120 auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1121 if (resultType) {
1122 resultTypes.push_back(resultType);
1123 resultSegmentLengths.push_back(1);
1124 } else if (segmentSpec == 0) {
1125 // Allowed to be optional.
1126 resultSegmentLengths.push_back(0);
1127 } else {
1128 throw py::cast_error("was None and result is not optional");
1129 }
1130 } catch (py::cast_error &err) {
1131 throw py::value_error((llvm::Twine("Result ") +
1132 llvm::Twine(it.index()) + " of operation \"" +
1133 name + "\" must be a Type (" + err.what() +
1134 ")")
1135 .str());
1136 }
1137 } else if (segmentSpec == -1) {
1138 // Unpack sequence by appending.
1139 try {
1140 if (std::get<0>(it.value()).is_none()) {
1141 // Treat it as an empty list.
1142 resultSegmentLengths.push_back(0);
1143 } else {
1144 // Unpack the list.
1145 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1146 for (py::object segmentItem : segment) {
1147 resultTypes.push_back(py::cast<PyType *>(segmentItem));
1148 if (!resultTypes.back()) {
1149 throw py::cast_error("contained a None item");
1150 }
1151 }
1152 resultSegmentLengths.push_back(segment.size());
1153 }
1154 } catch (std::exception &err) {
1155 // NOTE: Sloppy to be using a catch-all here, but there are at least
1156 // three different unrelated exceptions that can be thrown in the
1157 // above "casts". Just keep the scope above small and catch them all.
1158 throw py::value_error((llvm::Twine("Result ") +
1159 llvm::Twine(it.index()) + " of operation \"" +
1160 name + "\" must be a Sequence of Types (" +
1161 err.what() + ")")
1162 .str());
1163 }
1164 } else {
1165 throw py::value_error("Unexpected segment spec");
1166 }
1167 }
1168 }
1169
1170 // Unpack operands.
1171 std::vector<PyValue *> operands;
1172 operands.reserve(operands.size());
1173 if (operandSegmentSpecObj.is_none()) {
1174 // Non-sized operand unpacking.
1175 for (auto it : llvm::enumerate(operandList)) {
1176 try {
1177 operands.push_back(py::cast<PyValue *>(it.value()));
1178 if (!operands.back())
1179 throw py::cast_error();
1180 } catch (py::cast_error &err) {
1181 throw py::value_error((llvm::Twine("Operand ") +
1182 llvm::Twine(it.index()) + " of operation \"" +
1183 name + "\" must be a Value (" + err.what() + ")")
1184 .str());
1185 }
1186 }
1187 } else {
1188 // Sized operand unpacking.
1189 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1190 if (operandSegmentSpec.size() != operandList.size()) {
1191 throw py::value_error((llvm::Twine("Operation \"") + name +
1192 "\" requires " +
1193 llvm::Twine(operandSegmentSpec.size()) +
1194 "operand segments but was provided " +
1195 llvm::Twine(operandList.size()))
1196 .str());
1197 }
1198 operandSegmentLengths.reserve(operandList.size());
1199 for (auto it :
1200 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1201 int segmentSpec = std::get<1>(it.value());
1202 if (segmentSpec == 1 || segmentSpec == 0) {
1203 // Unpack unary element.
1204 try {
1205 auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1206 if (operandValue) {
1207 operands.push_back(operandValue);
1208 operandSegmentLengths.push_back(1);
1209 } else if (segmentSpec == 0) {
1210 // Allowed to be optional.
1211 operandSegmentLengths.push_back(0);
1212 } else {
1213 throw py::cast_error("was None and operand is not optional");
1214 }
1215 } catch (py::cast_error &err) {
1216 throw py::value_error((llvm::Twine("Operand ") +
1217 llvm::Twine(it.index()) + " of operation \"" +
1218 name + "\" must be a Value (" + err.what() +
1219 ")")
1220 .str());
1221 }
1222 } else if (segmentSpec == -1) {
1223 // Unpack sequence by appending.
1224 try {
1225 if (std::get<0>(it.value()).is_none()) {
1226 // Treat it as an empty list.
1227 operandSegmentLengths.push_back(0);
1228 } else {
1229 // Unpack the list.
1230 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1231 for (py::object segmentItem : segment) {
1232 operands.push_back(py::cast<PyValue *>(segmentItem));
1233 if (!operands.back()) {
1234 throw py::cast_error("contained a None item");
1235 }
1236 }
1237 operandSegmentLengths.push_back(segment.size());
1238 }
1239 } catch (std::exception &err) {
1240 // NOTE: Sloppy to be using a catch-all here, but there are at least
1241 // three different unrelated exceptions that can be thrown in the
1242 // above "casts". Just keep the scope above small and catch them all.
1243 throw py::value_error((llvm::Twine("Operand ") +
1244 llvm::Twine(it.index()) + " of operation \"" +
1245 name + "\" must be a Sequence of Values (" +
1246 err.what() + ")")
1247 .str());
1248 }
1249 } else {
1250 throw py::value_error("Unexpected segment spec");
1251 }
1252 }
1253 }
1254
1255 // Merge operand/result segment lengths into attributes if needed.
1256 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1257 // Dup.
1258 if (attributes) {
1259 attributes = py::dict(*attributes);
1260 } else {
1261 attributes = py::dict();
1262 }
1263 if (attributes->contains("result_segment_sizes") ||
1264 attributes->contains("operand_segment_sizes")) {
1265 throw py::value_error("Manually setting a 'result_segment_sizes' or "
1266 "'operand_segment_sizes' attribute is unsupported. "
1267 "Use Operation.create for such low-level access.");
1268 }
1269
1270 // Add result_segment_sizes attribute.
1271 if (!resultSegmentLengths.empty()) {
1272 int64_t size = resultSegmentLengths.size();
1273 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1274 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1275 resultSegmentLengths.size(), resultSegmentLengths.data());
1276 (*attributes)["result_segment_sizes"] =
1277 PyAttribute(context, segmentLengthAttr);
1278 }
1279
1280 // Add operand_segment_sizes attribute.
1281 if (!operandSegmentLengths.empty()) {
1282 int64_t size = operandSegmentLengths.size();
1283 MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1284 mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1285 operandSegmentLengths.size(), operandSegmentLengths.data());
1286 (*attributes)["operand_segment_sizes"] =
1287 PyAttribute(context, segmentLengthAttr);
1288 }
1289 }
1290
1291 // Delegate to create.
1292 return PyOperation::create(std::move(name),
1293 /*results=*/std::move(resultTypes),
1294 /*operands=*/std::move(operands),
1295 /*attributes=*/std::move(attributes),
1296 /*successors=*/std::move(successors),
1297 /*regions=*/*regions, location, maybeIp);
1298 }
1299
PyOpView(py::object operationObject)1300 PyOpView::PyOpView(py::object operationObject)
1301 // Casting through the PyOperationBase base-class and then back to the
1302 // Operation lets us accept any PyOperationBase subclass.
1303 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1304 operationObject(operation.getRef().getObject()) {}
1305
createRawSubclass(py::object userClass)1306 py::object PyOpView::createRawSubclass(py::object userClass) {
1307 // This is... a little gross. The typical pattern is to have a pure python
1308 // class that extends OpView like:
1309 // class AddFOp(_cext.ir.OpView):
1310 // def __init__(self, loc, lhs, rhs):
1311 // operation = loc.context.create_operation(
1312 // "addf", lhs, rhs, results=[lhs.type])
1313 // super().__init__(operation)
1314 //
1315 // I.e. The goal of the user facing type is to provide a nice constructor
1316 // that has complete freedom for the op under construction. This is at odds
1317 // with our other desire to sometimes create this object by just passing an
1318 // operation (to initialize the base class). We could do *arg and **kwargs
1319 // munging to try to make it work, but instead, we synthesize a new class
1320 // on the fly which extends this user class (AddFOp in this example) and
1321 // *give it* the base class's __init__ method, thus bypassing the
1322 // intermediate subclass's __init__ method entirely. While slightly,
1323 // underhanded, this is safe/legal because the type hierarchy has not changed
1324 // (we just added a new leaf) and we aren't mucking around with __new__.
1325 // Typically, this new class will be stored on the original as "_Raw" and will
1326 // be used for casts and other things that need a variant of the class that
1327 // is initialized purely from an operation.
1328 py::object parentMetaclass =
1329 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1330 py::dict attributes;
1331 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1332 // now.
1333 // auto opViewType = py::type::of<PyOpView>();
1334 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1335 attributes["__init__"] = opViewType.attr("__init__");
1336 py::str origName = userClass.attr("__name__");
1337 py::str newName = py::str("_") + origName;
1338 return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1339 }
1340
1341 //------------------------------------------------------------------------------
1342 // PyInsertionPoint.
1343 //------------------------------------------------------------------------------
1344
PyInsertionPoint(PyBlock & block)1345 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1346
PyInsertionPoint(PyOperationBase & beforeOperationBase)1347 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1348 : refOperation(beforeOperationBase.getOperation().getRef()),
1349 block((*refOperation)->getBlock()) {}
1350
insert(PyOperationBase & operationBase)1351 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1352 PyOperation &operation = operationBase.getOperation();
1353 if (operation.isAttached())
1354 throw SetPyError(PyExc_ValueError,
1355 "Attempt to insert operation that is already attached");
1356 block.getParentOperation()->checkValid();
1357 MlirOperation beforeOp = {nullptr};
1358 if (refOperation) {
1359 // Insert before operation.
1360 (*refOperation)->checkValid();
1361 beforeOp = (*refOperation)->get();
1362 } else {
1363 // Insert at end (before null) is only valid if the block does not
1364 // already end in a known terminator (violating this will cause assertion
1365 // failures later).
1366 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1367 throw py::index_error("Cannot insert operation at the end of a block "
1368 "that already has a terminator. Did you mean to "
1369 "use 'InsertionPoint.at_block_terminator(block)' "
1370 "versus 'InsertionPoint(block)'?");
1371 }
1372 }
1373 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1374 operation.setAttached();
1375 }
1376
atBlockBegin(PyBlock & block)1377 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1378 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1379 if (mlirOperationIsNull(firstOp)) {
1380 // Just insert at end.
1381 return PyInsertionPoint(block);
1382 }
1383
1384 // Insert before first op.
1385 PyOperationRef firstOpRef = PyOperation::forOperation(
1386 block.getParentOperation()->getContext(), firstOp);
1387 return PyInsertionPoint{block, std::move(firstOpRef)};
1388 }
1389
atBlockTerminator(PyBlock & block)1390 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1391 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1392 if (mlirOperationIsNull(terminator))
1393 throw SetPyError(PyExc_ValueError, "Block has no terminator");
1394 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1395 block.getParentOperation()->getContext(), terminator);
1396 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1397 }
1398
contextEnter()1399 py::object PyInsertionPoint::contextEnter() {
1400 return PyThreadContextEntry::pushInsertionPoint(*this);
1401 }
1402
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1403 void PyInsertionPoint::contextExit(pybind11::object excType,
1404 pybind11::object excVal,
1405 pybind11::object excTb) {
1406 PyThreadContextEntry::popInsertionPoint(*this);
1407 }
1408
1409 //------------------------------------------------------------------------------
1410 // PyAttribute.
1411 //------------------------------------------------------------------------------
1412
operator ==(const PyAttribute & other)1413 bool PyAttribute::operator==(const PyAttribute &other) {
1414 return mlirAttributeEqual(attr, other.attr);
1415 }
1416
getCapsule()1417 py::object PyAttribute::getCapsule() {
1418 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1419 }
1420
createFromCapsule(py::object capsule)1421 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1422 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1423 if (mlirAttributeIsNull(rawAttr))
1424 throw py::error_already_set();
1425 return PyAttribute(
1426 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1427 }
1428
1429 //------------------------------------------------------------------------------
1430 // PyNamedAttribute.
1431 //------------------------------------------------------------------------------
1432
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1433 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1434 : ownedName(new std::string(std::move(ownedName))) {
1435 namedAttr = mlirNamedAttributeGet(
1436 mlirIdentifierGet(mlirAttributeGetContext(attr),
1437 toMlirStringRef(*this->ownedName)),
1438 attr);
1439 }
1440
1441 //------------------------------------------------------------------------------
1442 // PyType.
1443 //------------------------------------------------------------------------------
1444
operator ==(const PyType & other)1445 bool PyType::operator==(const PyType &other) {
1446 return mlirTypeEqual(type, other.type);
1447 }
1448
getCapsule()1449 py::object PyType::getCapsule() {
1450 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1451 }
1452
createFromCapsule(py::object capsule)1453 PyType PyType::createFromCapsule(py::object capsule) {
1454 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1455 if (mlirTypeIsNull(rawType))
1456 throw py::error_already_set();
1457 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1458 rawType);
1459 }
1460
1461 //------------------------------------------------------------------------------
1462 // PyValue and subclases.
1463 //------------------------------------------------------------------------------
1464
1465 namespace {
1466 /// CRTP base class for Python MLIR values that subclass Value and should be
1467 /// castable from it. The value hierarchy is one level deep and is not supposed
1468 /// to accommodate other levels unless core MLIR changes.
1469 template <typename DerivedTy>
1470 class PyConcreteValue : public PyValue {
1471 public:
1472 // Derived classes must define statics for:
1473 // IsAFunctionTy isaFunction
1474 // const char *pyClassName
1475 // and redefine bindDerived.
1476 using ClassTy = py::class_<DerivedTy, PyValue>;
1477 using IsAFunctionTy = bool (*)(MlirValue);
1478
1479 PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1480 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1481 : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1482 PyConcreteValue(PyValue &orig)
1483 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1484
1485 /// Attempts to cast the original value to the derived type and throws on
1486 /// type mismatches.
castFrom(PyValue & orig)1487 static MlirValue castFrom(PyValue &orig) {
1488 if (!DerivedTy::isaFunction(orig.get())) {
1489 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1490 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1491 DerivedTy::pyClassName +
1492 " (from " + origRepr + ")");
1493 }
1494 return orig.get();
1495 }
1496
1497 /// Binds the Python module objects to functions of this class.
bind(py::module & m)1498 static void bind(py::module &m) {
1499 auto cls = ClassTy(m, DerivedTy::pyClassName);
1500 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1501 DerivedTy::bindDerived(cls);
1502 }
1503
1504 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1505 static void bindDerived(ClassTy &m) {}
1506 };
1507
1508 /// Python wrapper for MlirBlockArgument.
1509 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1510 public:
1511 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1512 static constexpr const char *pyClassName = "BlockArgument";
1513 using PyConcreteValue::PyConcreteValue;
1514
bindDerived(ClassTy & c)1515 static void bindDerived(ClassTy &c) {
1516 c.def_property_readonly("owner", [](PyBlockArgument &self) {
1517 return PyBlock(self.getParentOperation(),
1518 mlirBlockArgumentGetOwner(self.get()));
1519 });
1520 c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1521 return mlirBlockArgumentGetArgNumber(self.get());
1522 });
1523 c.def("set_type", [](PyBlockArgument &self, PyType type) {
1524 return mlirBlockArgumentSetType(self.get(), type);
1525 });
1526 }
1527 };
1528
1529 /// Python wrapper for MlirOpResult.
1530 class PyOpResult : public PyConcreteValue<PyOpResult> {
1531 public:
1532 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1533 static constexpr const char *pyClassName = "OpResult";
1534 using PyConcreteValue::PyConcreteValue;
1535
bindDerived(ClassTy & c)1536 static void bindDerived(ClassTy &c) {
1537 c.def_property_readonly("owner", [](PyOpResult &self) {
1538 assert(
1539 mlirOperationEqual(self.getParentOperation()->get(),
1540 mlirOpResultGetOwner(self.get())) &&
1541 "expected the owner of the value in Python to match that in the IR");
1542 return self.getParentOperation();
1543 });
1544 c.def_property_readonly("result_number", [](PyOpResult &self) {
1545 return mlirOpResultGetResultNumber(self.get());
1546 });
1547 }
1548 };
1549
1550 /// A list of block arguments. Internally, these are stored as consecutive
1551 /// elements, random access is cheap. The argument list is associated with the
1552 /// operation that contains the block (detached blocks are not allowed in
1553 /// Python bindings) and extends its lifetime.
1554 class PyBlockArgumentList {
1555 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1556 PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1557 : operation(std::move(operation)), block(block) {}
1558
1559 /// Returns the length of the block argument list.
dunderLen()1560 intptr_t dunderLen() {
1561 operation->checkValid();
1562 return mlirBlockGetNumArguments(block);
1563 }
1564
1565 /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1566 PyBlockArgument dunderGetItem(intptr_t index) {
1567 if (index < 0 || index >= dunderLen()) {
1568 throw SetPyError(PyExc_IndexError,
1569 "attempt to access out of bounds region");
1570 }
1571 PyValue value(operation, mlirBlockGetArgument(block, index));
1572 return PyBlockArgument(value);
1573 }
1574
1575 /// Defines a Python class in the bindings.
bind(py::module & m)1576 static void bind(py::module &m) {
1577 py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1578 .def("__len__", &PyBlockArgumentList::dunderLen)
1579 .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1580 }
1581
1582 private:
1583 PyOperationRef operation;
1584 MlirBlock block;
1585 };
1586
1587 /// A list of operation operands. Internally, these are stored as consecutive
1588 /// elements, random access is cheap. The result list is associated with the
1589 /// operation whose results these are, and extends the lifetime of this
1590 /// operation.
1591 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1592 public:
1593 static constexpr const char *pyClassName = "OpOperandList";
1594
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1595 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1596 intptr_t length = -1, intptr_t step = 1)
1597 : Sliceable(startIndex,
1598 length == -1 ? mlirOperationGetNumOperands(operation->get())
1599 : length,
1600 step),
1601 operation(operation) {}
1602
getNumElements()1603 intptr_t getNumElements() {
1604 operation->checkValid();
1605 return mlirOperationGetNumOperands(operation->get());
1606 }
1607
getElement(intptr_t pos)1608 PyValue getElement(intptr_t pos) {
1609 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1610 }
1611
slice(intptr_t startIndex,intptr_t length,intptr_t step)1612 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1613 return PyOpOperandList(operation, startIndex, length, step);
1614 }
1615
1616 private:
1617 PyOperationRef operation;
1618 };
1619
1620 /// A list of operation results. Internally, these are stored as consecutive
1621 /// elements, random access is cheap. The result list is associated with the
1622 /// operation whose results these are, and extends the lifetime of this
1623 /// operation.
1624 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1625 public:
1626 static constexpr const char *pyClassName = "OpResultList";
1627
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1628 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1629 intptr_t length = -1, intptr_t step = 1)
1630 : Sliceable(startIndex,
1631 length == -1 ? mlirOperationGetNumResults(operation->get())
1632 : length,
1633 step),
1634 operation(operation) {}
1635
getNumElements()1636 intptr_t getNumElements() {
1637 operation->checkValid();
1638 return mlirOperationGetNumResults(operation->get());
1639 }
1640
getElement(intptr_t index)1641 PyOpResult getElement(intptr_t index) {
1642 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1643 return PyOpResult(value);
1644 }
1645
slice(intptr_t startIndex,intptr_t length,intptr_t step)1646 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1647 return PyOpResultList(operation, startIndex, length, step);
1648 }
1649
1650 private:
1651 PyOperationRef operation;
1652 };
1653
1654 /// A list of operation attributes. Can be indexed by name, producing
1655 /// attributes, or by index, producing named attributes.
1656 class PyOpAttributeMap {
1657 public:
PyOpAttributeMap(PyOperationRef operation)1658 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1659
dunderGetItemNamed(const std::string & name)1660 PyAttribute dunderGetItemNamed(const std::string &name) {
1661 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1662 toMlirStringRef(name));
1663 if (mlirAttributeIsNull(attr)) {
1664 throw SetPyError(PyExc_KeyError,
1665 "attempt to access a non-existent attribute");
1666 }
1667 return PyAttribute(operation->getContext(), attr);
1668 }
1669
dunderGetItemIndexed(intptr_t index)1670 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1671 if (index < 0 || index >= dunderLen()) {
1672 throw SetPyError(PyExc_IndexError,
1673 "attempt to access out of bounds attribute");
1674 }
1675 MlirNamedAttribute namedAttr =
1676 mlirOperationGetAttribute(operation->get(), index);
1677 return PyNamedAttribute(
1678 namedAttr.attribute,
1679 std::string(mlirIdentifierStr(namedAttr.name).data));
1680 }
1681
dunderSetItem(const std::string & name,PyAttribute attr)1682 void dunderSetItem(const std::string &name, PyAttribute attr) {
1683 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1684 attr);
1685 }
1686
dunderDelItem(const std::string & name)1687 void dunderDelItem(const std::string &name) {
1688 int removed = mlirOperationRemoveAttributeByName(operation->get(),
1689 toMlirStringRef(name));
1690 if (!removed)
1691 throw SetPyError(PyExc_KeyError,
1692 "attempt to delete a non-existent attribute");
1693 }
1694
dunderLen()1695 intptr_t dunderLen() {
1696 return mlirOperationGetNumAttributes(operation->get());
1697 }
1698
dunderContains(const std::string & name)1699 bool dunderContains(const std::string &name) {
1700 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1701 operation->get(), toMlirStringRef(name)));
1702 }
1703
bind(py::module & m)1704 static void bind(py::module &m) {
1705 py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1706 .def("__contains__", &PyOpAttributeMap::dunderContains)
1707 .def("__len__", &PyOpAttributeMap::dunderLen)
1708 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1709 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1710 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1711 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1712 }
1713
1714 private:
1715 PyOperationRef operation;
1716 };
1717
1718 } // end namespace
1719
1720 //------------------------------------------------------------------------------
1721 // Builtin attribute subclasses.
1722 //------------------------------------------------------------------------------
1723
1724 namespace {
1725
1726 /// CRTP base classes for Python attributes that subclass Attribute and should
1727 /// be castable from it (i.e. via something like StringAttr(attr)).
1728 /// By default, attribute class hierarchies are one level deep (i.e. a
1729 /// concrete attribute class extends PyAttribute); however, intermediate
1730 /// python-visible base classes can be modeled by specifying a BaseTy.
1731 template <typename DerivedTy, typename BaseTy = PyAttribute>
1732 class PyConcreteAttribute : public BaseTy {
1733 public:
1734 // Derived classes must define statics for:
1735 // IsAFunctionTy isaFunction
1736 // const char *pyClassName
1737 using ClassTy = py::class_<DerivedTy, BaseTy>;
1738 using IsAFunctionTy = bool (*)(MlirAttribute);
1739
1740 PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)1741 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1742 : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)1743 PyConcreteAttribute(PyAttribute &orig)
1744 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
1745
castFrom(PyAttribute & orig)1746 static MlirAttribute castFrom(PyAttribute &orig) {
1747 if (!DerivedTy::isaFunction(orig)) {
1748 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1749 throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
1750 DerivedTy::pyClassName +
1751 " (from " + origRepr + ")");
1752 }
1753 return orig;
1754 }
1755
bind(py::module & m)1756 static void bind(py::module &m) {
1757 auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
1758 cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
1759 DerivedTy::bindDerived(cls);
1760 }
1761
1762 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1763 static void bindDerived(ClassTy &m) {}
1764 };
1765
1766 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
1767 public:
1768 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
1769 static constexpr const char *pyClassName = "ArrayAttr";
1770 using PyConcreteAttribute::PyConcreteAttribute;
1771
1772 class PyArrayAttributeIterator {
1773 public:
PyArrayAttributeIterator(PyAttribute attr)1774 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
1775
dunderIter()1776 PyArrayAttributeIterator &dunderIter() { return *this; }
1777
dunderNext()1778 PyAttribute dunderNext() {
1779 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
1780 throw py::stop_iteration();
1781 }
1782 return PyAttribute(attr.getContext(),
1783 mlirArrayAttrGetElement(attr.get(), nextIndex++));
1784 }
1785
bind(py::module & m)1786 static void bind(py::module &m) {
1787 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
1788 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
1789 .def("__next__", &PyArrayAttributeIterator::dunderNext);
1790 }
1791
1792 private:
1793 PyAttribute attr;
1794 int nextIndex = 0;
1795 };
1796
bindDerived(ClassTy & c)1797 static void bindDerived(ClassTy &c) {
1798 c.def_static(
1799 "get",
1800 [](py::list attributes, DefaultingPyMlirContext context) {
1801 SmallVector<MlirAttribute> mlirAttributes;
1802 mlirAttributes.reserve(py::len(attributes));
1803 for (auto attribute : attributes) {
1804 try {
1805 mlirAttributes.push_back(attribute.cast<PyAttribute>());
1806 } catch (py::cast_error &err) {
1807 std::string msg = std::string("Invalid attribute when attempting "
1808 "to create an ArrayAttribute (") +
1809 err.what() + ")";
1810 throw py::cast_error(msg);
1811 } catch (py::reference_cast_error &err) {
1812 // This exception seems thrown when the value is "None".
1813 std::string msg =
1814 std::string("Invalid attribute (None?) when attempting to "
1815 "create an ArrayAttribute (") +
1816 err.what() + ")";
1817 throw py::cast_error(msg);
1818 }
1819 }
1820 MlirAttribute attr = mlirArrayAttrGet(
1821 context->get(), mlirAttributes.size(), mlirAttributes.data());
1822 return PyArrayAttribute(context->getRef(), attr);
1823 },
1824 py::arg("attributes"), py::arg("context") = py::none(),
1825 "Gets a uniqued Array attribute");
1826 c.def("__getitem__",
1827 [](PyArrayAttribute &arr, intptr_t i) {
1828 if (i >= mlirArrayAttrGetNumElements(arr))
1829 throw py::index_error("ArrayAttribute index out of range");
1830 return PyAttribute(arr.getContext(),
1831 mlirArrayAttrGetElement(arr, i));
1832 })
1833 .def("__len__",
1834 [](const PyArrayAttribute &arr) {
1835 return mlirArrayAttrGetNumElements(arr);
1836 })
1837 .def("__iter__", [](const PyArrayAttribute &arr) {
1838 return PyArrayAttributeIterator(arr);
1839 });
1840 }
1841 };
1842
1843 /// Float Point Attribute subclass - FloatAttr.
1844 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
1845 public:
1846 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
1847 static constexpr const char *pyClassName = "FloatAttr";
1848 using PyConcreteAttribute::PyConcreteAttribute;
1849
bindDerived(ClassTy & c)1850 static void bindDerived(ClassTy &c) {
1851 c.def_static(
1852 "get",
1853 [](PyType &type, double value, DefaultingPyLocation loc) {
1854 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
1855 // TODO: Rework error reporting once diagnostic engine is exposed
1856 // in C API.
1857 if (mlirAttributeIsNull(attr)) {
1858 throw SetPyError(PyExc_ValueError,
1859 Twine("invalid '") +
1860 py::repr(py::cast(type)).cast<std::string>() +
1861 "' and expected floating point type.");
1862 }
1863 return PyFloatAttribute(type.getContext(), attr);
1864 },
1865 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
1866 "Gets an uniqued float point attribute associated to a type");
1867 c.def_static(
1868 "get_f32",
1869 [](double value, DefaultingPyMlirContext context) {
1870 MlirAttribute attr = mlirFloatAttrDoubleGet(
1871 context->get(), mlirF32TypeGet(context->get()), value);
1872 return PyFloatAttribute(context->getRef(), attr);
1873 },
1874 py::arg("value"), py::arg("context") = py::none(),
1875 "Gets an uniqued float point attribute associated to a f32 type");
1876 c.def_static(
1877 "get_f64",
1878 [](double value, DefaultingPyMlirContext context) {
1879 MlirAttribute attr = mlirFloatAttrDoubleGet(
1880 context->get(), mlirF64TypeGet(context->get()), value);
1881 return PyFloatAttribute(context->getRef(), attr);
1882 },
1883 py::arg("value"), py::arg("context") = py::none(),
1884 "Gets an uniqued float point attribute associated to a f64 type");
1885 c.def_property_readonly(
1886 "value",
1887 [](PyFloatAttribute &self) {
1888 return mlirFloatAttrGetValueDouble(self);
1889 },
1890 "Returns the value of the float point attribute");
1891 }
1892 };
1893
1894 /// Integer Attribute subclass - IntegerAttr.
1895 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
1896 public:
1897 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
1898 static constexpr const char *pyClassName = "IntegerAttr";
1899 using PyConcreteAttribute::PyConcreteAttribute;
1900
bindDerived(ClassTy & c)1901 static void bindDerived(ClassTy &c) {
1902 c.def_static(
1903 "get",
1904 [](PyType &type, int64_t value) {
1905 MlirAttribute attr = mlirIntegerAttrGet(type, value);
1906 return PyIntegerAttribute(type.getContext(), attr);
1907 },
1908 py::arg("type"), py::arg("value"),
1909 "Gets an uniqued integer attribute associated to a type");
1910 c.def_property_readonly(
1911 "value",
1912 [](PyIntegerAttribute &self) {
1913 return mlirIntegerAttrGetValueInt(self);
1914 },
1915 "Returns the value of the integer attribute");
1916 }
1917 };
1918
1919 /// Bool Attribute subclass - BoolAttr.
1920 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
1921 public:
1922 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
1923 static constexpr const char *pyClassName = "BoolAttr";
1924 using PyConcreteAttribute::PyConcreteAttribute;
1925
bindDerived(ClassTy & c)1926 static void bindDerived(ClassTy &c) {
1927 c.def_static(
1928 "get",
1929 [](bool value, DefaultingPyMlirContext context) {
1930 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
1931 return PyBoolAttribute(context->getRef(), attr);
1932 },
1933 py::arg("value"), py::arg("context") = py::none(),
1934 "Gets an uniqued bool attribute");
1935 c.def_property_readonly(
1936 "value",
1937 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
1938 "Returns the value of the bool attribute");
1939 }
1940 };
1941
1942 class PyFlatSymbolRefAttribute
1943 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
1944 public:
1945 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
1946 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
1947 using PyConcreteAttribute::PyConcreteAttribute;
1948
bindDerived(ClassTy & c)1949 static void bindDerived(ClassTy &c) {
1950 c.def_static(
1951 "get",
1952 [](std::string value, DefaultingPyMlirContext context) {
1953 MlirAttribute attr =
1954 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
1955 return PyFlatSymbolRefAttribute(context->getRef(), attr);
1956 },
1957 py::arg("value"), py::arg("context") = py::none(),
1958 "Gets a uniqued FlatSymbolRef attribute");
1959 c.def_property_readonly(
1960 "value",
1961 [](PyFlatSymbolRefAttribute &self) {
1962 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
1963 return py::str(stringRef.data, stringRef.length);
1964 },
1965 "Returns the value of the FlatSymbolRef attribute as a string");
1966 }
1967 };
1968
1969 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
1970 public:
1971 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
1972 static constexpr const char *pyClassName = "StringAttr";
1973 using PyConcreteAttribute::PyConcreteAttribute;
1974
bindDerived(ClassTy & c)1975 static void bindDerived(ClassTy &c) {
1976 c.def_static(
1977 "get",
1978 [](std::string value, DefaultingPyMlirContext context) {
1979 MlirAttribute attr =
1980 mlirStringAttrGet(context->get(), toMlirStringRef(value));
1981 return PyStringAttribute(context->getRef(), attr);
1982 },
1983 py::arg("value"), py::arg("context") = py::none(),
1984 "Gets a uniqued string attribute");
1985 c.def_static(
1986 "get_typed",
1987 [](PyType &type, std::string value) {
1988 MlirAttribute attr =
1989 mlirStringAttrTypedGet(type, toMlirStringRef(value));
1990 return PyStringAttribute(type.getContext(), attr);
1991 },
1992
1993 "Gets a uniqued string attribute associated to a type");
1994 c.def_property_readonly(
1995 "value",
1996 [](PyStringAttribute &self) {
1997 MlirStringRef stringRef = mlirStringAttrGetValue(self);
1998 return py::str(stringRef.data, stringRef.length);
1999 },
2000 "Returns the value of the string attribute");
2001 }
2002 };
2003
2004 // TODO: Support construction of bool elements.
2005 // TODO: Support construction of string elements.
2006 class PyDenseElementsAttribute
2007 : public PyConcreteAttribute<PyDenseElementsAttribute> {
2008 public:
2009 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
2010 static constexpr const char *pyClassName = "DenseElementsAttr";
2011 using PyConcreteAttribute::PyConcreteAttribute;
2012
2013 static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,DefaultingPyMlirContext contextWrapper)2014 getFromBuffer(py::buffer array, bool signless,
2015 DefaultingPyMlirContext contextWrapper) {
2016 // Request a contiguous view. In exotic cases, this will cause a copy.
2017 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
2018 Py_buffer *view = new Py_buffer();
2019 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
2020 delete view;
2021 throw py::error_already_set();
2022 }
2023 py::buffer_info arrayInfo(view);
2024
2025 MlirContext context = contextWrapper->get();
2026 // Switch on the types that can be bulk loaded between the Python and
2027 // MLIR-C APIs.
2028 // See: https://docs.python.org/3/library/struct.html#format-characters
2029 if (arrayInfo.format == "f") {
2030 // f32
2031 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
2032 return PyDenseElementsAttribute(
2033 contextWrapper->getRef(),
2034 bulkLoad(context, mlirDenseElementsAttrFloatGet,
2035 mlirF32TypeGet(context), arrayInfo));
2036 } else if (arrayInfo.format == "d") {
2037 // f64
2038 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
2039 return PyDenseElementsAttribute(
2040 contextWrapper->getRef(),
2041 bulkLoad(context, mlirDenseElementsAttrDoubleGet,
2042 mlirF64TypeGet(context), arrayInfo));
2043 } else if (isSignedIntegerFormat(arrayInfo.format)) {
2044 if (arrayInfo.itemsize == 4) {
2045 // i32
2046 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
2047 : mlirIntegerTypeSignedGet(context, 32);
2048 return PyDenseElementsAttribute(contextWrapper->getRef(),
2049 bulkLoad(context,
2050 mlirDenseElementsAttrInt32Get,
2051 elementType, arrayInfo));
2052 } else if (arrayInfo.itemsize == 8) {
2053 // i64
2054 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
2055 : mlirIntegerTypeSignedGet(context, 64);
2056 return PyDenseElementsAttribute(contextWrapper->getRef(),
2057 bulkLoad(context,
2058 mlirDenseElementsAttrInt64Get,
2059 elementType, arrayInfo));
2060 }
2061 } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
2062 if (arrayInfo.itemsize == 4) {
2063 // unsigned i32
2064 MlirType elementType = signless
2065 ? mlirIntegerTypeGet(context, 32)
2066 : mlirIntegerTypeUnsignedGet(context, 32);
2067 return PyDenseElementsAttribute(contextWrapper->getRef(),
2068 bulkLoad(context,
2069 mlirDenseElementsAttrUInt32Get,
2070 elementType, arrayInfo));
2071 } else if (arrayInfo.itemsize == 8) {
2072 // unsigned i64
2073 MlirType elementType = signless
2074 ? mlirIntegerTypeGet(context, 64)
2075 : mlirIntegerTypeUnsignedGet(context, 64);
2076 return PyDenseElementsAttribute(contextWrapper->getRef(),
2077 bulkLoad(context,
2078 mlirDenseElementsAttrUInt64Get,
2079 elementType, arrayInfo));
2080 }
2081 }
2082
2083 // TODO: Fall back to string-based get.
2084 std::string message = "unimplemented array format conversion from format: ";
2085 message.append(arrayInfo.format);
2086 throw SetPyError(PyExc_ValueError, message);
2087 }
2088
getSplat(PyType shapedType,PyAttribute & elementAttr)2089 static PyDenseElementsAttribute getSplat(PyType shapedType,
2090 PyAttribute &elementAttr) {
2091 auto contextWrapper =
2092 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
2093 if (!mlirAttributeIsAInteger(elementAttr) &&
2094 !mlirAttributeIsAFloat(elementAttr)) {
2095 std::string message = "Illegal element type for DenseElementsAttr: ";
2096 message.append(py::repr(py::cast(elementAttr)));
2097 throw SetPyError(PyExc_ValueError, message);
2098 }
2099 if (!mlirTypeIsAShaped(shapedType) ||
2100 !mlirShapedTypeHasStaticShape(shapedType)) {
2101 std::string message =
2102 "Expected a static ShapedType for the shaped_type parameter: ";
2103 message.append(py::repr(py::cast(shapedType)));
2104 throw SetPyError(PyExc_ValueError, message);
2105 }
2106 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
2107 MlirType attrType = mlirAttributeGetType(elementAttr);
2108 if (!mlirTypeEqual(shapedElementType, attrType)) {
2109 std::string message =
2110 "Shaped element type and attribute type must be equal: shaped=";
2111 message.append(py::repr(py::cast(shapedType)));
2112 message.append(", element=");
2113 message.append(py::repr(py::cast(elementAttr)));
2114 throw SetPyError(PyExc_ValueError, message);
2115 }
2116
2117 MlirAttribute elements =
2118 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
2119 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
2120 }
2121
dunderLen()2122 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
2123
accessBuffer()2124 py::buffer_info accessBuffer() {
2125 MlirType shapedType = mlirAttributeGetType(*this);
2126 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
2127
2128 if (mlirTypeIsAF32(elementType)) {
2129 // f32
2130 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
2131 } else if (mlirTypeIsAF64(elementType)) {
2132 // f64
2133 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
2134 } else if (mlirTypeIsAInteger(elementType) &&
2135 mlirIntegerTypeGetWidth(elementType) == 32) {
2136 if (mlirIntegerTypeIsSignless(elementType) ||
2137 mlirIntegerTypeIsSigned(elementType)) {
2138 // i32
2139 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
2140 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
2141 // unsigned i32
2142 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
2143 }
2144 } else if (mlirTypeIsAInteger(elementType) &&
2145 mlirIntegerTypeGetWidth(elementType) == 64) {
2146 if (mlirIntegerTypeIsSignless(elementType) ||
2147 mlirIntegerTypeIsSigned(elementType)) {
2148 // i64
2149 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
2150 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
2151 // unsigned i64
2152 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
2153 }
2154 }
2155
2156 std::string message = "unimplemented array format.";
2157 throw SetPyError(PyExc_ValueError, message);
2158 }
2159
bindDerived(ClassTy & c)2160 static void bindDerived(ClassTy &c) {
2161 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
2162 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
2163 py::arg("array"), py::arg("signless") = true,
2164 py::arg("context") = py::none(),
2165 "Gets from a buffer or ndarray")
2166 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
2167 py::arg("shaped_type"), py::arg("element_attr"),
2168 "Gets a DenseElementsAttr where all values are the same")
2169 .def_property_readonly("is_splat",
2170 [](PyDenseElementsAttribute &self) -> bool {
2171 return mlirDenseElementsAttrIsSplat(self);
2172 })
2173 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
2174 }
2175
2176 private:
2177 template <typename ElementTy>
2178 static MlirAttribute
bulkLoad(MlirContext context,MlirAttribute (* ctor)(MlirType,intptr_t,ElementTy *),MlirType mlirElementType,py::buffer_info & arrayInfo)2179 bulkLoad(MlirContext context,
2180 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
2181 MlirType mlirElementType, py::buffer_info &arrayInfo) {
2182 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
2183 arrayInfo.shape.begin() + arrayInfo.ndim);
2184 auto shapedType =
2185 mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
2186 intptr_t numElements = arrayInfo.size;
2187 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
2188 return ctor(shapedType, numElements, contents);
2189 }
2190
isUnsignedIntegerFormat(const std::string & format)2191 static bool isUnsignedIntegerFormat(const std::string &format) {
2192 if (format.empty())
2193 return false;
2194 char code = format[0];
2195 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
2196 code == 'Q';
2197 }
2198
isSignedIntegerFormat(const std::string & format)2199 static bool isSignedIntegerFormat(const std::string &format) {
2200 if (format.empty())
2201 return false;
2202 char code = format[0];
2203 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
2204 code == 'q';
2205 }
2206
2207 template <typename Type>
bufferInfo(MlirType shapedType,Type (* value)(MlirAttribute,intptr_t))2208 py::buffer_info bufferInfo(MlirType shapedType,
2209 Type (*value)(MlirAttribute, intptr_t)) {
2210 intptr_t rank = mlirShapedTypeGetRank(shapedType);
2211 // Prepare the data for the buffer_info.
2212 // Buffer is configured for read-only access below.
2213 Type *data = static_cast<Type *>(
2214 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
2215 // Prepare the shape for the buffer_info.
2216 SmallVector<intptr_t, 4> shape;
2217 for (intptr_t i = 0; i < rank; ++i)
2218 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
2219 // Prepare the strides for the buffer_info.
2220 SmallVector<intptr_t, 4> strides;
2221 intptr_t strideFactor = 1;
2222 for (intptr_t i = 1; i < rank; ++i) {
2223 strideFactor = 1;
2224 for (intptr_t j = i; j < rank; ++j) {
2225 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
2226 }
2227 strides.push_back(sizeof(Type) * strideFactor);
2228 }
2229 strides.push_back(sizeof(Type));
2230 return py::buffer_info(data, sizeof(Type),
2231 py::format_descriptor<Type>::format(), rank, shape,
2232 strides, /*readonly=*/true);
2233 }
2234 }; // namespace
2235
2236 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
2237 /// (and boolean) values. Supports element access.
2238 class PyDenseIntElementsAttribute
2239 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
2240 PyDenseElementsAttribute> {
2241 public:
2242 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
2243 static constexpr const char *pyClassName = "DenseIntElementsAttr";
2244 using PyConcreteAttribute::PyConcreteAttribute;
2245
2246 /// Returns the element at the given linear position. Asserts if the index is
2247 /// out of range.
dunderGetItem(intptr_t pos)2248 py::int_ dunderGetItem(intptr_t pos) {
2249 if (pos < 0 || pos >= dunderLen()) {
2250 throw SetPyError(PyExc_IndexError,
2251 "attempt to access out of bounds element");
2252 }
2253
2254 MlirType type = mlirAttributeGetType(*this);
2255 type = mlirShapedTypeGetElementType(type);
2256 assert(mlirTypeIsAInteger(type) &&
2257 "expected integer element type in dense int elements attribute");
2258 // Dispatch element extraction to an appropriate C function based on the
2259 // elemental type of the attribute. py::int_ is implicitly constructible
2260 // from any C++ integral type and handles bitwidth correctly.
2261 // TODO: consider caching the type properties in the constructor to avoid
2262 // querying them on each element access.
2263 unsigned width = mlirIntegerTypeGetWidth(type);
2264 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
2265 if (isUnsigned) {
2266 if (width == 1) {
2267 return mlirDenseElementsAttrGetBoolValue(*this, pos);
2268 }
2269 if (width == 32) {
2270 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
2271 }
2272 if (width == 64) {
2273 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
2274 }
2275 } else {
2276 if (width == 1) {
2277 return mlirDenseElementsAttrGetBoolValue(*this, pos);
2278 }
2279 if (width == 32) {
2280 return mlirDenseElementsAttrGetInt32Value(*this, pos);
2281 }
2282 if (width == 64) {
2283 return mlirDenseElementsAttrGetInt64Value(*this, pos);
2284 }
2285 }
2286 throw SetPyError(PyExc_TypeError, "Unsupported integer type");
2287 }
2288
bindDerived(ClassTy & c)2289 static void bindDerived(ClassTy &c) {
2290 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
2291 }
2292 };
2293
2294 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
2295 public:
2296 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
2297 static constexpr const char *pyClassName = "DictAttr";
2298 using PyConcreteAttribute::PyConcreteAttribute;
2299
dunderLen()2300 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
2301
bindDerived(ClassTy & c)2302 static void bindDerived(ClassTy &c) {
2303 c.def("__len__", &PyDictAttribute::dunderLen);
2304 c.def_static(
2305 "get",
2306 [](py::dict attributes, DefaultingPyMlirContext context) {
2307 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
2308 mlirNamedAttributes.reserve(attributes.size());
2309 for (auto &it : attributes) {
2310 auto &mlir_attr = it.second.cast<PyAttribute &>();
2311 auto name = it.first.cast<std::string>();
2312 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
2313 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
2314 toMlirStringRef(name)),
2315 mlir_attr));
2316 }
2317 MlirAttribute attr =
2318 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
2319 mlirNamedAttributes.data());
2320 return PyDictAttribute(context->getRef(), attr);
2321 },
2322 py::arg("value"), py::arg("context") = py::none(),
2323 "Gets an uniqued dict attribute");
2324 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
2325 MlirAttribute attr =
2326 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
2327 if (mlirAttributeIsNull(attr)) {
2328 throw SetPyError(PyExc_KeyError,
2329 "attempt to access a non-existent attribute");
2330 }
2331 return PyAttribute(self.getContext(), attr);
2332 });
2333 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
2334 if (index < 0 || index >= self.dunderLen()) {
2335 throw SetPyError(PyExc_IndexError,
2336 "attempt to access out of bounds attribute");
2337 }
2338 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
2339 return PyNamedAttribute(
2340 namedAttr.attribute,
2341 std::string(mlirIdentifierStr(namedAttr.name).data));
2342 });
2343 }
2344 };
2345
2346 /// Refinement of PyDenseElementsAttribute for attributes containing
2347 /// floating-point values. Supports element access.
2348 class PyDenseFPElementsAttribute
2349 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
2350 PyDenseElementsAttribute> {
2351 public:
2352 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
2353 static constexpr const char *pyClassName = "DenseFPElementsAttr";
2354 using PyConcreteAttribute::PyConcreteAttribute;
2355
dunderGetItem(intptr_t pos)2356 py::float_ dunderGetItem(intptr_t pos) {
2357 if (pos < 0 || pos >= dunderLen()) {
2358 throw SetPyError(PyExc_IndexError,
2359 "attempt to access out of bounds element");
2360 }
2361
2362 MlirType type = mlirAttributeGetType(*this);
2363 type = mlirShapedTypeGetElementType(type);
2364 // Dispatch element extraction to an appropriate C function based on the
2365 // elemental type of the attribute. py::float_ is implicitly constructible
2366 // from float and double.
2367 // TODO: consider caching the type properties in the constructor to avoid
2368 // querying them on each element access.
2369 if (mlirTypeIsAF32(type)) {
2370 return mlirDenseElementsAttrGetFloatValue(*this, pos);
2371 }
2372 if (mlirTypeIsAF64(type)) {
2373 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
2374 }
2375 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
2376 }
2377
bindDerived(ClassTy & c)2378 static void bindDerived(ClassTy &c) {
2379 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
2380 }
2381 };
2382
2383 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
2384 public:
2385 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
2386 static constexpr const char *pyClassName = "TypeAttr";
2387 using PyConcreteAttribute::PyConcreteAttribute;
2388
bindDerived(ClassTy & c)2389 static void bindDerived(ClassTy &c) {
2390 c.def_static(
2391 "get",
2392 [](PyType value, DefaultingPyMlirContext context) {
2393 MlirAttribute attr = mlirTypeAttrGet(value.get());
2394 return PyTypeAttribute(context->getRef(), attr);
2395 },
2396 py::arg("value"), py::arg("context") = py::none(),
2397 "Gets a uniqued Type attribute");
2398 c.def_property_readonly("value", [](PyTypeAttribute &self) {
2399 return PyType(self.getContext()->getRef(),
2400 mlirTypeAttrGetValue(self.get()));
2401 });
2402 }
2403 };
2404
2405 /// Unit Attribute subclass. Unit attributes don't have values.
2406 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
2407 public:
2408 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
2409 static constexpr const char *pyClassName = "UnitAttr";
2410 using PyConcreteAttribute::PyConcreteAttribute;
2411
bindDerived(ClassTy & c)2412 static void bindDerived(ClassTy &c) {
2413 c.def_static(
2414 "get",
2415 [](DefaultingPyMlirContext context) {
2416 return PyUnitAttribute(context->getRef(),
2417 mlirUnitAttrGet(context->get()));
2418 },
2419 py::arg("context") = py::none(), "Create a Unit attribute.");
2420 }
2421 };
2422
2423 } // namespace
2424
2425 //------------------------------------------------------------------------------
2426 // Builtin type subclasses.
2427 //------------------------------------------------------------------------------
2428
2429 namespace {
2430
2431 /// CRTP base classes for Python types that subclass Type and should be
2432 /// castable from it (i.e. via something like IntegerType(t)).
2433 /// By default, type class hierarchies are one level deep (i.e. a
2434 /// concrete type class extends PyType); however, intermediate python-visible
2435 /// base classes can be modeled by specifying a BaseTy.
2436 template <typename DerivedTy, typename BaseTy = PyType>
2437 class PyConcreteType : public BaseTy {
2438 public:
2439 // Derived classes must define statics for:
2440 // IsAFunctionTy isaFunction
2441 // const char *pyClassName
2442 using ClassTy = py::class_<DerivedTy, BaseTy>;
2443 using IsAFunctionTy = bool (*)(MlirType);
2444
2445 PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)2446 PyConcreteType(PyMlirContextRef contextRef, MlirType t)
2447 : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)2448 PyConcreteType(PyType &orig)
2449 : PyConcreteType(orig.getContext(), castFrom(orig)) {}
2450
castFrom(PyType & orig)2451 static MlirType castFrom(PyType &orig) {
2452 if (!DerivedTy::isaFunction(orig)) {
2453 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2454 throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
2455 DerivedTy::pyClassName +
2456 " (from " + origRepr + ")");
2457 }
2458 return orig;
2459 }
2460
bind(py::module & m)2461 static void bind(py::module &m) {
2462 auto cls = ClassTy(m, DerivedTy::pyClassName);
2463 cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
2464 DerivedTy::bindDerived(cls);
2465 }
2466
2467 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)2468 static void bindDerived(ClassTy &m) {}
2469 };
2470
2471 class PyIntegerType : public PyConcreteType<PyIntegerType> {
2472 public:
2473 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
2474 static constexpr const char *pyClassName = "IntegerType";
2475 using PyConcreteType::PyConcreteType;
2476
bindDerived(ClassTy & c)2477 static void bindDerived(ClassTy &c) {
2478 c.def_static(
2479 "get_signless",
2480 [](unsigned width, DefaultingPyMlirContext context) {
2481 MlirType t = mlirIntegerTypeGet(context->get(), width);
2482 return PyIntegerType(context->getRef(), t);
2483 },
2484 py::arg("width"), py::arg("context") = py::none(),
2485 "Create a signless integer type");
2486 c.def_static(
2487 "get_signed",
2488 [](unsigned width, DefaultingPyMlirContext context) {
2489 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
2490 return PyIntegerType(context->getRef(), t);
2491 },
2492 py::arg("width"), py::arg("context") = py::none(),
2493 "Create a signed integer type");
2494 c.def_static(
2495 "get_unsigned",
2496 [](unsigned width, DefaultingPyMlirContext context) {
2497 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
2498 return PyIntegerType(context->getRef(), t);
2499 },
2500 py::arg("width"), py::arg("context") = py::none(),
2501 "Create an unsigned integer type");
2502 c.def_property_readonly(
2503 "width",
2504 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
2505 "Returns the width of the integer type");
2506 c.def_property_readonly(
2507 "is_signless",
2508 [](PyIntegerType &self) -> bool {
2509 return mlirIntegerTypeIsSignless(self);
2510 },
2511 "Returns whether this is a signless integer");
2512 c.def_property_readonly(
2513 "is_signed",
2514 [](PyIntegerType &self) -> bool {
2515 return mlirIntegerTypeIsSigned(self);
2516 },
2517 "Returns whether this is a signed integer");
2518 c.def_property_readonly(
2519 "is_unsigned",
2520 [](PyIntegerType &self) -> bool {
2521 return mlirIntegerTypeIsUnsigned(self);
2522 },
2523 "Returns whether this is an unsigned integer");
2524 }
2525 };
2526
2527 /// Index Type subclass - IndexType.
2528 class PyIndexType : public PyConcreteType<PyIndexType> {
2529 public:
2530 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
2531 static constexpr const char *pyClassName = "IndexType";
2532 using PyConcreteType::PyConcreteType;
2533
bindDerived(ClassTy & c)2534 static void bindDerived(ClassTy &c) {
2535 c.def_static(
2536 "get",
2537 [](DefaultingPyMlirContext context) {
2538 MlirType t = mlirIndexTypeGet(context->get());
2539 return PyIndexType(context->getRef(), t);
2540 },
2541 py::arg("context") = py::none(), "Create a index type.");
2542 }
2543 };
2544
2545 /// Floating Point Type subclass - BF16Type.
2546 class PyBF16Type : public PyConcreteType<PyBF16Type> {
2547 public:
2548 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
2549 static constexpr const char *pyClassName = "BF16Type";
2550 using PyConcreteType::PyConcreteType;
2551
bindDerived(ClassTy & c)2552 static void bindDerived(ClassTy &c) {
2553 c.def_static(
2554 "get",
2555 [](DefaultingPyMlirContext context) {
2556 MlirType t = mlirBF16TypeGet(context->get());
2557 return PyBF16Type(context->getRef(), t);
2558 },
2559 py::arg("context") = py::none(), "Create a bf16 type.");
2560 }
2561 };
2562
2563 /// Floating Point Type subclass - F16Type.
2564 class PyF16Type : public PyConcreteType<PyF16Type> {
2565 public:
2566 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
2567 static constexpr const char *pyClassName = "F16Type";
2568 using PyConcreteType::PyConcreteType;
2569
bindDerived(ClassTy & c)2570 static void bindDerived(ClassTy &c) {
2571 c.def_static(
2572 "get",
2573 [](DefaultingPyMlirContext context) {
2574 MlirType t = mlirF16TypeGet(context->get());
2575 return PyF16Type(context->getRef(), t);
2576 },
2577 py::arg("context") = py::none(), "Create a f16 type.");
2578 }
2579 };
2580
2581 /// Floating Point Type subclass - F32Type.
2582 class PyF32Type : public PyConcreteType<PyF32Type> {
2583 public:
2584 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
2585 static constexpr const char *pyClassName = "F32Type";
2586 using PyConcreteType::PyConcreteType;
2587
bindDerived(ClassTy & c)2588 static void bindDerived(ClassTy &c) {
2589 c.def_static(
2590 "get",
2591 [](DefaultingPyMlirContext context) {
2592 MlirType t = mlirF32TypeGet(context->get());
2593 return PyF32Type(context->getRef(), t);
2594 },
2595 py::arg("context") = py::none(), "Create a f32 type.");
2596 }
2597 };
2598
2599 /// Floating Point Type subclass - F64Type.
2600 class PyF64Type : public PyConcreteType<PyF64Type> {
2601 public:
2602 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
2603 static constexpr const char *pyClassName = "F64Type";
2604 using PyConcreteType::PyConcreteType;
2605
bindDerived(ClassTy & c)2606 static void bindDerived(ClassTy &c) {
2607 c.def_static(
2608 "get",
2609 [](DefaultingPyMlirContext context) {
2610 MlirType t = mlirF64TypeGet(context->get());
2611 return PyF64Type(context->getRef(), t);
2612 },
2613 py::arg("context") = py::none(), "Create a f64 type.");
2614 }
2615 };
2616
2617 /// None Type subclass - NoneType.
2618 class PyNoneType : public PyConcreteType<PyNoneType> {
2619 public:
2620 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
2621 static constexpr const char *pyClassName = "NoneType";
2622 using PyConcreteType::PyConcreteType;
2623
bindDerived(ClassTy & c)2624 static void bindDerived(ClassTy &c) {
2625 c.def_static(
2626 "get",
2627 [](DefaultingPyMlirContext context) {
2628 MlirType t = mlirNoneTypeGet(context->get());
2629 return PyNoneType(context->getRef(), t);
2630 },
2631 py::arg("context") = py::none(), "Create a none type.");
2632 }
2633 };
2634
2635 /// Complex Type subclass - ComplexType.
2636 class PyComplexType : public PyConcreteType<PyComplexType> {
2637 public:
2638 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
2639 static constexpr const char *pyClassName = "ComplexType";
2640 using PyConcreteType::PyConcreteType;
2641
bindDerived(ClassTy & c)2642 static void bindDerived(ClassTy &c) {
2643 c.def_static(
2644 "get",
2645 [](PyType &elementType) {
2646 // The element must be a floating point or integer scalar type.
2647 if (mlirTypeIsAIntegerOrFloat(elementType)) {
2648 MlirType t = mlirComplexTypeGet(elementType);
2649 return PyComplexType(elementType.getContext(), t);
2650 }
2651 throw SetPyError(
2652 PyExc_ValueError,
2653 Twine("invalid '") +
2654 py::repr(py::cast(elementType)).cast<std::string>() +
2655 "' and expected floating point or integer type.");
2656 },
2657 "Create a complex type");
2658 c.def_property_readonly(
2659 "element_type",
2660 [](PyComplexType &self) -> PyType {
2661 MlirType t = mlirComplexTypeGetElementType(self);
2662 return PyType(self.getContext(), t);
2663 },
2664 "Returns element type.");
2665 }
2666 };
2667
2668 class PyShapedType : public PyConcreteType<PyShapedType> {
2669 public:
2670 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
2671 static constexpr const char *pyClassName = "ShapedType";
2672 using PyConcreteType::PyConcreteType;
2673
bindDerived(ClassTy & c)2674 static void bindDerived(ClassTy &c) {
2675 c.def_property_readonly(
2676 "element_type",
2677 [](PyShapedType &self) {
2678 MlirType t = mlirShapedTypeGetElementType(self);
2679 return PyType(self.getContext(), t);
2680 },
2681 "Returns the element type of the shaped type.");
2682 c.def_property_readonly(
2683 "has_rank",
2684 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
2685 "Returns whether the given shaped type is ranked.");
2686 c.def_property_readonly(
2687 "rank",
2688 [](PyShapedType &self) {
2689 self.requireHasRank();
2690 return mlirShapedTypeGetRank(self);
2691 },
2692 "Returns the rank of the given ranked shaped type.");
2693 c.def_property_readonly(
2694 "has_static_shape",
2695 [](PyShapedType &self) -> bool {
2696 return mlirShapedTypeHasStaticShape(self);
2697 },
2698 "Returns whether the given shaped type has a static shape.");
2699 c.def(
2700 "is_dynamic_dim",
2701 [](PyShapedType &self, intptr_t dim) -> bool {
2702 self.requireHasRank();
2703 return mlirShapedTypeIsDynamicDim(self, dim);
2704 },
2705 "Returns whether the dim-th dimension of the given shaped type is "
2706 "dynamic.");
2707 c.def(
2708 "get_dim_size",
2709 [](PyShapedType &self, intptr_t dim) {
2710 self.requireHasRank();
2711 return mlirShapedTypeGetDimSize(self, dim);
2712 },
2713 "Returns the dim-th dimension of the given ranked shaped type.");
2714 c.def_static(
2715 "is_dynamic_size",
2716 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
2717 "Returns whether the given dimension size indicates a dynamic "
2718 "dimension.");
2719 c.def(
2720 "is_dynamic_stride_or_offset",
2721 [](PyShapedType &self, int64_t val) -> bool {
2722 self.requireHasRank();
2723 return mlirShapedTypeIsDynamicStrideOrOffset(val);
2724 },
2725 "Returns whether the given value is used as a placeholder for dynamic "
2726 "strides and offsets in shaped types.");
2727 }
2728
2729 private:
requireHasRank()2730 void requireHasRank() {
2731 if (!mlirShapedTypeHasRank(*this)) {
2732 throw SetPyError(
2733 PyExc_ValueError,
2734 "calling this method requires that the type has a rank.");
2735 }
2736 }
2737 };
2738
2739 /// Vector Type subclass - VectorType.
2740 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
2741 public:
2742 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
2743 static constexpr const char *pyClassName = "VectorType";
2744 using PyConcreteType::PyConcreteType;
2745
bindDerived(ClassTy & c)2746 static void bindDerived(ClassTy &c) {
2747 c.def_static(
2748 "get",
2749 [](std::vector<int64_t> shape, PyType &elementType,
2750 DefaultingPyLocation loc) {
2751 MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
2752 elementType, loc);
2753 // TODO: Rework error reporting once diagnostic engine is exposed
2754 // in C API.
2755 if (mlirTypeIsNull(t)) {
2756 throw SetPyError(
2757 PyExc_ValueError,
2758 Twine("invalid '") +
2759 py::repr(py::cast(elementType)).cast<std::string>() +
2760 "' and expected floating point or integer type.");
2761 }
2762 return PyVectorType(elementType.getContext(), t);
2763 },
2764 py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
2765 "Create a vector type");
2766 }
2767 };
2768
2769 /// Ranked Tensor Type subclass - RankedTensorType.
2770 class PyRankedTensorType
2771 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
2772 public:
2773 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2774 static constexpr const char *pyClassName = "RankedTensorType";
2775 using PyConcreteType::PyConcreteType;
2776
bindDerived(ClassTy & c)2777 static void bindDerived(ClassTy &c) {
2778 c.def_static(
2779 "get",
2780 [](std::vector<int64_t> shape, PyType &elementType,
2781 DefaultingPyLocation loc) {
2782 MlirType t = mlirRankedTensorTypeGetChecked(
2783 shape.size(), shape.data(), elementType, loc);
2784 // TODO: Rework error reporting once diagnostic engine is exposed
2785 // in C API.
2786 if (mlirTypeIsNull(t)) {
2787 throw SetPyError(
2788 PyExc_ValueError,
2789 Twine("invalid '") +
2790 py::repr(py::cast(elementType)).cast<std::string>() +
2791 "' and expected floating point, integer, vector or "
2792 "complex "
2793 "type.");
2794 }
2795 return PyRankedTensorType(elementType.getContext(), t);
2796 },
2797 py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
2798 "Create a ranked tensor type");
2799 }
2800 };
2801
2802 /// Unranked Tensor Type subclass - UnrankedTensorType.
2803 class PyUnrankedTensorType
2804 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
2805 public:
2806 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
2807 static constexpr const char *pyClassName = "UnrankedTensorType";
2808 using PyConcreteType::PyConcreteType;
2809
bindDerived(ClassTy & c)2810 static void bindDerived(ClassTy &c) {
2811 c.def_static(
2812 "get",
2813 [](PyType &elementType, DefaultingPyLocation loc) {
2814 MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
2815 // TODO: Rework error reporting once diagnostic engine is exposed
2816 // in C API.
2817 if (mlirTypeIsNull(t)) {
2818 throw SetPyError(
2819 PyExc_ValueError,
2820 Twine("invalid '") +
2821 py::repr(py::cast(elementType)).cast<std::string>() +
2822 "' and expected floating point, integer, vector or "
2823 "complex "
2824 "type.");
2825 }
2826 return PyUnrankedTensorType(elementType.getContext(), t);
2827 },
2828 py::arg("element_type"), py::arg("loc") = py::none(),
2829 "Create a unranked tensor type");
2830 }
2831 };
2832
2833 class PyMemRefLayoutMapList;
2834
2835 /// Ranked MemRef Type subclass - MemRefType.
2836 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
2837 public:
2838 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2839 static constexpr const char *pyClassName = "MemRefType";
2840 using PyConcreteType::PyConcreteType;
2841
2842 PyMemRefLayoutMapList getLayout();
2843
bindDerived(ClassTy & c)2844 static void bindDerived(ClassTy &c) {
2845 c.def_static(
2846 "get",
2847 [](std::vector<int64_t> shape, PyType &elementType,
2848 std::vector<PyAffineMap> layout, unsigned memorySpace,
2849 DefaultingPyLocation loc) {
2850 SmallVector<MlirAffineMap> maps;
2851 maps.reserve(layout.size());
2852 for (PyAffineMap &map : layout)
2853 maps.push_back(map);
2854
2855 MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
2856 shape.data(), maps.size(),
2857 maps.data(), memorySpace, loc);
2858 // TODO: Rework error reporting once diagnostic engine is exposed
2859 // in C API.
2860 if (mlirTypeIsNull(t)) {
2861 throw SetPyError(
2862 PyExc_ValueError,
2863 Twine("invalid '") +
2864 py::repr(py::cast(elementType)).cast<std::string>() +
2865 "' and expected floating point, integer, vector or "
2866 "complex "
2867 "type.");
2868 }
2869 return PyMemRefType(elementType.getContext(), t);
2870 },
2871 py::arg("shape"), py::arg("element_type"),
2872 py::arg("layout") = py::list(), py::arg("memory_space") = 0,
2873 py::arg("loc") = py::none(), "Create a memref type")
2874 .def_property_readonly("layout", &PyMemRefType::getLayout,
2875 "The list of layout maps of the MemRef type.")
2876 .def_property_readonly(
2877 "memory_space",
2878 [](PyMemRefType &self) -> unsigned {
2879 return mlirMemRefTypeGetMemorySpace(self);
2880 },
2881 "Returns the memory space of the given MemRef type.");
2882 }
2883 };
2884
2885 /// A list of affine layout maps in a memref type. Internally, these are stored
2886 /// as consecutive elements, random access is cheap. Both the type and the maps
2887 /// are owned by the context, no need to worry about lifetime extension.
2888 class PyMemRefLayoutMapList
2889 : public Sliceable<PyMemRefLayoutMapList, PyAffineMap> {
2890 public:
2891 static constexpr const char *pyClassName = "MemRefLayoutMapList";
2892
PyMemRefLayoutMapList(PyMemRefType type,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2893 PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0,
2894 intptr_t length = -1, intptr_t step = 1)
2895 : Sliceable(startIndex,
2896 length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length,
2897 step),
2898 memref(type) {}
2899
getNumElements()2900 intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); }
2901
getElement(intptr_t index)2902 PyAffineMap getElement(intptr_t index) {
2903 return PyAffineMap(memref.getContext(),
2904 mlirMemRefTypeGetAffineMap(memref, index));
2905 }
2906
slice(intptr_t startIndex,intptr_t length,intptr_t step)2907 PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length,
2908 intptr_t step) {
2909 return PyMemRefLayoutMapList(memref, startIndex, length, step);
2910 }
2911
2912 private:
2913 PyMemRefType memref;
2914 };
2915
getLayout()2916 PyMemRefLayoutMapList PyMemRefType::getLayout() {
2917 return PyMemRefLayoutMapList(*this);
2918 }
2919
2920 /// Unranked MemRef Type subclass - UnrankedMemRefType.
2921 class PyUnrankedMemRefType
2922 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
2923 public:
2924 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
2925 static constexpr const char *pyClassName = "UnrankedMemRefType";
2926 using PyConcreteType::PyConcreteType;
2927
bindDerived(ClassTy & c)2928 static void bindDerived(ClassTy &c) {
2929 c.def_static(
2930 "get",
2931 [](PyType &elementType, unsigned memorySpace,
2932 DefaultingPyLocation loc) {
2933 MlirType t =
2934 mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
2935 // TODO: Rework error reporting once diagnostic engine is exposed
2936 // in C API.
2937 if (mlirTypeIsNull(t)) {
2938 throw SetPyError(
2939 PyExc_ValueError,
2940 Twine("invalid '") +
2941 py::repr(py::cast(elementType)).cast<std::string>() +
2942 "' and expected floating point, integer, vector or "
2943 "complex "
2944 "type.");
2945 }
2946 return PyUnrankedMemRefType(elementType.getContext(), t);
2947 },
2948 py::arg("element_type"), py::arg("memory_space"),
2949 py::arg("loc") = py::none(), "Create a unranked memref type")
2950 .def_property_readonly(
2951 "memory_space",
2952 [](PyUnrankedMemRefType &self) -> unsigned {
2953 return mlirUnrankedMemrefGetMemorySpace(self);
2954 },
2955 "Returns the memory space of the given Unranked MemRef type.");
2956 }
2957 };
2958
2959 /// Tuple Type subclass - TupleType.
2960 class PyTupleType : public PyConcreteType<PyTupleType> {
2961 public:
2962 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
2963 static constexpr const char *pyClassName = "TupleType";
2964 using PyConcreteType::PyConcreteType;
2965
bindDerived(ClassTy & c)2966 static void bindDerived(ClassTy &c) {
2967 c.def_static(
2968 "get_tuple",
2969 [](py::list elementList, DefaultingPyMlirContext context) {
2970 intptr_t num = py::len(elementList);
2971 // Mapping py::list to SmallVector.
2972 SmallVector<MlirType, 4> elements;
2973 for (auto element : elementList)
2974 elements.push_back(element.cast<PyType>());
2975 MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
2976 return PyTupleType(context->getRef(), t);
2977 },
2978 py::arg("elements"), py::arg("context") = py::none(),
2979 "Create a tuple type");
2980 c.def(
2981 "get_type",
2982 [](PyTupleType &self, intptr_t pos) -> PyType {
2983 MlirType t = mlirTupleTypeGetType(self, pos);
2984 return PyType(self.getContext(), t);
2985 },
2986 "Returns the pos-th type in the tuple type.");
2987 c.def_property_readonly(
2988 "num_types",
2989 [](PyTupleType &self) -> intptr_t {
2990 return mlirTupleTypeGetNumTypes(self);
2991 },
2992 "Returns the number of types contained in a tuple.");
2993 }
2994 };
2995
2996 /// Function type.
2997 class PyFunctionType : public PyConcreteType<PyFunctionType> {
2998 public:
2999 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
3000 static constexpr const char *pyClassName = "FunctionType";
3001 using PyConcreteType::PyConcreteType;
3002
bindDerived(ClassTy & c)3003 static void bindDerived(ClassTy &c) {
3004 c.def_static(
3005 "get",
3006 [](std::vector<PyType> inputs, std::vector<PyType> results,
3007 DefaultingPyMlirContext context) {
3008 SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
3009 SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
3010 MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
3011 inputsRaw.data(), resultsRaw.size(),
3012 resultsRaw.data());
3013 return PyFunctionType(context->getRef(), t);
3014 },
3015 py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
3016 "Gets a FunctionType from a list of input and result types");
3017 c.def_property_readonly(
3018 "inputs",
3019 [](PyFunctionType &self) {
3020 MlirType t = self;
3021 auto contextRef = self.getContext();
3022 py::list types;
3023 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
3024 ++i) {
3025 types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
3026 }
3027 return types;
3028 },
3029 "Returns the list of input types in the FunctionType.");
3030 c.def_property_readonly(
3031 "results",
3032 [](PyFunctionType &self) {
3033 auto contextRef = self.getContext();
3034 py::list types;
3035 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
3036 ++i) {
3037 types.append(
3038 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
3039 }
3040 return types;
3041 },
3042 "Returns the list of result types in the FunctionType.");
3043 }
3044 };
3045
3046 } // namespace
3047
3048 //------------------------------------------------------------------------------
3049 // PyAffineExpr and subclasses.
3050 //------------------------------------------------------------------------------
3051
3052 namespace {
3053 /// CRTP base class for Python MLIR affine expressions that subclass AffineExpr
3054 /// and should be castable from it. Intermediate hierarchy classes can be
3055 /// modeled by specifying BaseTy.
3056 template <typename DerivedTy, typename BaseTy = PyAffineExpr>
3057 class PyConcreteAffineExpr : public BaseTy {
3058 public:
3059 // Derived classes must define statics for:
3060 // IsAFunctionTy isaFunction
3061 // const char *pyClassName
3062 // and redefine bindDerived.
3063 using ClassTy = py::class_<DerivedTy, BaseTy>;
3064 using IsAFunctionTy = bool (*)(MlirAffineExpr);
3065
3066 PyConcreteAffineExpr() = default;
PyConcreteAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)3067 PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
3068 : BaseTy(std::move(contextRef), affineExpr) {}
PyConcreteAffineExpr(PyAffineExpr & orig)3069 PyConcreteAffineExpr(PyAffineExpr &orig)
3070 : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {}
3071
castFrom(PyAffineExpr & orig)3072 static MlirAffineExpr castFrom(PyAffineExpr &orig) {
3073 if (!DerivedTy::isaFunction(orig)) {
3074 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
3075 throw SetPyError(PyExc_ValueError,
3076 Twine("Cannot cast affine expression to ") +
3077 DerivedTy::pyClassName + " (from " + origRepr + ")");
3078 }
3079 return orig;
3080 }
3081
bind(py::module & m)3082 static void bind(py::module &m) {
3083 auto cls = ClassTy(m, DerivedTy::pyClassName);
3084 cls.def(py::init<PyAffineExpr &>());
3085 DerivedTy::bindDerived(cls);
3086 }
3087
3088 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)3089 static void bindDerived(ClassTy &m) {}
3090 };
3091
3092 class PyAffineConstantExpr : public PyConcreteAffineExpr<PyAffineConstantExpr> {
3093 public:
3094 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant;
3095 static constexpr const char *pyClassName = "AffineConstantExpr";
3096 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3097
get(intptr_t value,DefaultingPyMlirContext context)3098 static PyAffineConstantExpr get(intptr_t value,
3099 DefaultingPyMlirContext context) {
3100 MlirAffineExpr affineExpr =
3101 mlirAffineConstantExprGet(context->get(), static_cast<int64_t>(value));
3102 return PyAffineConstantExpr(context->getRef(), affineExpr);
3103 }
3104
bindDerived(ClassTy & c)3105 static void bindDerived(ClassTy &c) {
3106 c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
3107 py::arg("context") = py::none());
3108 c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
3109 return mlirAffineConstantExprGetValue(self);
3110 });
3111 }
3112 };
3113
3114 class PyAffineDimExpr : public PyConcreteAffineExpr<PyAffineDimExpr> {
3115 public:
3116 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim;
3117 static constexpr const char *pyClassName = "AffineDimExpr";
3118 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3119
get(intptr_t pos,DefaultingPyMlirContext context)3120 static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) {
3121 MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos);
3122 return PyAffineDimExpr(context->getRef(), affineExpr);
3123 }
3124
bindDerived(ClassTy & c)3125 static void bindDerived(ClassTy &c) {
3126 c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
3127 py::arg("context") = py::none());
3128 c.def_property_readonly("position", [](PyAffineDimExpr &self) {
3129 return mlirAffineDimExprGetPosition(self);
3130 });
3131 }
3132 };
3133
3134 class PyAffineSymbolExpr : public PyConcreteAffineExpr<PyAffineSymbolExpr> {
3135 public:
3136 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol;
3137 static constexpr const char *pyClassName = "AffineSymbolExpr";
3138 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3139
get(intptr_t pos,DefaultingPyMlirContext context)3140 static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) {
3141 MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos);
3142 return PyAffineSymbolExpr(context->getRef(), affineExpr);
3143 }
3144
bindDerived(ClassTy & c)3145 static void bindDerived(ClassTy &c) {
3146 c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
3147 py::arg("context") = py::none());
3148 c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
3149 return mlirAffineSymbolExprGetPosition(self);
3150 });
3151 }
3152 };
3153
3154 class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
3155 public:
3156 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary;
3157 static constexpr const char *pyClassName = "AffineBinaryExpr";
3158 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3159
lhs()3160 PyAffineExpr lhs() {
3161 MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
3162 return PyAffineExpr(getContext(), lhsExpr);
3163 }
3164
rhs()3165 PyAffineExpr rhs() {
3166 MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
3167 return PyAffineExpr(getContext(), rhsExpr);
3168 }
3169
bindDerived(ClassTy & c)3170 static void bindDerived(ClassTy &c) {
3171 c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
3172 c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
3173 }
3174 };
3175
3176 class PyAffineAddExpr
3177 : public PyConcreteAffineExpr<PyAffineAddExpr, PyAffineBinaryExpr> {
3178 public:
3179 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd;
3180 static constexpr const char *pyClassName = "AffineAddExpr";
3181 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3182
get(PyAffineExpr lhs,PyAffineExpr rhs)3183 static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3184 MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs);
3185 return PyAffineAddExpr(lhs.getContext(), expr);
3186 }
3187
bindDerived(ClassTy & c)3188 static void bindDerived(ClassTy &c) {
3189 c.def_static("get", &PyAffineAddExpr::get);
3190 }
3191 };
3192
3193 class PyAffineMulExpr
3194 : public PyConcreteAffineExpr<PyAffineMulExpr, PyAffineBinaryExpr> {
3195 public:
3196 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul;
3197 static constexpr const char *pyClassName = "AffineMulExpr";
3198 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3199
get(PyAffineExpr lhs,PyAffineExpr rhs)3200 static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3201 MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs);
3202 return PyAffineMulExpr(lhs.getContext(), expr);
3203 }
3204
bindDerived(ClassTy & c)3205 static void bindDerived(ClassTy &c) {
3206 c.def_static("get", &PyAffineMulExpr::get);
3207 }
3208 };
3209
3210 class PyAffineModExpr
3211 : public PyConcreteAffineExpr<PyAffineModExpr, PyAffineBinaryExpr> {
3212 public:
3213 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod;
3214 static constexpr const char *pyClassName = "AffineModExpr";
3215 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3216
get(PyAffineExpr lhs,PyAffineExpr rhs)3217 static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3218 MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs);
3219 return PyAffineModExpr(lhs.getContext(), expr);
3220 }
3221
bindDerived(ClassTy & c)3222 static void bindDerived(ClassTy &c) {
3223 c.def_static("get", &PyAffineModExpr::get);
3224 }
3225 };
3226
3227 class PyAffineFloorDivExpr
3228 : public PyConcreteAffineExpr<PyAffineFloorDivExpr, PyAffineBinaryExpr> {
3229 public:
3230 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv;
3231 static constexpr const char *pyClassName = "AffineFloorDivExpr";
3232 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3233
get(PyAffineExpr lhs,PyAffineExpr rhs)3234 static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3235 MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs);
3236 return PyAffineFloorDivExpr(lhs.getContext(), expr);
3237 }
3238
bindDerived(ClassTy & c)3239 static void bindDerived(ClassTy &c) {
3240 c.def_static("get", &PyAffineFloorDivExpr::get);
3241 }
3242 };
3243
3244 class PyAffineCeilDivExpr
3245 : public PyConcreteAffineExpr<PyAffineCeilDivExpr, PyAffineBinaryExpr> {
3246 public:
3247 static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv;
3248 static constexpr const char *pyClassName = "AffineCeilDivExpr";
3249 using PyConcreteAffineExpr::PyConcreteAffineExpr;
3250
get(PyAffineExpr lhs,PyAffineExpr rhs)3251 static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) {
3252 MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs);
3253 return PyAffineCeilDivExpr(lhs.getContext(), expr);
3254 }
3255
bindDerived(ClassTy & c)3256 static void bindDerived(ClassTy &c) {
3257 c.def_static("get", &PyAffineCeilDivExpr::get);
3258 }
3259 };
3260 } // namespace
3261
operator ==(const PyAffineExpr & other)3262 bool PyAffineExpr::operator==(const PyAffineExpr &other) {
3263 return mlirAffineExprEqual(affineExpr, other.affineExpr);
3264 }
3265
getCapsule()3266 py::object PyAffineExpr::getCapsule() {
3267 return py::reinterpret_steal<py::object>(
3268 mlirPythonAffineExprToCapsule(*this));
3269 }
3270
createFromCapsule(py::object capsule)3271 PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
3272 MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
3273 if (mlirAffineExprIsNull(rawAffineExpr))
3274 throw py::error_already_set();
3275 return PyAffineExpr(
3276 PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
3277 rawAffineExpr);
3278 }
3279
3280 //------------------------------------------------------------------------------
3281 // PyAffineMap and utilities.
3282 //------------------------------------------------------------------------------
3283
3284 namespace {
3285 /// A list of expressions contained in an affine map. Internally these are
3286 /// stored as a consecutive array leading to inexpensive random access. Both
3287 /// the map and the expression are owned by the context so we need not bother
3288 /// with lifetime extension.
3289 class PyAffineMapExprList
3290 : public Sliceable<PyAffineMapExprList, PyAffineExpr> {
3291 public:
3292 static constexpr const char *pyClassName = "AffineExprList";
3293
PyAffineMapExprList(PyAffineMap map,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)3294 PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0,
3295 intptr_t length = -1, intptr_t step = 1)
3296 : Sliceable(startIndex,
3297 length == -1 ? mlirAffineMapGetNumResults(map) : length,
3298 step),
3299 affineMap(map) {}
3300
getNumElements()3301 intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); }
3302
getElement(intptr_t pos)3303 PyAffineExpr getElement(intptr_t pos) {
3304 return PyAffineExpr(affineMap.getContext(),
3305 mlirAffineMapGetResult(affineMap, pos));
3306 }
3307
slice(intptr_t startIndex,intptr_t length,intptr_t step)3308 PyAffineMapExprList slice(intptr_t startIndex, intptr_t length,
3309 intptr_t step) {
3310 return PyAffineMapExprList(affineMap, startIndex, length, step);
3311 }
3312
3313 private:
3314 PyAffineMap affineMap;
3315 };
3316 } // end namespace
3317
operator ==(const PyAffineMap & other)3318 bool PyAffineMap::operator==(const PyAffineMap &other) {
3319 return mlirAffineMapEqual(affineMap, other.affineMap);
3320 }
3321
getCapsule()3322 py::object PyAffineMap::getCapsule() {
3323 return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
3324 }
3325
createFromCapsule(py::object capsule)3326 PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
3327 MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
3328 if (mlirAffineMapIsNull(rawAffineMap))
3329 throw py::error_already_set();
3330 return PyAffineMap(
3331 PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
3332 rawAffineMap);
3333 }
3334
3335 //------------------------------------------------------------------------------
3336 // PyIntegerSet and utilities.
3337 //------------------------------------------------------------------------------
3338
3339 class PyIntegerSetConstraint {
3340 public:
PyIntegerSetConstraint(PyIntegerSet set,intptr_t pos)3341 PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
3342
getExpr()3343 PyAffineExpr getExpr() {
3344 return PyAffineExpr(set.getContext(),
3345 mlirIntegerSetGetConstraint(set, pos));
3346 }
3347
isEq()3348 bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
3349
bind(py::module & m)3350 static void bind(py::module &m) {
3351 py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
3352 .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
3353 .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
3354 }
3355
3356 private:
3357 PyIntegerSet set;
3358 intptr_t pos;
3359 };
3360
3361 class PyIntegerSetConstraintList
3362 : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
3363 public:
3364 static constexpr const char *pyClassName = "IntegerSetConstraintList";
3365
PyIntegerSetConstraintList(PyIntegerSet set,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)3366 PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
3367 intptr_t length = -1, intptr_t step = 1)
3368 : Sliceable(startIndex,
3369 length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
3370 step),
3371 set(set) {}
3372
getNumElements()3373 intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
3374
getElement(intptr_t pos)3375 PyIntegerSetConstraint getElement(intptr_t pos) {
3376 return PyIntegerSetConstraint(set, pos);
3377 }
3378
slice(intptr_t startIndex,intptr_t length,intptr_t step)3379 PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
3380 intptr_t step) {
3381 return PyIntegerSetConstraintList(set, startIndex, length, step);
3382 }
3383
3384 private:
3385 PyIntegerSet set;
3386 };
3387
operator ==(const PyIntegerSet & other)3388 bool PyIntegerSet::operator==(const PyIntegerSet &other) {
3389 return mlirIntegerSetEqual(integerSet, other.integerSet);
3390 }
3391
getCapsule()3392 py::object PyIntegerSet::getCapsule() {
3393 return py::reinterpret_steal<py::object>(
3394 mlirPythonIntegerSetToCapsule(*this));
3395 }
3396
createFromCapsule(py::object capsule)3397 PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
3398 MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
3399 if (mlirIntegerSetIsNull(rawIntegerSet))
3400 throw py::error_already_set();
3401 return PyIntegerSet(
3402 PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
3403 rawIntegerSet);
3404 }
3405
3406 /// Attempts to populate `result` with the content of `list` casted to the
3407 /// appropriate type (Python and C types are provided as template arguments).
3408 /// Throws errors in case of failure, using "action" to describe what the caller
3409 /// was attempting to do.
3410 template <typename PyType, typename CType>
pyListToVector(py::list list,llvm::SmallVectorImpl<CType> & result,StringRef action)3411 static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
3412 StringRef action) {
3413 result.reserve(py::len(list));
3414 for (py::handle item : list) {
3415 try {
3416 result.push_back(item.cast<PyType>());
3417 } catch (py::cast_error &err) {
3418 std::string msg = (llvm::Twine("Invalid expression when ") + action +
3419 " (" + err.what() + ")")
3420 .str();
3421 throw py::cast_error(msg);
3422 } catch (py::reference_cast_error &err) {
3423 std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
3424 action + " (" + err.what() + ")")
3425 .str();
3426 throw py::cast_error(msg);
3427 }
3428 }
3429 }
3430
3431 //------------------------------------------------------------------------------
3432 // Populates the pybind11 IR submodule.
3433 //------------------------------------------------------------------------------
3434
populateIRSubmodule(py::module & m)3435 void mlir::python::populateIRSubmodule(py::module &m) {
3436 //----------------------------------------------------------------------------
3437 // Mapping of MlirContext
3438 //----------------------------------------------------------------------------
3439 py::class_<PyMlirContext>(m, "Context")
3440 .def(py::init<>(&PyMlirContext::createNewContextForInit))
3441 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
3442 .def("_get_context_again",
3443 [](PyMlirContext &self) {
3444 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
3445 return ref.releaseObject();
3446 })
3447 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
3448 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
3449 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3450 &PyMlirContext::getCapsule)
3451 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
3452 .def("__enter__", &PyMlirContext::contextEnter)
3453 .def("__exit__", &PyMlirContext::contextExit)
3454 .def_property_readonly_static(
3455 "current",
3456 [](py::object & /*class*/) {
3457 auto *context = PyThreadContextEntry::getDefaultContext();
3458 if (!context)
3459 throw SetPyError(PyExc_ValueError, "No current Context");
3460 return context;
3461 },
3462 "Gets the Context bound to the current thread or raises ValueError")
3463 .def_property_readonly(
3464 "dialects",
3465 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3466 "Gets a container for accessing dialects by name")
3467 .def_property_readonly(
3468 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
3469 "Alias for 'dialect'")
3470 .def(
3471 "get_dialect_descriptor",
3472 [=](PyMlirContext &self, std::string &name) {
3473 MlirDialect dialect = mlirContextGetOrLoadDialect(
3474 self.get(), {name.data(), name.size()});
3475 if (mlirDialectIsNull(dialect)) {
3476 throw SetPyError(PyExc_ValueError,
3477 Twine("Dialect '") + name + "' not found");
3478 }
3479 return PyDialectDescriptor(self.getRef(), dialect);
3480 },
3481 "Gets or loads a dialect by name, returning its descriptor object")
3482 .def_property(
3483 "allow_unregistered_dialects",
3484 [](PyMlirContext &self) -> bool {
3485 return mlirContextGetAllowUnregisteredDialects(self.get());
3486 },
3487 [](PyMlirContext &self, bool value) {
3488 mlirContextSetAllowUnregisteredDialects(self.get(), value);
3489 });
3490
3491 //----------------------------------------------------------------------------
3492 // Mapping of PyDialectDescriptor
3493 //----------------------------------------------------------------------------
3494 py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
3495 .def_property_readonly("namespace",
3496 [](PyDialectDescriptor &self) {
3497 MlirStringRef ns =
3498 mlirDialectGetNamespace(self.get());
3499 return py::str(ns.data, ns.length);
3500 })
3501 .def("__repr__", [](PyDialectDescriptor &self) {
3502 MlirStringRef ns = mlirDialectGetNamespace(self.get());
3503 std::string repr("<DialectDescriptor ");
3504 repr.append(ns.data, ns.length);
3505 repr.append(">");
3506 return repr;
3507 });
3508
3509 //----------------------------------------------------------------------------
3510 // Mapping of PyDialects
3511 //----------------------------------------------------------------------------
3512 py::class_<PyDialects>(m, "Dialects")
3513 .def("__getitem__",
3514 [=](PyDialects &self, std::string keyName) {
3515 MlirDialect dialect =
3516 self.getDialectForKey(keyName, /*attrError=*/false);
3517 py::object descriptor =
3518 py::cast(PyDialectDescriptor{self.getContext(), dialect});
3519 return createCustomDialectWrapper(keyName, std::move(descriptor));
3520 })
3521 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
3522 MlirDialect dialect =
3523 self.getDialectForKey(attrName, /*attrError=*/true);
3524 py::object descriptor =
3525 py::cast(PyDialectDescriptor{self.getContext(), dialect});
3526 return createCustomDialectWrapper(attrName, std::move(descriptor));
3527 });
3528
3529 //----------------------------------------------------------------------------
3530 // Mapping of PyDialect
3531 //----------------------------------------------------------------------------
3532 py::class_<PyDialect>(m, "Dialect")
3533 .def(py::init<py::object>(), "descriptor")
3534 .def_property_readonly(
3535 "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
3536 .def("__repr__", [](py::object self) {
3537 auto clazz = self.attr("__class__");
3538 return py::str("<Dialect ") +
3539 self.attr("descriptor").attr("namespace") + py::str(" (class ") +
3540 clazz.attr("__module__") + py::str(".") +
3541 clazz.attr("__name__") + py::str(")>");
3542 });
3543
3544 //----------------------------------------------------------------------------
3545 // Mapping of Location
3546 //----------------------------------------------------------------------------
3547 py::class_<PyLocation>(m, "Location")
3548 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
3549 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
3550 .def("__enter__", &PyLocation::contextEnter)
3551 .def("__exit__", &PyLocation::contextExit)
3552 .def("__eq__",
3553 [](PyLocation &self, PyLocation &other) -> bool {
3554 return mlirLocationEqual(self, other);
3555 })
3556 .def("__eq__", [](PyLocation &self, py::object other) { return false; })
3557 .def_property_readonly_static(
3558 "current",
3559 [](py::object & /*class*/) {
3560 auto *loc = PyThreadContextEntry::getDefaultLocation();
3561 if (!loc)
3562 throw SetPyError(PyExc_ValueError, "No current Location");
3563 return loc;
3564 },
3565 "Gets the Location bound to the current thread or raises ValueError")
3566 .def_static(
3567 "unknown",
3568 [](DefaultingPyMlirContext context) {
3569 return PyLocation(context->getRef(),
3570 mlirLocationUnknownGet(context->get()));
3571 },
3572 py::arg("context") = py::none(),
3573 "Gets a Location representing an unknown location")
3574 .def_static(
3575 "file",
3576 [](std::string filename, int line, int col,
3577 DefaultingPyMlirContext context) {
3578 return PyLocation(
3579 context->getRef(),
3580 mlirLocationFileLineColGet(
3581 context->get(), toMlirStringRef(filename), line, col));
3582 },
3583 py::arg("filename"), py::arg("line"), py::arg("col"),
3584 py::arg("context") = py::none(), kContextGetFileLocationDocstring)
3585 .def_property_readonly(
3586 "context",
3587 [](PyLocation &self) { return self.getContext().getObject(); },
3588 "Context that owns the Location")
3589 .def("__repr__", [](PyLocation &self) {
3590 PyPrintAccumulator printAccum;
3591 mlirLocationPrint(self, printAccum.getCallback(),
3592 printAccum.getUserData());
3593 return printAccum.join();
3594 });
3595
3596 //----------------------------------------------------------------------------
3597 // Mapping of Module
3598 //----------------------------------------------------------------------------
3599 py::class_<PyModule>(m, "Module")
3600 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3601 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3602 .def_static(
3603 "parse",
3604 [](const std::string moduleAsm, DefaultingPyMlirContext context) {
3605 MlirModule module = mlirModuleCreateParse(
3606 context->get(), toMlirStringRef(moduleAsm));
3607 // TODO: Rework error reporting once diagnostic engine is exposed
3608 // in C API.
3609 if (mlirModuleIsNull(module)) {
3610 throw SetPyError(
3611 PyExc_ValueError,
3612 "Unable to parse module assembly (see diagnostics)");
3613 }
3614 return PyModule::forModule(module).releaseObject();
3615 },
3616 py::arg("asm"), py::arg("context") = py::none(),
3617 kModuleParseDocstring)
3618 .def_static(
3619 "create",
3620 [](DefaultingPyLocation loc) {
3621 MlirModule module = mlirModuleCreateEmpty(loc);
3622 return PyModule::forModule(module).releaseObject();
3623 },
3624 py::arg("loc") = py::none(), "Creates an empty module")
3625 .def_property_readonly(
3626 "context",
3627 [](PyModule &self) { return self.getContext().getObject(); },
3628 "Context that created the Module")
3629 .def_property_readonly(
3630 "operation",
3631 [](PyModule &self) {
3632 return PyOperation::forOperation(self.getContext(),
3633 mlirModuleGetOperation(self.get()),
3634 self.getRef().releaseObject())
3635 .releaseObject();
3636 },
3637 "Accesses the module as an operation")
3638 .def_property_readonly(
3639 "body",
3640 [](PyModule &self) {
3641 PyOperationRef module_op = PyOperation::forOperation(
3642 self.getContext(), mlirModuleGetOperation(self.get()),
3643 self.getRef().releaseObject());
3644 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
3645 return returnBlock;
3646 },
3647 "Return the block for this module")
3648 .def(
3649 "dump",
3650 [](PyModule &self) {
3651 mlirOperationDump(mlirModuleGetOperation(self.get()));
3652 },
3653 kDumpDocstring)
3654 .def(
3655 "__str__",
3656 [](PyModule &self) {
3657 MlirOperation operation = mlirModuleGetOperation(self.get());
3658 PyPrintAccumulator printAccum;
3659 mlirOperationPrint(operation, printAccum.getCallback(),
3660 printAccum.getUserData());
3661 return printAccum.join();
3662 },
3663 kOperationStrDunderDocstring);
3664
3665 //----------------------------------------------------------------------------
3666 // Mapping of Operation.
3667 //----------------------------------------------------------------------------
3668 py::class_<PyOperationBase>(m, "_OperationBase")
3669 .def("__eq__",
3670 [](PyOperationBase &self, PyOperationBase &other) {
3671 return &self.getOperation() == &other.getOperation();
3672 })
3673 .def("__eq__",
3674 [](PyOperationBase &self, py::object other) { return false; })
3675 .def_property_readonly("attributes",
3676 [](PyOperationBase &self) {
3677 return PyOpAttributeMap(
3678 self.getOperation().getRef());
3679 })
3680 .def_property_readonly("operands",
3681 [](PyOperationBase &self) {
3682 return PyOpOperandList(
3683 self.getOperation().getRef());
3684 })
3685 .def_property_readonly("regions",
3686 [](PyOperationBase &self) {
3687 return PyRegionList(
3688 self.getOperation().getRef());
3689 })
3690 .def_property_readonly(
3691 "results",
3692 [](PyOperationBase &self) {
3693 return PyOpResultList(self.getOperation().getRef());
3694 },
3695 "Returns the list of Operation results.")
3696 .def_property_readonly(
3697 "result",
3698 [](PyOperationBase &self) {
3699 auto &operation = self.getOperation();
3700 auto numResults = mlirOperationGetNumResults(operation);
3701 if (numResults != 1) {
3702 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3703 throw SetPyError(
3704 PyExc_ValueError,
3705 Twine("Cannot call .result on operation ") +
3706 StringRef(name.data, name.length) + " which has " +
3707 Twine(numResults) +
3708 " results (it is only valid for operations with a "
3709 "single result)");
3710 }
3711 return PyOpResult(operation.getRef(),
3712 mlirOperationGetResult(operation, 0));
3713 },
3714 "Shortcut to get an op result if it has only one (throws an error "
3715 "otherwise).")
3716 .def("__iter__",
3717 [](PyOperationBase &self) {
3718 return PyRegionIterator(self.getOperation().getRef());
3719 })
3720 .def(
3721 "__str__",
3722 [](PyOperationBase &self) {
3723 return self.getAsm(/*binary=*/false,
3724 /*largeElementsLimit=*/llvm::None,
3725 /*enableDebugInfo=*/false,
3726 /*prettyDebugInfo=*/false,
3727 /*printGenericOpForm=*/false,
3728 /*useLocalScope=*/false);
3729 },
3730 "Returns the assembly form of the operation.")
3731 .def("print", &PyOperationBase::print,
3732 // Careful: Lots of arguments must match up with print method.
3733 py::arg("file") = py::none(), py::arg("binary") = false,
3734 py::arg("large_elements_limit") = py::none(),
3735 py::arg("enable_debug_info") = false,
3736 py::arg("pretty_debug_info") = false,
3737 py::arg("print_generic_op_form") = false,
3738 py::arg("use_local_scope") = false, kOperationPrintDocstring)
3739 .def("get_asm", &PyOperationBase::getAsm,
3740 // Careful: Lots of arguments must match up with get_asm method.
3741 py::arg("binary") = false,
3742 py::arg("large_elements_limit") = py::none(),
3743 py::arg("enable_debug_info") = false,
3744 py::arg("pretty_debug_info") = false,
3745 py::arg("print_generic_op_form") = false,
3746 py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
3747 .def(
3748 "verify",
3749 [](PyOperationBase &self) {
3750 return mlirOperationVerify(self.getOperation());
3751 },
3752 "Verify the operation and return true if it passes, false if it "
3753 "fails.");
3754
3755 py::class_<PyOperation, PyOperationBase>(m, "Operation")
3756 .def_static("create", &PyOperation::create, py::arg("name"),
3757 py::arg("results") = py::none(),
3758 py::arg("operands") = py::none(),
3759 py::arg("attributes") = py::none(),
3760 py::arg("successors") = py::none(), py::arg("regions") = 0,
3761 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3762 kOperationCreateDocstring)
3763 .def_property_readonly("name",
3764 [](PyOperation &self) {
3765 MlirOperation operation = self.get();
3766 MlirStringRef name = mlirIdentifierStr(
3767 mlirOperationGetName(operation));
3768 return py::str(name.data, name.length);
3769 })
3770 .def_property_readonly(
3771 "context",
3772 [](PyOperation &self) { return self.getContext().getObject(); },
3773 "Context that owns the Operation")
3774 .def_property_readonly("opview", &PyOperation::createOpView);
3775
3776 auto opViewClass =
3777 py::class_<PyOpView, PyOperationBase>(m, "OpView")
3778 .def(py::init<py::object>())
3779 .def_property_readonly("operation", &PyOpView::getOperationObject)
3780 .def_property_readonly(
3781 "context",
3782 [](PyOpView &self) {
3783 return self.getOperation().getContext().getObject();
3784 },
3785 "Context that owns the Operation")
3786 .def("__str__", [](PyOpView &self) {
3787 return py::str(self.getOperationObject());
3788 });
3789 opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
3790 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
3791 opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3792 opViewClass.attr("build_generic") = classmethod(
3793 &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3794 py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3795 py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3796 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
3797 "Builds a specific, generated OpView based on class level attributes.");
3798
3799 //----------------------------------------------------------------------------
3800 // Mapping of PyRegion.
3801 //----------------------------------------------------------------------------
3802 py::class_<PyRegion>(m, "Region")
3803 .def_property_readonly(
3804 "blocks",
3805 [](PyRegion &self) {
3806 return PyBlockList(self.getParentOperation(), self.get());
3807 },
3808 "Returns a forward-optimized sequence of blocks.")
3809 .def(
3810 "__iter__",
3811 [](PyRegion &self) {
3812 self.checkValid();
3813 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3814 return PyBlockIterator(self.getParentOperation(), firstBlock);
3815 },
3816 "Iterates over blocks in the region.")
3817 .def("__eq__",
3818 [](PyRegion &self, PyRegion &other) {
3819 return self.get().ptr == other.get().ptr;
3820 })
3821 .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
3822
3823 //----------------------------------------------------------------------------
3824 // Mapping of PyBlock.
3825 //----------------------------------------------------------------------------
3826 py::class_<PyBlock>(m, "Block")
3827 .def_property_readonly(
3828 "arguments",
3829 [](PyBlock &self) {
3830 return PyBlockArgumentList(self.getParentOperation(), self.get());
3831 },
3832 "Returns a list of block arguments.")
3833 .def_property_readonly(
3834 "operations",
3835 [](PyBlock &self) {
3836 return PyOperationList(self.getParentOperation(), self.get());
3837 },
3838 "Returns a forward-optimized sequence of operations.")
3839 .def(
3840 "__iter__",
3841 [](PyBlock &self) {
3842 self.checkValid();
3843 MlirOperation firstOperation =
3844 mlirBlockGetFirstOperation(self.get());
3845 return PyOperationIterator(self.getParentOperation(),
3846 firstOperation);
3847 },
3848 "Iterates over operations in the block.")
3849 .def("__eq__",
3850 [](PyBlock &self, PyBlock &other) {
3851 return self.get().ptr == other.get().ptr;
3852 })
3853 .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
3854 .def(
3855 "__str__",
3856 [](PyBlock &self) {
3857 self.checkValid();
3858 PyPrintAccumulator printAccum;
3859 mlirBlockPrint(self.get(), printAccum.getCallback(),
3860 printAccum.getUserData());
3861 return printAccum.join();
3862 },
3863 "Returns the assembly form of the block.");
3864
3865 //----------------------------------------------------------------------------
3866 // Mapping of PyInsertionPoint.
3867 //----------------------------------------------------------------------------
3868
3869 py::class_<PyInsertionPoint>(m, "InsertionPoint")
3870 .def(py::init<PyBlock &>(), py::arg("block"),
3871 "Inserts after the last operation but still inside the block.")
3872 .def("__enter__", &PyInsertionPoint::contextEnter)
3873 .def("__exit__", &PyInsertionPoint::contextExit)
3874 .def_property_readonly_static(
3875 "current",
3876 [](py::object & /*class*/) {
3877 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3878 if (!ip)
3879 throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
3880 return ip;
3881 },
3882 "Gets the InsertionPoint bound to the current thread or raises "
3883 "ValueError if none has been set")
3884 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
3885 "Inserts before a referenced operation.")
3886 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3887 py::arg("block"), "Inserts at the beginning of the block.")
3888 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3889 py::arg("block"), "Inserts before the block terminator.")
3890 .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
3891 "Inserts an operation.");
3892
3893 //----------------------------------------------------------------------------
3894 // Mapping of PyAttribute.
3895 //----------------------------------------------------------------------------
3896 py::class_<PyAttribute>(m, "Attribute")
3897 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3898 &PyAttribute::getCapsule)
3899 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3900 .def_static(
3901 "parse",
3902 [](std::string attrSpec, DefaultingPyMlirContext context) {
3903 MlirAttribute type = mlirAttributeParseGet(
3904 context->get(), toMlirStringRef(attrSpec));
3905 // TODO: Rework error reporting once diagnostic engine is exposed
3906 // in C API.
3907 if (mlirAttributeIsNull(type)) {
3908 throw SetPyError(PyExc_ValueError,
3909 Twine("Unable to parse attribute: '") +
3910 attrSpec + "'");
3911 }
3912 return PyAttribute(context->getRef(), type);
3913 },
3914 py::arg("asm"), py::arg("context") = py::none(),
3915 "Parses an attribute from an assembly form")
3916 .def_property_readonly(
3917 "context",
3918 [](PyAttribute &self) { return self.getContext().getObject(); },
3919 "Context that owns the Attribute")
3920 .def_property_readonly("type",
3921 [](PyAttribute &self) {
3922 return PyType(self.getContext()->getRef(),
3923 mlirAttributeGetType(self));
3924 })
3925 .def(
3926 "get_named",
3927 [](PyAttribute &self, std::string name) {
3928 return PyNamedAttribute(self, std::move(name));
3929 },
3930 py::keep_alive<0, 1>(), "Binds a name to the attribute")
3931 .def("__eq__",
3932 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3933 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3934 .def(
3935 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3936 kDumpDocstring)
3937 .def(
3938 "__str__",
3939 [](PyAttribute &self) {
3940 PyPrintAccumulator printAccum;
3941 mlirAttributePrint(self, printAccum.getCallback(),
3942 printAccum.getUserData());
3943 return printAccum.join();
3944 },
3945 "Returns the assembly form of the Attribute.")
3946 .def("__repr__", [](PyAttribute &self) {
3947 // Generally, assembly formats are not printed for __repr__ because
3948 // this can cause exceptionally long debug output and exceptions.
3949 // However, attribute values are generally considered useful and are
3950 // printed. This may need to be re-evaluated if debug dumps end up
3951 // being excessive.
3952 PyPrintAccumulator printAccum;
3953 printAccum.parts.append("Attribute(");
3954 mlirAttributePrint(self, printAccum.getCallback(),
3955 printAccum.getUserData());
3956 printAccum.parts.append(")");
3957 return printAccum.join();
3958 });
3959
3960 //----------------------------------------------------------------------------
3961 // Mapping of PyNamedAttribute
3962 //----------------------------------------------------------------------------
3963 py::class_<PyNamedAttribute>(m, "NamedAttribute")
3964 .def("__repr__",
3965 [](PyNamedAttribute &self) {
3966 PyPrintAccumulator printAccum;
3967 printAccum.parts.append("NamedAttribute(");
3968 printAccum.parts.append(
3969 mlirIdentifierStr(self.namedAttr.name).data);
3970 printAccum.parts.append("=");
3971 mlirAttributePrint(self.namedAttr.attribute,
3972 printAccum.getCallback(),
3973 printAccum.getUserData());
3974 printAccum.parts.append(")");
3975 return printAccum.join();
3976 })
3977 .def_property_readonly(
3978 "name",
3979 [](PyNamedAttribute &self) {
3980 return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3981 mlirIdentifierStr(self.namedAttr.name).length);
3982 },
3983 "The name of the NamedAttribute binding")
3984 .def_property_readonly(
3985 "attr",
3986 [](PyNamedAttribute &self) {
3987 // TODO: When named attribute is removed/refactored, also remove
3988 // this constructor (it does an inefficient table lookup).
3989 auto contextRef = PyMlirContext::forContext(
3990 mlirAttributeGetContext(self.namedAttr.attribute));
3991 return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3992 },
3993 py::keep_alive<0, 1>(),
3994 "The underlying generic attribute of the NamedAttribute binding");
3995
3996 // Builtin attribute bindings.
3997 PyFloatAttribute::bind(m);
3998 PyArrayAttribute::bind(m);
3999 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
4000 PyIntegerAttribute::bind(m);
4001 PyBoolAttribute::bind(m);
4002 PyFlatSymbolRefAttribute::bind(m);
4003 PyStringAttribute::bind(m);
4004 PyDenseElementsAttribute::bind(m);
4005 PyDenseIntElementsAttribute::bind(m);
4006 PyDenseFPElementsAttribute::bind(m);
4007 PyDictAttribute::bind(m);
4008 PyTypeAttribute::bind(m);
4009 PyUnitAttribute::bind(m);
4010
4011 //----------------------------------------------------------------------------
4012 // Mapping of PyType.
4013 //----------------------------------------------------------------------------
4014 py::class_<PyType>(m, "Type")
4015 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
4016 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
4017 .def_static(
4018 "parse",
4019 [](std::string typeSpec, DefaultingPyMlirContext context) {
4020 MlirType type =
4021 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
4022 // TODO: Rework error reporting once diagnostic engine is exposed
4023 // in C API.
4024 if (mlirTypeIsNull(type)) {
4025 throw SetPyError(PyExc_ValueError,
4026 Twine("Unable to parse type: '") + typeSpec +
4027 "'");
4028 }
4029 return PyType(context->getRef(), type);
4030 },
4031 py::arg("asm"), py::arg("context") = py::none(),
4032 kContextParseTypeDocstring)
4033 .def_property_readonly(
4034 "context", [](PyType &self) { return self.getContext().getObject(); },
4035 "Context that owns the Type")
4036 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
4037 .def("__eq__", [](PyType &self, py::object &other) { return false; })
4038 .def(
4039 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
4040 .def(
4041 "__str__",
4042 [](PyType &self) {
4043 PyPrintAccumulator printAccum;
4044 mlirTypePrint(self, printAccum.getCallback(),
4045 printAccum.getUserData());
4046 return printAccum.join();
4047 },
4048 "Returns the assembly form of the type.")
4049 .def("__repr__", [](PyType &self) {
4050 // Generally, assembly formats are not printed for __repr__ because
4051 // this can cause exceptionally long debug output and exceptions.
4052 // However, types are an exception as they typically have compact
4053 // assembly forms and printing them is useful.
4054 PyPrintAccumulator printAccum;
4055 printAccum.parts.append("Type(");
4056 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
4057 printAccum.parts.append(")");
4058 return printAccum.join();
4059 });
4060
4061 // Builtin type bindings.
4062 PyIntegerType::bind(m);
4063 PyIndexType::bind(m);
4064 PyBF16Type::bind(m);
4065 PyF16Type::bind(m);
4066 PyF32Type::bind(m);
4067 PyF64Type::bind(m);
4068 PyNoneType::bind(m);
4069 PyComplexType::bind(m);
4070 PyShapedType::bind(m);
4071 PyVectorType::bind(m);
4072 PyRankedTensorType::bind(m);
4073 PyUnrankedTensorType::bind(m);
4074 PyMemRefType::bind(m);
4075 PyMemRefLayoutMapList::bind(m);
4076 PyUnrankedMemRefType::bind(m);
4077 PyTupleType::bind(m);
4078 PyFunctionType::bind(m);
4079
4080 //----------------------------------------------------------------------------
4081 // Mapping of Value.
4082 //----------------------------------------------------------------------------
4083 py::class_<PyValue>(m, "Value")
4084 .def_property_readonly(
4085 "context",
4086 [](PyValue &self) { return self.getParentOperation()->getContext(); },
4087 "Context in which the value lives.")
4088 .def(
4089 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
4090 kDumpDocstring)
4091 .def("__eq__",
4092 [](PyValue &self, PyValue &other) {
4093 return self.get().ptr == other.get().ptr;
4094 })
4095 .def("__eq__", [](PyValue &self, py::object other) { return false; })
4096 .def(
4097 "__str__",
4098 [](PyValue &self) {
4099 PyPrintAccumulator printAccum;
4100 printAccum.parts.append("Value(");
4101 mlirValuePrint(self.get(), printAccum.getCallback(),
4102 printAccum.getUserData());
4103 printAccum.parts.append(")");
4104 return printAccum.join();
4105 },
4106 kValueDunderStrDocstring)
4107 .def_property_readonly("type", [](PyValue &self) {
4108 return PyType(self.getParentOperation()->getContext(),
4109 mlirValueGetType(self.get()));
4110 });
4111 PyBlockArgument::bind(m);
4112 PyOpResult::bind(m);
4113
4114 // Container bindings.
4115 PyBlockArgumentList::bind(m);
4116 PyBlockIterator::bind(m);
4117 PyBlockList::bind(m);
4118 PyOperationIterator::bind(m);
4119 PyOperationList::bind(m);
4120 PyOpAttributeMap::bind(m);
4121 PyOpOperandList::bind(m);
4122 PyOpResultList::bind(m);
4123 PyRegionIterator::bind(m);
4124 PyRegionList::bind(m);
4125
4126 //----------------------------------------------------------------------------
4127 // Mapping of PyAffineExpr and derived classes.
4128 //----------------------------------------------------------------------------
4129 py::class_<PyAffineExpr>(m, "AffineExpr")
4130 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4131 &PyAffineExpr::getCapsule)
4132 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
4133 .def("__add__",
4134 [](PyAffineExpr &self, PyAffineExpr &other) {
4135 return PyAffineAddExpr::get(self, other);
4136 })
4137 .def("__mul__",
4138 [](PyAffineExpr &self, PyAffineExpr &other) {
4139 return PyAffineMulExpr::get(self, other);
4140 })
4141 .def("__mod__",
4142 [](PyAffineExpr &self, PyAffineExpr &other) {
4143 return PyAffineModExpr::get(self, other);
4144 })
4145 .def("__sub__",
4146 [](PyAffineExpr &self, PyAffineExpr &other) {
4147 auto negOne =
4148 PyAffineConstantExpr::get(-1, *self.getContext().get());
4149 return PyAffineAddExpr::get(self,
4150 PyAffineMulExpr::get(negOne, other));
4151 })
4152 .def("__eq__", [](PyAffineExpr &self,
4153 PyAffineExpr &other) { return self == other; })
4154 .def("__eq__",
4155 [](PyAffineExpr &self, py::object &other) { return false; })
4156 .def("__str__",
4157 [](PyAffineExpr &self) {
4158 PyPrintAccumulator printAccum;
4159 mlirAffineExprPrint(self, printAccum.getCallback(),
4160 printAccum.getUserData());
4161 return printAccum.join();
4162 })
4163 .def("__repr__",
4164 [](PyAffineExpr &self) {
4165 PyPrintAccumulator printAccum;
4166 printAccum.parts.append("AffineExpr(");
4167 mlirAffineExprPrint(self, printAccum.getCallback(),
4168 printAccum.getUserData());
4169 printAccum.parts.append(")");
4170 return printAccum.join();
4171 })
4172 .def_property_readonly(
4173 "context",
4174 [](PyAffineExpr &self) { return self.getContext().getObject(); })
4175 .def_static(
4176 "get_add", &PyAffineAddExpr::get,
4177 "Gets an affine expression containing a sum of two expressions.")
4178 .def_static(
4179 "get_mul", &PyAffineMulExpr::get,
4180 "Gets an affine expression containing a product of two expressions.")
4181 .def_static("get_mod", &PyAffineModExpr::get,
4182 "Gets an affine expression containing the modulo of dividing "
4183 "one expression by another.")
4184 .def_static("get_floor_div", &PyAffineFloorDivExpr::get,
4185 "Gets an affine expression containing the rounded-down "
4186 "result of dividing one expression by another.")
4187 .def_static("get_ceil_div", &PyAffineCeilDivExpr::get,
4188 "Gets an affine expression containing the rounded-up result "
4189 "of dividing one expression by another.")
4190 .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
4191 py::arg("context") = py::none(),
4192 "Gets a constant affine expression with the given value.")
4193 .def_static(
4194 "get_dim", &PyAffineDimExpr::get, py::arg("position"),
4195 py::arg("context") = py::none(),
4196 "Gets an affine expression of a dimension at the given position.")
4197 .def_static(
4198 "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
4199 py::arg("context") = py::none(),
4200 "Gets an affine expression of a symbol at the given position.")
4201 .def(
4202 "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
4203 kDumpDocstring);
4204 PyAffineConstantExpr::bind(m);
4205 PyAffineDimExpr::bind(m);
4206 PyAffineSymbolExpr::bind(m);
4207 PyAffineBinaryExpr::bind(m);
4208 PyAffineAddExpr::bind(m);
4209 PyAffineMulExpr::bind(m);
4210 PyAffineModExpr::bind(m);
4211 PyAffineFloorDivExpr::bind(m);
4212 PyAffineCeilDivExpr::bind(m);
4213
4214 //----------------------------------------------------------------------------
4215 // Mapping of PyAffineMap.
4216 //----------------------------------------------------------------------------
4217 py::class_<PyAffineMap>(m, "AffineMap")
4218 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4219 &PyAffineMap::getCapsule)
4220 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
4221 .def("__eq__",
4222 [](PyAffineMap &self, PyAffineMap &other) { return self == other; })
4223 .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
4224 .def("__str__",
4225 [](PyAffineMap &self) {
4226 PyPrintAccumulator printAccum;
4227 mlirAffineMapPrint(self, printAccum.getCallback(),
4228 printAccum.getUserData());
4229 return printAccum.join();
4230 })
4231 .def("__repr__",
4232 [](PyAffineMap &self) {
4233 PyPrintAccumulator printAccum;
4234 printAccum.parts.append("AffineMap(");
4235 mlirAffineMapPrint(self, printAccum.getCallback(),
4236 printAccum.getUserData());
4237 printAccum.parts.append(")");
4238 return printAccum.join();
4239 })
4240 .def_property_readonly(
4241 "context",
4242 [](PyAffineMap &self) { return self.getContext().getObject(); },
4243 "Context that owns the Affine Map")
4244 .def(
4245 "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
4246 kDumpDocstring)
4247 .def_static(
4248 "get",
4249 [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
4250 DefaultingPyMlirContext context) {
4251 SmallVector<MlirAffineExpr> affineExprs;
4252 pyListToVector<PyAffineExpr, MlirAffineExpr>(
4253 exprs, affineExprs, "attempting to create an AffineMap");
4254 MlirAffineMap map =
4255 mlirAffineMapGet(context->get(), dimCount, symbolCount,
4256 affineExprs.size(), affineExprs.data());
4257 return PyAffineMap(context->getRef(), map);
4258 },
4259 py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
4260 py::arg("context") = py::none(),
4261 "Gets a map with the given expressions as results.")
4262 .def_static(
4263 "get_constant",
4264 [](intptr_t value, DefaultingPyMlirContext context) {
4265 MlirAffineMap affineMap =
4266 mlirAffineMapConstantGet(context->get(), value);
4267 return PyAffineMap(context->getRef(), affineMap);
4268 },
4269 py::arg("value"), py::arg("context") = py::none(),
4270 "Gets an affine map with a single constant result")
4271 .def_static(
4272 "get_empty",
4273 [](DefaultingPyMlirContext context) {
4274 MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
4275 return PyAffineMap(context->getRef(), affineMap);
4276 },
4277 py::arg("context") = py::none(), "Gets an empty affine map.")
4278 .def_static(
4279 "get_identity",
4280 [](intptr_t nDims, DefaultingPyMlirContext context) {
4281 MlirAffineMap affineMap =
4282 mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
4283 return PyAffineMap(context->getRef(), affineMap);
4284 },
4285 py::arg("n_dims"), py::arg("context") = py::none(),
4286 "Gets an identity map with the given number of dimensions.")
4287 .def_static(
4288 "get_minor_identity",
4289 [](intptr_t nDims, intptr_t nResults,
4290 DefaultingPyMlirContext context) {
4291 MlirAffineMap affineMap =
4292 mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
4293 return PyAffineMap(context->getRef(), affineMap);
4294 },
4295 py::arg("n_dims"), py::arg("n_results"),
4296 py::arg("context") = py::none(),
4297 "Gets a minor identity map with the given number of dimensions and "
4298 "results.")
4299 .def_static(
4300 "get_permutation",
4301 [](std::vector<unsigned> permutation,
4302 DefaultingPyMlirContext context) {
4303 if (!isPermutation(permutation))
4304 throw py::cast_error("Invalid permutation when attempting to "
4305 "create an AffineMap");
4306 MlirAffineMap affineMap = mlirAffineMapPermutationGet(
4307 context->get(), permutation.size(), permutation.data());
4308 return PyAffineMap(context->getRef(), affineMap);
4309 },
4310 py::arg("permutation"), py::arg("context") = py::none(),
4311 "Gets an affine map that permutes its inputs.")
4312 .def("get_submap",
4313 [](PyAffineMap &self, std::vector<intptr_t> &resultPos) {
4314 intptr_t numResults = mlirAffineMapGetNumResults(self);
4315 for (intptr_t pos : resultPos) {
4316 if (pos < 0 || pos >= numResults)
4317 throw py::value_error("result position out of bounds");
4318 }
4319 MlirAffineMap affineMap = mlirAffineMapGetSubMap(
4320 self, resultPos.size(), resultPos.data());
4321 return PyAffineMap(self.getContext(), affineMap);
4322 })
4323 .def("get_major_submap",
4324 [](PyAffineMap &self, intptr_t nResults) {
4325 if (nResults >= mlirAffineMapGetNumResults(self))
4326 throw py::value_error("number of results out of bounds");
4327 MlirAffineMap affineMap =
4328 mlirAffineMapGetMajorSubMap(self, nResults);
4329 return PyAffineMap(self.getContext(), affineMap);
4330 })
4331 .def("get_minor_submap",
4332 [](PyAffineMap &self, intptr_t nResults) {
4333 if (nResults >= mlirAffineMapGetNumResults(self))
4334 throw py::value_error("number of results out of bounds");
4335 MlirAffineMap affineMap =
4336 mlirAffineMapGetMinorSubMap(self, nResults);
4337 return PyAffineMap(self.getContext(), affineMap);
4338 })
4339 .def_property_readonly(
4340 "is_permutation",
4341 [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
4342 .def_property_readonly("is_projected_permutation",
4343 [](PyAffineMap &self) {
4344 return mlirAffineMapIsProjectedPermutation(self);
4345 })
4346 .def_property_readonly(
4347 "n_dims",
4348 [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
4349 .def_property_readonly(
4350 "n_inputs",
4351 [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
4352 .def_property_readonly(
4353 "n_symbols",
4354 [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
4355 .def_property_readonly("results", [](PyAffineMap &self) {
4356 return PyAffineMapExprList(self);
4357 });
4358 PyAffineMapExprList::bind(m);
4359
4360 //----------------------------------------------------------------------------
4361 // Mapping of PyIntegerSet.
4362 //----------------------------------------------------------------------------
4363 py::class_<PyIntegerSet>(m, "IntegerSet")
4364 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4365 &PyIntegerSet::getCapsule)
4366 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
4367 .def("__eq__", [](PyIntegerSet &self,
4368 PyIntegerSet &other) { return self == other; })
4369 .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
4370 .def("__str__",
4371 [](PyIntegerSet &self) {
4372 PyPrintAccumulator printAccum;
4373 mlirIntegerSetPrint(self, printAccum.getCallback(),
4374 printAccum.getUserData());
4375 return printAccum.join();
4376 })
4377 .def("__repr__",
4378 [](PyIntegerSet &self) {
4379 PyPrintAccumulator printAccum;
4380 printAccum.parts.append("IntegerSet(");
4381 mlirIntegerSetPrint(self, printAccum.getCallback(),
4382 printAccum.getUserData());
4383 printAccum.parts.append(")");
4384 return printAccum.join();
4385 })
4386 .def_property_readonly(
4387 "context",
4388 [](PyIntegerSet &self) { return self.getContext().getObject(); })
4389 .def(
4390 "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
4391 kDumpDocstring)
4392 .def_static(
4393 "get",
4394 [](intptr_t numDims, intptr_t numSymbols, py::list exprs,
4395 std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
4396 if (exprs.size() != eqFlags.size())
4397 throw py::value_error(
4398 "Expected the number of constraints to match "
4399 "that of equality flags");
4400 if (exprs.empty())
4401 throw py::value_error("Expected non-empty list of constraints");
4402
4403 // Copy over to a SmallVector because std::vector has a
4404 // specialization for booleans that packs data and does not
4405 // expose a `bool *`.
4406 SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
4407
4408 SmallVector<MlirAffineExpr> affineExprs;
4409 pyListToVector<PyAffineExpr>(exprs, affineExprs,
4410 "attempting to create an IntegerSet");
4411 MlirIntegerSet set = mlirIntegerSetGet(
4412 context->get(), numDims, numSymbols, exprs.size(),
4413 affineExprs.data(), flags.data());
4414 return PyIntegerSet(context->getRef(), set);
4415 },
4416 py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
4417 py::arg("eq_flags"), py::arg("context") = py::none())
4418 .def_static(
4419 "get_empty",
4420 [](intptr_t numDims, intptr_t numSymbols,
4421 DefaultingPyMlirContext context) {
4422 MlirIntegerSet set =
4423 mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
4424 return PyIntegerSet(context->getRef(), set);
4425 },
4426 py::arg("num_dims"), py::arg("num_symbols"),
4427 py::arg("context") = py::none())
4428 .def("get_replaced",
4429 [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
4430 intptr_t numResultDims, intptr_t numResultSymbols) {
4431 if (static_cast<intptr_t>(dimExprs.size()) !=
4432 mlirIntegerSetGetNumDims(self))
4433 throw py::value_error(
4434 "Expected the number of dimension replacement expressions "
4435 "to match that of dimensions");
4436 if (static_cast<intptr_t>(symbolExprs.size()) !=
4437 mlirIntegerSetGetNumSymbols(self))
4438 throw py::value_error(
4439 "Expected the number of symbol replacement expressions "
4440 "to match that of symbols");
4441
4442 SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
4443 pyListToVector<PyAffineExpr>(
4444 dimExprs, dimAffineExprs,
4445 "attempting to create an IntegerSet by replacing dimensions");
4446 pyListToVector<PyAffineExpr>(
4447 symbolExprs, symbolAffineExprs,
4448 "attempting to create an IntegerSet by replacing symbols");
4449 MlirIntegerSet set = mlirIntegerSetReplaceGet(
4450 self, dimAffineExprs.data(), symbolAffineExprs.data(),
4451 numResultDims, numResultSymbols);
4452 return PyIntegerSet(self.getContext(), set);
4453 })
4454 .def_property_readonly("is_canonical_empty",
4455 [](PyIntegerSet &self) {
4456 return mlirIntegerSetIsCanonicalEmpty(self);
4457 })
4458 .def_property_readonly(
4459 "n_dims",
4460 [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
4461 .def_property_readonly(
4462 "n_symbols",
4463 [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
4464 .def_property_readonly(
4465 "n_inputs",
4466 [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
4467 .def_property_readonly("n_equalities",
4468 [](PyIntegerSet &self) {
4469 return mlirIntegerSetGetNumEqualities(self);
4470 })
4471 .def_property_readonly("n_inequalities",
4472 [](PyIntegerSet &self) {
4473 return mlirIntegerSetGetNumInequalities(self);
4474 })
4475 .def_property_readonly("constraints", [](PyIntegerSet &self) {
4476 return PyIntegerSetConstraintList(self);
4477 });
4478 PyIntegerSetConstraint::bind(m);
4479 PyIntegerSetConstraintList::bind(m);
4480 }
4481