1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <vector>
13 
14 #include "PybindUtils.h"
15 
16 #include "mlir-c/AffineExpr.h"
17 #include "mlir-c/AffineMap.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/IntegerSet.h"
20 #include "llvm/ADT/DenseMap.h"
21 
22 namespace mlir {
23 namespace python {
24 
25 class PyBlock;
26 class PyInsertionPoint;
27 class PyLocation;
28 class DefaultingPyLocation;
29 class PyMlirContext;
30 class DefaultingPyMlirContext;
31 class PyModule;
32 class PyOperation;
33 class PyType;
34 class PyValue;
35 
36 /// Template for a reference to a concrete type which captures a python
37 /// reference to its underlying python object.
38 template <typename T>
39 class PyObjectRef {
40 public:
PyObjectRef(T * referrent,pybind11::object object)41   PyObjectRef(T *referrent, pybind11::object object)
42       : referrent(referrent), object(std::move(object)) {
43     assert(this->referrent &&
44            "cannot construct PyObjectRef with null referrent");
45     assert(this->object && "cannot construct PyObjectRef with null object");
46   }
PyObjectRef(PyObjectRef && other)47   PyObjectRef(PyObjectRef &&other)
48       : referrent(other.referrent), object(std::move(other.object)) {
49     other.referrent = nullptr;
50     assert(!other.object);
51   }
PyObjectRef(const PyObjectRef & other)52   PyObjectRef(const PyObjectRef &other)
53       : referrent(other.referrent), object(other.object /* copies */) {}
~PyObjectRef()54   ~PyObjectRef() {}
55 
getRefCount()56   int getRefCount() {
57     if (!object)
58       return 0;
59     return object.ref_count();
60   }
61 
62   /// Releases the object held by this instance, returning it.
63   /// This is the proper thing to return from a function that wants to return
64   /// the reference. Note that this does not work from initializers.
releaseObject()65   pybind11::object releaseObject() {
66     assert(referrent && object);
67     referrent = nullptr;
68     auto stolen = std::move(object);
69     return stolen;
70   }
71 
get()72   T *get() { return referrent; }
73   T *operator->() {
74     assert(referrent && object);
75     return referrent;
76   }
getObject()77   pybind11::object getObject() {
78     assert(referrent && object);
79     return object;
80   }
81   operator bool() const { return referrent && object; }
82 
83 private:
84   T *referrent;
85   pybind11::object object;
86 };
87 
88 /// Tracks an entry in the thread context stack. New entries are pushed onto
89 /// here for each with block that activates a new InsertionPoint, Context or
90 /// Location.
91 ///
92 /// Pushing either a Location or InsertionPoint also pushes its associated
93 /// Context. Pushing a Context will not modify the Location or InsertionPoint
94 /// unless if they are from a different context, in which case, they are
95 /// cleared.
96 class PyThreadContextEntry {
97 public:
98   enum class FrameKind {
99     Context,
100     InsertionPoint,
101     Location,
102   };
103 
PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)104   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
105                        pybind11::object insertionPoint,
106                        pybind11::object location)
107       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
108         location(std::move(location)), frameKind(frameKind) {}
109 
110   /// Gets the top of stack context and return nullptr if not defined.
111   static PyMlirContext *getDefaultContext();
112 
113   /// Gets the top of stack insertion point and return nullptr if not defined.
114   static PyInsertionPoint *getDefaultInsertionPoint();
115 
116   /// Gets the top of stack location and returns nullptr if not defined.
117   static PyLocation *getDefaultLocation();
118 
119   PyMlirContext *getContext();
120   PyInsertionPoint *getInsertionPoint();
121   PyLocation *getLocation();
getFrameKind()122   FrameKind getFrameKind() { return frameKind; }
123 
124   /// Stack management.
125   static PyThreadContextEntry *getTopOfStack();
126   static pybind11::object pushContext(PyMlirContext &context);
127   static void popContext(PyMlirContext &context);
128   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
129   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
130   static pybind11::object pushLocation(PyLocation &location);
131   static void popLocation(PyLocation &location);
132 
133   /// Gets the thread local stack.
134   static std::vector<PyThreadContextEntry> &getStack();
135 
136 private:
137   static void push(FrameKind frameKind, pybind11::object context,
138                    pybind11::object insertionPoint, pybind11::object location);
139 
140   /// An object reference to the PyContext.
141   pybind11::object context;
142   /// An object reference to the current insertion point.
143   pybind11::object insertionPoint;
144   /// An object reference to the current location.
145   pybind11::object location;
146   // The kind of push that was performed.
147   FrameKind frameKind;
148 };
149 
150 /// Wrapper around MlirContext.
151 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
152 class PyMlirContext {
153 public:
154   PyMlirContext() = delete;
155   PyMlirContext(const PyMlirContext &) = delete;
156   PyMlirContext(PyMlirContext &&) = delete;
157 
158   /// For the case of a python __init__ (py::init) method, pybind11 is quite
159   /// strict about needing to return a pointer that is not yet associated to
160   /// an py::object. Since the forContext() method acts like a pool, possibly
161   /// returning a recycled context, it does not satisfy this need. The usual
162   /// way in python to accomplish such a thing is to override __new__, but
163   /// that is also not supported by pybind11. Instead, we use this entry
164   /// point which always constructs a fresh context (which cannot alias an
165   /// existing one because it is fresh).
166   static PyMlirContext *createNewContextForInit();
167 
168   /// Returns a context reference for the singleton PyMlirContext wrapper for
169   /// the given context.
170   static PyMlirContextRef forContext(MlirContext context);
171   ~PyMlirContext();
172 
173   /// Accesses the underlying MlirContext.
get()174   MlirContext get() { return context; }
175 
176   /// Gets a strong reference to this context, which will ensure it is kept
177   /// alive for the life of the reference.
getRef()178   PyMlirContextRef getRef() {
179     return PyMlirContextRef(this, pybind11::cast(this));
180   }
181 
182   /// Gets a capsule wrapping the void* within the MlirContext.
183   pybind11::object getCapsule();
184 
185   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
186   /// Note that PyMlirContext instances are uniqued, so the returned object
187   /// may be a pre-existing object. Ownership of the underlying MlirContext
188   /// is taken by calling this function.
189   static pybind11::object createFromCapsule(pybind11::object capsule);
190 
191   /// Gets the count of live context objects. Used for testing.
192   static size_t getLiveCount();
193 
194   /// Gets the count of live operations associated with this context.
195   /// Used for testing.
196   size_t getLiveOperationCount();
197 
198   /// Gets the count of live modules associated with this context.
199   /// Used for testing.
200   size_t getLiveModuleCount();
201 
202   /// Enter and exit the context manager.
203   pybind11::object contextEnter();
204   void contextExit(pybind11::object excType, pybind11::object excVal,
205                    pybind11::object excTb);
206 
207 private:
208   PyMlirContext(MlirContext context);
209   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
210   // preserving the relationship that an MlirContext maps to a single
211   // PyMlirContext wrapper. This could be replaced in the future with an
212   // extension mechanism on the MlirContext for stashing user pointers.
213   // Note that this holds a handle, which does not imply ownership.
214   // Mappings will be removed when the context is destructed.
215   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
216   static LiveContextMap &getLiveContexts();
217 
218   // Interns all live modules associated with this context. Modules tracked
219   // in this map are valid. When a module is invalidated, it is removed
220   // from this map, and while it still exists as an instance, any
221   // attempt to access it will raise an error.
222   using LiveModuleMap =
223       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
224   LiveModuleMap liveModules;
225 
226   // Interns all live operations associated with this context. Operations
227   // tracked in this map are valid. When an operation is invalidated, it is
228   // removed from this map, and while it still exists as an instance, any
229   // attempt to access it will raise an error.
230   using LiveOperationMap =
231       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
232   LiveOperationMap liveOperations;
233 
234   MlirContext context;
235   friend class PyModule;
236   friend class PyOperation;
237 };
238 
239 /// Used in function arguments when None should resolve to the current context
240 /// manager set instance.
241 class DefaultingPyMlirContext
242     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
243 public:
244   using Defaulting::Defaulting;
245   static constexpr const char kTypeDescription[] =
246       "[ThreadContextAware] mlir.ir.Context";
247   static PyMlirContext &resolve();
248 };
249 
250 /// Base class for all objects that directly or indirectly depend on an
251 /// MlirContext. The lifetime of the context will extend at least to the
252 /// lifetime of these instances.
253 /// Immutable objects that depend on a context extend this directly.
254 class BaseContextObject {
255 public:
BaseContextObject(PyMlirContextRef ref)256   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
257     assert(this->contextRef &&
258            "context object constructed with null context ref");
259   }
260 
261   /// Accesses the context reference.
getContext()262   PyMlirContextRef &getContext() { return contextRef; }
263 
264 private:
265   PyMlirContextRef contextRef;
266 };
267 
268 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
269 /// order to differentiate it from the `Dialect` base class which is extended by
270 /// plugins which extend dialect functionality through extension python code.
271 /// This should be seen as the "low-level" object and `Dialect` as the
272 /// high-level, user facing object.
273 class PyDialectDescriptor : public BaseContextObject {
274 public:
PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)275   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
276       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
277 
get()278   MlirDialect get() { return dialect; }
279 
280 private:
281   MlirDialect dialect;
282 };
283 
284 /// User-level object for accessing dialects with dotted syntax such as:
285 ///   ctx.dialect.std
286 class PyDialects : public BaseContextObject {
287 public:
PyDialects(PyMlirContextRef contextRef)288   PyDialects(PyMlirContextRef contextRef)
289       : BaseContextObject(std::move(contextRef)) {}
290 
291   MlirDialect getDialectForKey(const std::string &key, bool attrError);
292 };
293 
294 /// User-level dialect object. For dialects that have a registered extension,
295 /// this will be the base class of the extension dialect type. For un-extended,
296 /// objects of this type will be returned directly.
297 class PyDialect {
298 public:
PyDialect(pybind11::object descriptor)299   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
300 
getDescriptor()301   pybind11::object getDescriptor() { return descriptor; }
302 
303 private:
304   pybind11::object descriptor;
305 };
306 
307 /// Wrapper around an MlirLocation.
308 class PyLocation : public BaseContextObject {
309 public:
PyLocation(PyMlirContextRef contextRef,MlirLocation loc)310   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
311       : BaseContextObject(std::move(contextRef)), loc(loc) {}
312 
MlirLocation()313   operator MlirLocation() const { return loc; }
get()314   MlirLocation get() const { return loc; }
315 
316   /// Enter and exit the context manager.
317   pybind11::object contextEnter();
318   void contextExit(pybind11::object excType, pybind11::object excVal,
319                    pybind11::object excTb);
320 
321   /// Gets a capsule wrapping the void* within the MlirLocation.
322   pybind11::object getCapsule();
323 
324   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
325   /// Note that PyLocation instances are uniqued, so the returned object
326   /// may be a pre-existing object. Ownership of the underlying MlirLocation
327   /// is taken by calling this function.
328   static PyLocation createFromCapsule(pybind11::object capsule);
329 
330 private:
331   MlirLocation loc;
332 };
333 
334 /// Used in function arguments when None should resolve to the current context
335 /// manager set instance.
336 class DefaultingPyLocation
337     : public Defaulting<DefaultingPyLocation, PyLocation> {
338 public:
339   using Defaulting::Defaulting;
340   static constexpr const char kTypeDescription[] =
341       "[ThreadContextAware] mlir.ir.Location";
342   static PyLocation &resolve();
343 
MlirLocation()344   operator MlirLocation() const { return *get(); }
345 };
346 
347 /// Wrapper around MlirModule.
348 /// This is the top-level, user-owned object that contains regions/ops/blocks.
349 class PyModule;
350 using PyModuleRef = PyObjectRef<PyModule>;
351 class PyModule : public BaseContextObject {
352 public:
353   /// Returns a PyModule reference for the given MlirModule. This may return
354   /// a pre-existing or new object.
355   static PyModuleRef forModule(MlirModule module);
356   PyModule(PyModule &) = delete;
357   PyModule(PyMlirContext &&) = delete;
358   ~PyModule();
359 
360   /// Gets the backing MlirModule.
get()361   MlirModule get() { return module; }
362 
363   /// Gets a strong reference to this module.
getRef()364   PyModuleRef getRef() {
365     return PyModuleRef(this,
366                        pybind11::reinterpret_borrow<pybind11::object>(handle));
367   }
368 
369   /// Gets a capsule wrapping the void* within the MlirModule.
370   /// Note that the module does not (yet) provide a corresponding factory for
371   /// constructing from a capsule as that would require uniquing PyModule
372   /// instances, which is not currently done.
373   pybind11::object getCapsule();
374 
375   /// Creates a PyModule from the MlirModule wrapped by a capsule.
376   /// Note that PyModule instances are uniqued, so the returned object
377   /// may be a pre-existing object. Ownership of the underlying MlirModule
378   /// is taken by calling this function.
379   static pybind11::object createFromCapsule(pybind11::object capsule);
380 
381 private:
382   PyModule(PyMlirContextRef contextRef, MlirModule module);
383   MlirModule module;
384   pybind11::handle handle;
385 };
386 
387 /// Base class for PyOperation and PyOpView which exposes the primary, user
388 /// visible methods for manipulating it.
389 class PyOperationBase {
390 public:
391   virtual ~PyOperationBase() = default;
392   /// Implements the bound 'print' method and helps with others.
393   void print(pybind11::object fileObject, bool binary,
394              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
395              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
396   pybind11::object getAsm(bool binary,
397                           llvm::Optional<int64_t> largeElementsLimit,
398                           bool enableDebugInfo, bool prettyDebugInfo,
399                           bool printGenericOpForm, bool useLocalScope);
400 
401   /// Each must provide access to the raw Operation.
402   virtual PyOperation &getOperation() = 0;
403 };
404 
405 /// Wrapper around PyOperation.
406 /// Operations exist in either an attached (dependent) or detached (top-level)
407 /// state. In the detached state (as on creation), an operation is owned by
408 /// the creator and its lifetime extends either until its reference count
409 /// drops to zero or it is attached to a parent, at which point its lifetime
410 /// is bounded by its top-level parent reference.
411 class PyOperation;
412 using PyOperationRef = PyObjectRef<PyOperation>;
413 class PyOperation : public PyOperationBase, public BaseContextObject {
414 public:
415   ~PyOperation();
getOperation()416   PyOperation &getOperation() override { return *this; }
417 
418   /// Returns a PyOperation for the given MlirOperation, optionally associating
419   /// it with a parentKeepAlive.
420   static PyOperationRef
421   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
422                pybind11::object parentKeepAlive = pybind11::object());
423 
424   /// Creates a detached operation. The operation must not be associated with
425   /// any existing live operation.
426   static PyOperationRef
427   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
428                  pybind11::object parentKeepAlive = pybind11::object());
429 
430   /// Gets the backing operation.
MlirOperation()431   operator MlirOperation() const { return get(); }
get()432   MlirOperation get() const {
433     checkValid();
434     return operation;
435   }
436 
getRef()437   PyOperationRef getRef() {
438     return PyOperationRef(
439         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
440   }
441 
isAttached()442   bool isAttached() { return attached; }
setAttached()443   void setAttached() {
444     assert(!attached && "operation already attached");
445     attached = true;
446   }
447   void checkValid() const;
448 
449   /// Gets the owning block or raises an exception if the operation has no
450   /// owning block.
451   PyBlock getBlock();
452 
453   /// Gets the parent operation or raises an exception if the operation has
454   /// no parent.
455   PyOperationRef getParentOperation();
456 
457   /// Gets a capsule wrapping the void* within the MlirOperation.
458   pybind11::object getCapsule();
459 
460   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
461   /// Ownership of the underlying MlirOperation is taken by calling this
462   /// function.
463   static pybind11::object createFromCapsule(pybind11::object capsule);
464 
465   /// Creates an operation. See corresponding python docstring.
466   static pybind11::object
467   create(std::string name, llvm::Optional<std::vector<PyType *>> results,
468          llvm::Optional<std::vector<PyValue *>> operands,
469          llvm::Optional<pybind11::dict> attributes,
470          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
471          DefaultingPyLocation location, pybind11::object ip);
472 
473   /// Creates an OpView suitable for this operation.
474   pybind11::object createOpView();
475 
476   /// Erases the underlying MlirOperation, removes its pointer from the
477   /// parent context's live operations map, and sets the valid bit false.
478   void erase();
479 
480 private:
481   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
482   static PyOperationRef createInstance(PyMlirContextRef contextRef,
483                                        MlirOperation operation,
484                                        pybind11::object parentKeepAlive);
485 
486   MlirOperation operation;
487   pybind11::handle handle;
488   // Keeps the parent alive, regardless of whether it is an Operation or
489   // Module.
490   // TODO: As implemented, this facility is only sufficient for modeling the
491   // trivial module parent back-reference. Generalize this to also account for
492   // transitions from detached to attached and address TODOs in the
493   // ir_operation.py regarding testing corresponding lifetime guarantees.
494   pybind11::object parentKeepAlive;
495   bool attached = true;
496   bool valid = true;
497 };
498 
499 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
500 /// providing more instance-specific accessors and serve as the base class for
501 /// custom ODS-style operation classes. Since this class is subclass on the
502 /// python side, it must present an __init__ method that operates in pure
503 /// python types.
504 class PyOpView : public PyOperationBase {
505 public:
506   PyOpView(pybind11::object operationObject);
getOperation()507   PyOperation &getOperation() override { return operation; }
508 
509   static pybind11::object createRawSubclass(pybind11::object userClass);
510 
getOperationObject()511   pybind11::object getOperationObject() { return operationObject; }
512 
513   static pybind11::object
514   buildGeneric(pybind11::object cls, pybind11::list resultTypeList,
515                pybind11::list operandList,
516                llvm::Optional<pybind11::dict> attributes,
517                llvm::Optional<std::vector<PyBlock *>> successors,
518                llvm::Optional<int> regions, DefaultingPyLocation location,
519                pybind11::object maybeIp);
520 
521 private:
522   PyOperation &operation;           // For efficient, cast-free access from C++
523   pybind11::object operationObject; // Holds the reference.
524 };
525 
526 /// Wrapper around an MlirRegion.
527 /// Regions are managed completely by their containing operation. Unlike the
528 /// C++ API, the python API does not support detached regions.
529 class PyRegion {
530 public:
PyRegion(PyOperationRef parentOperation,MlirRegion region)531   PyRegion(PyOperationRef parentOperation, MlirRegion region)
532       : parentOperation(std::move(parentOperation)), region(region) {
533     assert(!mlirRegionIsNull(region) && "python region cannot be null");
534   }
535 
get()536   MlirRegion get() { return region; }
getParentOperation()537   PyOperationRef &getParentOperation() { return parentOperation; }
538 
checkValid()539   void checkValid() { return parentOperation->checkValid(); }
540 
541 private:
542   PyOperationRef parentOperation;
543   MlirRegion region;
544 };
545 
546 /// Wrapper around an MlirBlock.
547 /// Blocks are managed completely by their containing operation. Unlike the
548 /// C++ API, the python API does not support detached blocks.
549 class PyBlock {
550 public:
PyBlock(PyOperationRef parentOperation,MlirBlock block)551   PyBlock(PyOperationRef parentOperation, MlirBlock block)
552       : parentOperation(std::move(parentOperation)), block(block) {
553     assert(!mlirBlockIsNull(block) && "python block cannot be null");
554   }
555 
get()556   MlirBlock get() { return block; }
getParentOperation()557   PyOperationRef &getParentOperation() { return parentOperation; }
558 
checkValid()559   void checkValid() { return parentOperation->checkValid(); }
560 
561 private:
562   PyOperationRef parentOperation;
563   MlirBlock block;
564 };
565 
566 /// An insertion point maintains a pointer to a Block and a reference operation.
567 /// Calls to insert() will insert a new operation before the
568 /// reference operation. If the reference operation is null, then appends to
569 /// the end of the block.
570 class PyInsertionPoint {
571 public:
572   /// Creates an insertion point positioned after the last operation in the
573   /// block, but still inside the block.
574   PyInsertionPoint(PyBlock &block);
575   /// Creates an insertion point positioned before a reference operation.
576   PyInsertionPoint(PyOperationBase &beforeOperationBase);
577 
578   /// Shortcut to create an insertion point at the beginning of the block.
579   static PyInsertionPoint atBlockBegin(PyBlock &block);
580   /// Shortcut to create an insertion point before the block terminator.
581   static PyInsertionPoint atBlockTerminator(PyBlock &block);
582 
583   /// Inserts an operation.
584   void insert(PyOperationBase &operationBase);
585 
586   /// Enter and exit the context manager.
587   pybind11::object contextEnter();
588   void contextExit(pybind11::object excType, pybind11::object excVal,
589                    pybind11::object excTb);
590 
getBlock()591   PyBlock &getBlock() { return block; }
592 
593 private:
594   // Trampoline constructor that avoids null initializing members while
595   // looking up parents.
PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)596   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
597       : refOperation(std::move(refOperation)), block(std::move(block)) {}
598 
599   llvm::Optional<PyOperationRef> refOperation;
600   PyBlock block;
601 };
602 
603 /// Wrapper around the generic MlirAttribute.
604 /// The lifetime of a type is bound by the PyContext that created it.
605 class PyAttribute : public BaseContextObject {
606 public:
PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)607   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
608       : BaseContextObject(std::move(contextRef)), attr(attr) {}
609   bool operator==(const PyAttribute &other);
MlirAttribute()610   operator MlirAttribute() const { return attr; }
get()611   MlirAttribute get() const { return attr; }
612 
613   /// Gets a capsule wrapping the void* within the MlirAttribute.
614   pybind11::object getCapsule();
615 
616   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
617   /// Note that PyAttribute instances are uniqued, so the returned object
618   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
619   /// is taken by calling this function.
620   static PyAttribute createFromCapsule(pybind11::object capsule);
621 
622 private:
623   MlirAttribute attr;
624 };
625 
626 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
627 /// TODO: Refactor this and the C-API to be based on an Identifier owned
628 /// by the context so as to avoid ownership issues here.
629 class PyNamedAttribute {
630 public:
631   /// Constructs a PyNamedAttr that retains an owned name. This should be
632   /// used in any code that originates an MlirNamedAttribute from a python
633   /// string.
634   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
635   /// passed attribute.
636   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
637 
638   MlirNamedAttribute namedAttr;
639 
640 private:
641   // Since the MlirNamedAttr contains an internal pointer to the actual
642   // memory of the owned string, it must be heap allocated to remain valid.
643   // Otherwise, strings that fit within the small object optimization threshold
644   // will have their memory address change as the containing object is moved,
645   // resulting in an invalid aliased pointer.
646   std::unique_ptr<std::string> ownedName;
647 };
648 
649 /// CRTP base classes for Python attributes that subclass Attribute and should
650 /// be castable from it (i.e. via something like StringAttr(attr)).
651 /// By default, attribute class hierarchies are one level deep (i.e. a
652 /// concrete attribute class extends PyAttribute); however, intermediate
653 /// python-visible base classes can be modeled by specifying a BaseTy.
654 template <typename DerivedTy, typename BaseTy = PyAttribute>
655 class PyConcreteAttribute : public BaseTy {
656 public:
657   // Derived classes must define statics for:
658   //   IsAFunctionTy isaFunction
659   //   const char *pyClassName
660   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
661   using IsAFunctionTy = bool (*)(MlirAttribute);
662 
663   PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)664   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
665       : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)666   PyConcreteAttribute(PyAttribute &orig)
667       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
668 
castFrom(PyAttribute & orig)669   static MlirAttribute castFrom(PyAttribute &orig) {
670     if (!DerivedTy::isaFunction(orig)) {
671       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
672       throw SetPyError(PyExc_ValueError,
673                        llvm::Twine("Cannot cast attribute to ") +
674                            DerivedTy::pyClassName + " (from " + origRepr + ")");
675     }
676     return orig;
677   }
678 
bind(pybind11::module & m)679   static void bind(pybind11::module &m) {
680     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol());
681     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
682     DerivedTy::bindDerived(cls);
683   }
684 
685   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)686   static void bindDerived(ClassTy &m) {}
687 };
688 
689 /// Wrapper around the generic MlirType.
690 /// The lifetime of a type is bound by the PyContext that created it.
691 class PyType : public BaseContextObject {
692 public:
PyType(PyMlirContextRef contextRef,MlirType type)693   PyType(PyMlirContextRef contextRef, MlirType type)
694       : BaseContextObject(std::move(contextRef)), type(type) {}
695   bool operator==(const PyType &other);
MlirType()696   operator MlirType() const { return type; }
get()697   MlirType get() const { return type; }
698 
699   /// Gets a capsule wrapping the void* within the MlirType.
700   pybind11::object getCapsule();
701 
702   /// Creates a PyType from the MlirType wrapped by a capsule.
703   /// Note that PyType instances are uniqued, so the returned object
704   /// may be a pre-existing object. Ownership of the underlying MlirType
705   /// is taken by calling this function.
706   static PyType createFromCapsule(pybind11::object capsule);
707 
708 private:
709   MlirType type;
710 };
711 
712 /// CRTP base classes for Python types that subclass Type and should be
713 /// castable from it (i.e. via something like IntegerType(t)).
714 /// By default, type class hierarchies are one level deep (i.e. a
715 /// concrete type class extends PyType); however, intermediate python-visible
716 /// base classes can be modeled by specifying a BaseTy.
717 template <typename DerivedTy, typename BaseTy = PyType>
718 class PyConcreteType : public BaseTy {
719 public:
720   // Derived classes must define statics for:
721   //   IsAFunctionTy isaFunction
722   //   const char *pyClassName
723   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
724   using IsAFunctionTy = bool (*)(MlirType);
725 
726   PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)727   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
728       : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)729   PyConcreteType(PyType &orig)
730       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
731 
castFrom(PyType & orig)732   static MlirType castFrom(PyType &orig) {
733     if (!DerivedTy::isaFunction(orig)) {
734       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
735       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
736                                              DerivedTy::pyClassName +
737                                              " (from " + origRepr + ")");
738     }
739     return orig;
740   }
741 
bind(pybind11::module & m)742   static void bind(pybind11::module &m) {
743     auto cls = ClassTy(m, DerivedTy::pyClassName);
744     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
745     cls.def_static("isinstance", [](PyType &otherType) -> bool {
746       return DerivedTy::isaFunction(otherType);
747     });
748     DerivedTy::bindDerived(cls);
749   }
750 
751   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)752   static void bindDerived(ClassTy &m) {}
753 };
754 
755 /// Wrapper around the generic MlirValue.
756 /// Values are managed completely by the operation that resulted in their
757 /// definition. For op result value, this is the operation that defines the
758 /// value. For block argument values, this is the operation that contains the
759 /// block to which the value is an argument (blocks cannot be detached in Python
760 /// bindings so such operation always exists).
761 class PyValue {
762 public:
PyValue(PyOperationRef parentOperation,MlirValue value)763   PyValue(PyOperationRef parentOperation, MlirValue value)
764       : parentOperation(parentOperation), value(value) {}
765 
get()766   MlirValue get() { return value; }
getParentOperation()767   PyOperationRef &getParentOperation() { return parentOperation; }
768 
checkValid()769   void checkValid() { return parentOperation->checkValid(); }
770 
771   /// Gets a capsule wrapping the void* within the MlirValue.
772   pybind11::object getCapsule();
773 
774   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
775   /// the underlying MlirValue is still tied to the owning operation.
776   static PyValue createFromCapsule(pybind11::object capsule);
777 
778 private:
779   PyOperationRef parentOperation;
780   MlirValue value;
781 };
782 
783 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
784 class PyAffineExpr : public BaseContextObject {
785 public:
PyAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)786   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
787       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
788   bool operator==(const PyAffineExpr &other);
MlirAffineExpr()789   operator MlirAffineExpr() const { return affineExpr; }
get()790   MlirAffineExpr get() const { return affineExpr; }
791 
792   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
793   pybind11::object getCapsule();
794 
795   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
796   /// Note that PyAffineExpr instances are uniqued, so the returned object
797   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
798   /// is taken by calling this function.
799   static PyAffineExpr createFromCapsule(pybind11::object capsule);
800 
801   PyAffineExpr add(const PyAffineExpr &other) const;
802   PyAffineExpr mul(const PyAffineExpr &other) const;
803   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
804   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
805   PyAffineExpr mod(const PyAffineExpr &other) const;
806 
807 private:
808   MlirAffineExpr affineExpr;
809 };
810 
811 class PyAffineMap : public BaseContextObject {
812 public:
PyAffineMap(PyMlirContextRef contextRef,MlirAffineMap affineMap)813   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
814       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
815   bool operator==(const PyAffineMap &other);
MlirAffineMap()816   operator MlirAffineMap() const { return affineMap; }
get()817   MlirAffineMap get() const { return affineMap; }
818 
819   /// Gets a capsule wrapping the void* within the MlirAffineMap.
820   pybind11::object getCapsule();
821 
822   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
823   /// Note that PyAffineMap instances are uniqued, so the returned object
824   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
825   /// is taken by calling this function.
826   static PyAffineMap createFromCapsule(pybind11::object capsule);
827 
828 private:
829   MlirAffineMap affineMap;
830 };
831 
832 class PyIntegerSet : public BaseContextObject {
833 public:
PyIntegerSet(PyMlirContextRef contextRef,MlirIntegerSet integerSet)834   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
835       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
836   bool operator==(const PyIntegerSet &other);
MlirIntegerSet()837   operator MlirIntegerSet() const { return integerSet; }
get()838   MlirIntegerSet get() const { return integerSet; }
839 
840   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
841   pybind11::object getCapsule();
842 
843   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
844   /// Note that PyIntegerSet instances may be uniqued, so the returned object
845   /// may be a pre-existing object. Integer sets are owned by the context.
846   static PyIntegerSet createFromCapsule(pybind11::object capsule);
847 
848 private:
849   MlirIntegerSet integerSet;
850 };
851 
852 void populateIRAffine(pybind11::module &m);
853 void populateIRAttributes(pybind11::module &m);
854 void populateIRCore(pybind11::module &m);
855 void populateIRTypes(pybind11::module &m);
856 
857 } // namespace python
858 } // namespace mlir
859 
860 namespace pybind11 {
861 namespace detail {
862 
863 template <>
864 struct type_caster<mlir::python::DefaultingPyMlirContext>
865     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
866 template <>
867 struct type_caster<mlir::python::DefaultingPyLocation>
868     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
869 
870 } // namespace detail
871 } // namespace pybind11
872 
873 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
874