1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <vector> 13 14 #include "PybindUtils.h" 15 16 #include "mlir-c/AffineExpr.h" 17 #include "mlir-c/AffineMap.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/IntegerSet.h" 20 #include "llvm/ADT/DenseMap.h" 21 22 namespace mlir { 23 namespace python { 24 25 class PyBlock; 26 class PyInsertionPoint; 27 class PyLocation; 28 class DefaultingPyLocation; 29 class PyMlirContext; 30 class DefaultingPyMlirContext; 31 class PyModule; 32 class PyOperation; 33 class PyType; 34 class PyValue; 35 36 /// Template for a reference to a concrete type which captures a python 37 /// reference to its underlying python object. 38 template <typename T> 39 class PyObjectRef { 40 public: PyObjectRef(T * referrent,pybind11::object object)41 PyObjectRef(T *referrent, pybind11::object object) 42 : referrent(referrent), object(std::move(object)) { 43 assert(this->referrent && 44 "cannot construct PyObjectRef with null referrent"); 45 assert(this->object && "cannot construct PyObjectRef with null object"); 46 } PyObjectRef(PyObjectRef && other)47 PyObjectRef(PyObjectRef &&other) 48 : referrent(other.referrent), object(std::move(other.object)) { 49 other.referrent = nullptr; 50 assert(!other.object); 51 } PyObjectRef(const PyObjectRef & other)52 PyObjectRef(const PyObjectRef &other) 53 : referrent(other.referrent), object(other.object /* copies */) {} ~PyObjectRef()54 ~PyObjectRef() {} 55 getRefCount()56 int getRefCount() { 57 if (!object) 58 return 0; 59 return object.ref_count(); 60 } 61 62 /// Releases the object held by this instance, returning it. 63 /// This is the proper thing to return from a function that wants to return 64 /// the reference. Note that this does not work from initializers. releaseObject()65 pybind11::object releaseObject() { 66 assert(referrent && object); 67 referrent = nullptr; 68 auto stolen = std::move(object); 69 return stolen; 70 } 71 get()72 T *get() { return referrent; } 73 T *operator->() { 74 assert(referrent && object); 75 return referrent; 76 } getObject()77 pybind11::object getObject() { 78 assert(referrent && object); 79 return object; 80 } 81 operator bool() const { return referrent && object; } 82 83 private: 84 T *referrent; 85 pybind11::object object; 86 }; 87 88 /// Tracks an entry in the thread context stack. New entries are pushed onto 89 /// here for each with block that activates a new InsertionPoint, Context or 90 /// Location. 91 /// 92 /// Pushing either a Location or InsertionPoint also pushes its associated 93 /// Context. Pushing a Context will not modify the Location or InsertionPoint 94 /// unless if they are from a different context, in which case, they are 95 /// cleared. 96 class PyThreadContextEntry { 97 public: 98 enum class FrameKind { 99 Context, 100 InsertionPoint, 101 Location, 102 }; 103 PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)104 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 105 pybind11::object insertionPoint, 106 pybind11::object location) 107 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 108 location(std::move(location)), frameKind(frameKind) {} 109 110 /// Gets the top of stack context and return nullptr if not defined. 111 static PyMlirContext *getDefaultContext(); 112 113 /// Gets the top of stack insertion point and return nullptr if not defined. 114 static PyInsertionPoint *getDefaultInsertionPoint(); 115 116 /// Gets the top of stack location and returns nullptr if not defined. 117 static PyLocation *getDefaultLocation(); 118 119 PyMlirContext *getContext(); 120 PyInsertionPoint *getInsertionPoint(); 121 PyLocation *getLocation(); getFrameKind()122 FrameKind getFrameKind() { return frameKind; } 123 124 /// Stack management. 125 static PyThreadContextEntry *getTopOfStack(); 126 static pybind11::object pushContext(PyMlirContext &context); 127 static void popContext(PyMlirContext &context); 128 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 129 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 130 static pybind11::object pushLocation(PyLocation &location); 131 static void popLocation(PyLocation &location); 132 133 /// Gets the thread local stack. 134 static std::vector<PyThreadContextEntry> &getStack(); 135 136 private: 137 static void push(FrameKind frameKind, pybind11::object context, 138 pybind11::object insertionPoint, pybind11::object location); 139 140 /// An object reference to the PyContext. 141 pybind11::object context; 142 /// An object reference to the current insertion point. 143 pybind11::object insertionPoint; 144 /// An object reference to the current location. 145 pybind11::object location; 146 // The kind of push that was performed. 147 FrameKind frameKind; 148 }; 149 150 /// Wrapper around MlirContext. 151 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 152 class PyMlirContext { 153 public: 154 PyMlirContext() = delete; 155 PyMlirContext(const PyMlirContext &) = delete; 156 PyMlirContext(PyMlirContext &&) = delete; 157 158 /// For the case of a python __init__ (py::init) method, pybind11 is quite 159 /// strict about needing to return a pointer that is not yet associated to 160 /// an py::object. Since the forContext() method acts like a pool, possibly 161 /// returning a recycled context, it does not satisfy this need. The usual 162 /// way in python to accomplish such a thing is to override __new__, but 163 /// that is also not supported by pybind11. Instead, we use this entry 164 /// point which always constructs a fresh context (which cannot alias an 165 /// existing one because it is fresh). 166 static PyMlirContext *createNewContextForInit(); 167 168 /// Returns a context reference for the singleton PyMlirContext wrapper for 169 /// the given context. 170 static PyMlirContextRef forContext(MlirContext context); 171 ~PyMlirContext(); 172 173 /// Accesses the underlying MlirContext. get()174 MlirContext get() { return context; } 175 176 /// Gets a strong reference to this context, which will ensure it is kept 177 /// alive for the life of the reference. getRef()178 PyMlirContextRef getRef() { 179 return PyMlirContextRef(this, pybind11::cast(this)); 180 } 181 182 /// Gets a capsule wrapping the void* within the MlirContext. 183 pybind11::object getCapsule(); 184 185 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 186 /// Note that PyMlirContext instances are uniqued, so the returned object 187 /// may be a pre-existing object. Ownership of the underlying MlirContext 188 /// is taken by calling this function. 189 static pybind11::object createFromCapsule(pybind11::object capsule); 190 191 /// Gets the count of live context objects. Used for testing. 192 static size_t getLiveCount(); 193 194 /// Gets the count of live operations associated with this context. 195 /// Used for testing. 196 size_t getLiveOperationCount(); 197 198 /// Gets the count of live modules associated with this context. 199 /// Used for testing. 200 size_t getLiveModuleCount(); 201 202 /// Enter and exit the context manager. 203 pybind11::object contextEnter(); 204 void contextExit(pybind11::object excType, pybind11::object excVal, 205 pybind11::object excTb); 206 207 private: 208 PyMlirContext(MlirContext context); 209 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 210 // preserving the relationship that an MlirContext maps to a single 211 // PyMlirContext wrapper. This could be replaced in the future with an 212 // extension mechanism on the MlirContext for stashing user pointers. 213 // Note that this holds a handle, which does not imply ownership. 214 // Mappings will be removed when the context is destructed. 215 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 216 static LiveContextMap &getLiveContexts(); 217 218 // Interns all live modules associated with this context. Modules tracked 219 // in this map are valid. When a module is invalidated, it is removed 220 // from this map, and while it still exists as an instance, any 221 // attempt to access it will raise an error. 222 using LiveModuleMap = 223 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 224 LiveModuleMap liveModules; 225 226 // Interns all live operations associated with this context. Operations 227 // tracked in this map are valid. When an operation is invalidated, it is 228 // removed from this map, and while it still exists as an instance, any 229 // attempt to access it will raise an error. 230 using LiveOperationMap = 231 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 232 LiveOperationMap liveOperations; 233 234 MlirContext context; 235 friend class PyModule; 236 friend class PyOperation; 237 }; 238 239 /// Used in function arguments when None should resolve to the current context 240 /// manager set instance. 241 class DefaultingPyMlirContext 242 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 243 public: 244 using Defaulting::Defaulting; 245 static constexpr const char kTypeDescription[] = 246 "[ThreadContextAware] mlir.ir.Context"; 247 static PyMlirContext &resolve(); 248 }; 249 250 /// Base class for all objects that directly or indirectly depend on an 251 /// MlirContext. The lifetime of the context will extend at least to the 252 /// lifetime of these instances. 253 /// Immutable objects that depend on a context extend this directly. 254 class BaseContextObject { 255 public: BaseContextObject(PyMlirContextRef ref)256 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 257 assert(this->contextRef && 258 "context object constructed with null context ref"); 259 } 260 261 /// Accesses the context reference. getContext()262 PyMlirContextRef &getContext() { return contextRef; } 263 264 private: 265 PyMlirContextRef contextRef; 266 }; 267 268 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 269 /// order to differentiate it from the `Dialect` base class which is extended by 270 /// plugins which extend dialect functionality through extension python code. 271 /// This should be seen as the "low-level" object and `Dialect` as the 272 /// high-level, user facing object. 273 class PyDialectDescriptor : public BaseContextObject { 274 public: PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)275 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 276 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 277 get()278 MlirDialect get() { return dialect; } 279 280 private: 281 MlirDialect dialect; 282 }; 283 284 /// User-level object for accessing dialects with dotted syntax such as: 285 /// ctx.dialect.std 286 class PyDialects : public BaseContextObject { 287 public: PyDialects(PyMlirContextRef contextRef)288 PyDialects(PyMlirContextRef contextRef) 289 : BaseContextObject(std::move(contextRef)) {} 290 291 MlirDialect getDialectForKey(const std::string &key, bool attrError); 292 }; 293 294 /// User-level dialect object. For dialects that have a registered extension, 295 /// this will be the base class of the extension dialect type. For un-extended, 296 /// objects of this type will be returned directly. 297 class PyDialect { 298 public: PyDialect(pybind11::object descriptor)299 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 300 getDescriptor()301 pybind11::object getDescriptor() { return descriptor; } 302 303 private: 304 pybind11::object descriptor; 305 }; 306 307 /// Wrapper around an MlirLocation. 308 class PyLocation : public BaseContextObject { 309 public: PyLocation(PyMlirContextRef contextRef,MlirLocation loc)310 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 311 : BaseContextObject(std::move(contextRef)), loc(loc) {} 312 MlirLocation()313 operator MlirLocation() const { return loc; } get()314 MlirLocation get() const { return loc; } 315 316 /// Enter and exit the context manager. 317 pybind11::object contextEnter(); 318 void contextExit(pybind11::object excType, pybind11::object excVal, 319 pybind11::object excTb); 320 321 /// Gets a capsule wrapping the void* within the MlirLocation. 322 pybind11::object getCapsule(); 323 324 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 325 /// Note that PyLocation instances are uniqued, so the returned object 326 /// may be a pre-existing object. Ownership of the underlying MlirLocation 327 /// is taken by calling this function. 328 static PyLocation createFromCapsule(pybind11::object capsule); 329 330 private: 331 MlirLocation loc; 332 }; 333 334 /// Used in function arguments when None should resolve to the current context 335 /// manager set instance. 336 class DefaultingPyLocation 337 : public Defaulting<DefaultingPyLocation, PyLocation> { 338 public: 339 using Defaulting::Defaulting; 340 static constexpr const char kTypeDescription[] = 341 "[ThreadContextAware] mlir.ir.Location"; 342 static PyLocation &resolve(); 343 MlirLocation()344 operator MlirLocation() const { return *get(); } 345 }; 346 347 /// Wrapper around MlirModule. 348 /// This is the top-level, user-owned object that contains regions/ops/blocks. 349 class PyModule; 350 using PyModuleRef = PyObjectRef<PyModule>; 351 class PyModule : public BaseContextObject { 352 public: 353 /// Returns a PyModule reference for the given MlirModule. This may return 354 /// a pre-existing or new object. 355 static PyModuleRef forModule(MlirModule module); 356 PyModule(PyModule &) = delete; 357 PyModule(PyMlirContext &&) = delete; 358 ~PyModule(); 359 360 /// Gets the backing MlirModule. get()361 MlirModule get() { return module; } 362 363 /// Gets a strong reference to this module. getRef()364 PyModuleRef getRef() { 365 return PyModuleRef(this, 366 pybind11::reinterpret_borrow<pybind11::object>(handle)); 367 } 368 369 /// Gets a capsule wrapping the void* within the MlirModule. 370 /// Note that the module does not (yet) provide a corresponding factory for 371 /// constructing from a capsule as that would require uniquing PyModule 372 /// instances, which is not currently done. 373 pybind11::object getCapsule(); 374 375 /// Creates a PyModule from the MlirModule wrapped by a capsule. 376 /// Note that PyModule instances are uniqued, so the returned object 377 /// may be a pre-existing object. Ownership of the underlying MlirModule 378 /// is taken by calling this function. 379 static pybind11::object createFromCapsule(pybind11::object capsule); 380 381 private: 382 PyModule(PyMlirContextRef contextRef, MlirModule module); 383 MlirModule module; 384 pybind11::handle handle; 385 }; 386 387 /// Base class for PyOperation and PyOpView which exposes the primary, user 388 /// visible methods for manipulating it. 389 class PyOperationBase { 390 public: 391 virtual ~PyOperationBase() = default; 392 /// Implements the bound 'print' method and helps with others. 393 void print(pybind11::object fileObject, bool binary, 394 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 395 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); 396 pybind11::object getAsm(bool binary, 397 llvm::Optional<int64_t> largeElementsLimit, 398 bool enableDebugInfo, bool prettyDebugInfo, 399 bool printGenericOpForm, bool useLocalScope); 400 401 /// Each must provide access to the raw Operation. 402 virtual PyOperation &getOperation() = 0; 403 }; 404 405 /// Wrapper around PyOperation. 406 /// Operations exist in either an attached (dependent) or detached (top-level) 407 /// state. In the detached state (as on creation), an operation is owned by 408 /// the creator and its lifetime extends either until its reference count 409 /// drops to zero or it is attached to a parent, at which point its lifetime 410 /// is bounded by its top-level parent reference. 411 class PyOperation; 412 using PyOperationRef = PyObjectRef<PyOperation>; 413 class PyOperation : public PyOperationBase, public BaseContextObject { 414 public: 415 ~PyOperation(); getOperation()416 PyOperation &getOperation() override { return *this; } 417 418 /// Returns a PyOperation for the given MlirOperation, optionally associating 419 /// it with a parentKeepAlive. 420 static PyOperationRef 421 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 422 pybind11::object parentKeepAlive = pybind11::object()); 423 424 /// Creates a detached operation. The operation must not be associated with 425 /// any existing live operation. 426 static PyOperationRef 427 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 428 pybind11::object parentKeepAlive = pybind11::object()); 429 430 /// Gets the backing operation. MlirOperation()431 operator MlirOperation() const { return get(); } get()432 MlirOperation get() const { 433 checkValid(); 434 return operation; 435 } 436 getRef()437 PyOperationRef getRef() { 438 return PyOperationRef( 439 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 440 } 441 isAttached()442 bool isAttached() { return attached; } setAttached()443 void setAttached() { 444 assert(!attached && "operation already attached"); 445 attached = true; 446 } 447 void checkValid() const; 448 449 /// Gets the owning block or raises an exception if the operation has no 450 /// owning block. 451 PyBlock getBlock(); 452 453 /// Gets the parent operation or raises an exception if the operation has 454 /// no parent. 455 PyOperationRef getParentOperation(); 456 457 /// Gets a capsule wrapping the void* within the MlirOperation. 458 pybind11::object getCapsule(); 459 460 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 461 /// Ownership of the underlying MlirOperation is taken by calling this 462 /// function. 463 static pybind11::object createFromCapsule(pybind11::object capsule); 464 465 /// Creates an operation. See corresponding python docstring. 466 static pybind11::object 467 create(std::string name, llvm::Optional<std::vector<PyType *>> results, 468 llvm::Optional<std::vector<PyValue *>> operands, 469 llvm::Optional<pybind11::dict> attributes, 470 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 471 DefaultingPyLocation location, pybind11::object ip); 472 473 /// Creates an OpView suitable for this operation. 474 pybind11::object createOpView(); 475 476 /// Erases the underlying MlirOperation, removes its pointer from the 477 /// parent context's live operations map, and sets the valid bit false. 478 void erase(); 479 480 private: 481 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 482 static PyOperationRef createInstance(PyMlirContextRef contextRef, 483 MlirOperation operation, 484 pybind11::object parentKeepAlive); 485 486 MlirOperation operation; 487 pybind11::handle handle; 488 // Keeps the parent alive, regardless of whether it is an Operation or 489 // Module. 490 // TODO: As implemented, this facility is only sufficient for modeling the 491 // trivial module parent back-reference. Generalize this to also account for 492 // transitions from detached to attached and address TODOs in the 493 // ir_operation.py regarding testing corresponding lifetime guarantees. 494 pybind11::object parentKeepAlive; 495 bool attached = true; 496 bool valid = true; 497 }; 498 499 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 500 /// providing more instance-specific accessors and serve as the base class for 501 /// custom ODS-style operation classes. Since this class is subclass on the 502 /// python side, it must present an __init__ method that operates in pure 503 /// python types. 504 class PyOpView : public PyOperationBase { 505 public: 506 PyOpView(pybind11::object operationObject); getOperation()507 PyOperation &getOperation() override { return operation; } 508 509 static pybind11::object createRawSubclass(pybind11::object userClass); 510 getOperationObject()511 pybind11::object getOperationObject() { return operationObject; } 512 513 static pybind11::object 514 buildGeneric(pybind11::object cls, pybind11::list resultTypeList, 515 pybind11::list operandList, 516 llvm::Optional<pybind11::dict> attributes, 517 llvm::Optional<std::vector<PyBlock *>> successors, 518 llvm::Optional<int> regions, DefaultingPyLocation location, 519 pybind11::object maybeIp); 520 521 private: 522 PyOperation &operation; // For efficient, cast-free access from C++ 523 pybind11::object operationObject; // Holds the reference. 524 }; 525 526 /// Wrapper around an MlirRegion. 527 /// Regions are managed completely by their containing operation. Unlike the 528 /// C++ API, the python API does not support detached regions. 529 class PyRegion { 530 public: PyRegion(PyOperationRef parentOperation,MlirRegion region)531 PyRegion(PyOperationRef parentOperation, MlirRegion region) 532 : parentOperation(std::move(parentOperation)), region(region) { 533 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 534 } 535 get()536 MlirRegion get() { return region; } getParentOperation()537 PyOperationRef &getParentOperation() { return parentOperation; } 538 checkValid()539 void checkValid() { return parentOperation->checkValid(); } 540 541 private: 542 PyOperationRef parentOperation; 543 MlirRegion region; 544 }; 545 546 /// Wrapper around an MlirBlock. 547 /// Blocks are managed completely by their containing operation. Unlike the 548 /// C++ API, the python API does not support detached blocks. 549 class PyBlock { 550 public: PyBlock(PyOperationRef parentOperation,MlirBlock block)551 PyBlock(PyOperationRef parentOperation, MlirBlock block) 552 : parentOperation(std::move(parentOperation)), block(block) { 553 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 554 } 555 get()556 MlirBlock get() { return block; } getParentOperation()557 PyOperationRef &getParentOperation() { return parentOperation; } 558 checkValid()559 void checkValid() { return parentOperation->checkValid(); } 560 561 private: 562 PyOperationRef parentOperation; 563 MlirBlock block; 564 }; 565 566 /// An insertion point maintains a pointer to a Block and a reference operation. 567 /// Calls to insert() will insert a new operation before the 568 /// reference operation. If the reference operation is null, then appends to 569 /// the end of the block. 570 class PyInsertionPoint { 571 public: 572 /// Creates an insertion point positioned after the last operation in the 573 /// block, but still inside the block. 574 PyInsertionPoint(PyBlock &block); 575 /// Creates an insertion point positioned before a reference operation. 576 PyInsertionPoint(PyOperationBase &beforeOperationBase); 577 578 /// Shortcut to create an insertion point at the beginning of the block. 579 static PyInsertionPoint atBlockBegin(PyBlock &block); 580 /// Shortcut to create an insertion point before the block terminator. 581 static PyInsertionPoint atBlockTerminator(PyBlock &block); 582 583 /// Inserts an operation. 584 void insert(PyOperationBase &operationBase); 585 586 /// Enter and exit the context manager. 587 pybind11::object contextEnter(); 588 void contextExit(pybind11::object excType, pybind11::object excVal, 589 pybind11::object excTb); 590 getBlock()591 PyBlock &getBlock() { return block; } 592 593 private: 594 // Trampoline constructor that avoids null initializing members while 595 // looking up parents. PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)596 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 597 : refOperation(std::move(refOperation)), block(std::move(block)) {} 598 599 llvm::Optional<PyOperationRef> refOperation; 600 PyBlock block; 601 }; 602 603 /// Wrapper around the generic MlirAttribute. 604 /// The lifetime of a type is bound by the PyContext that created it. 605 class PyAttribute : public BaseContextObject { 606 public: PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)607 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 608 : BaseContextObject(std::move(contextRef)), attr(attr) {} 609 bool operator==(const PyAttribute &other); MlirAttribute()610 operator MlirAttribute() const { return attr; } get()611 MlirAttribute get() const { return attr; } 612 613 /// Gets a capsule wrapping the void* within the MlirAttribute. 614 pybind11::object getCapsule(); 615 616 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 617 /// Note that PyAttribute instances are uniqued, so the returned object 618 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 619 /// is taken by calling this function. 620 static PyAttribute createFromCapsule(pybind11::object capsule); 621 622 private: 623 MlirAttribute attr; 624 }; 625 626 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 627 /// TODO: Refactor this and the C-API to be based on an Identifier owned 628 /// by the context so as to avoid ownership issues here. 629 class PyNamedAttribute { 630 public: 631 /// Constructs a PyNamedAttr that retains an owned name. This should be 632 /// used in any code that originates an MlirNamedAttribute from a python 633 /// string. 634 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 635 /// passed attribute. 636 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 637 638 MlirNamedAttribute namedAttr; 639 640 private: 641 // Since the MlirNamedAttr contains an internal pointer to the actual 642 // memory of the owned string, it must be heap allocated to remain valid. 643 // Otherwise, strings that fit within the small object optimization threshold 644 // will have their memory address change as the containing object is moved, 645 // resulting in an invalid aliased pointer. 646 std::unique_ptr<std::string> ownedName; 647 }; 648 649 /// CRTP base classes for Python attributes that subclass Attribute and should 650 /// be castable from it (i.e. via something like StringAttr(attr)). 651 /// By default, attribute class hierarchies are one level deep (i.e. a 652 /// concrete attribute class extends PyAttribute); however, intermediate 653 /// python-visible base classes can be modeled by specifying a BaseTy. 654 template <typename DerivedTy, typename BaseTy = PyAttribute> 655 class PyConcreteAttribute : public BaseTy { 656 public: 657 // Derived classes must define statics for: 658 // IsAFunctionTy isaFunction 659 // const char *pyClassName 660 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 661 using IsAFunctionTy = bool (*)(MlirAttribute); 662 663 PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)664 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 665 : BaseTy(std::move(contextRef), attr) {} PyConcreteAttribute(PyAttribute & orig)666 PyConcreteAttribute(PyAttribute &orig) 667 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 668 castFrom(PyAttribute & orig)669 static MlirAttribute castFrom(PyAttribute &orig) { 670 if (!DerivedTy::isaFunction(orig)) { 671 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 672 throw SetPyError(PyExc_ValueError, 673 llvm::Twine("Cannot cast attribute to ") + 674 DerivedTy::pyClassName + " (from " + origRepr + ")"); 675 } 676 return orig; 677 } 678 bind(pybind11::module & m)679 static void bind(pybind11::module &m) { 680 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol()); 681 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>()); 682 DerivedTy::bindDerived(cls); 683 } 684 685 /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)686 static void bindDerived(ClassTy &m) {} 687 }; 688 689 /// Wrapper around the generic MlirType. 690 /// The lifetime of a type is bound by the PyContext that created it. 691 class PyType : public BaseContextObject { 692 public: PyType(PyMlirContextRef contextRef,MlirType type)693 PyType(PyMlirContextRef contextRef, MlirType type) 694 : BaseContextObject(std::move(contextRef)), type(type) {} 695 bool operator==(const PyType &other); MlirType()696 operator MlirType() const { return type; } get()697 MlirType get() const { return type; } 698 699 /// Gets a capsule wrapping the void* within the MlirType. 700 pybind11::object getCapsule(); 701 702 /// Creates a PyType from the MlirType wrapped by a capsule. 703 /// Note that PyType instances are uniqued, so the returned object 704 /// may be a pre-existing object. Ownership of the underlying MlirType 705 /// is taken by calling this function. 706 static PyType createFromCapsule(pybind11::object capsule); 707 708 private: 709 MlirType type; 710 }; 711 712 /// CRTP base classes for Python types that subclass Type and should be 713 /// castable from it (i.e. via something like IntegerType(t)). 714 /// By default, type class hierarchies are one level deep (i.e. a 715 /// concrete type class extends PyType); however, intermediate python-visible 716 /// base classes can be modeled by specifying a BaseTy. 717 template <typename DerivedTy, typename BaseTy = PyType> 718 class PyConcreteType : public BaseTy { 719 public: 720 // Derived classes must define statics for: 721 // IsAFunctionTy isaFunction 722 // const char *pyClassName 723 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 724 using IsAFunctionTy = bool (*)(MlirType); 725 726 PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef,MlirType t)727 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 728 : BaseTy(std::move(contextRef), t) {} PyConcreteType(PyType & orig)729 PyConcreteType(PyType &orig) 730 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 731 castFrom(PyType & orig)732 static MlirType castFrom(PyType &orig) { 733 if (!DerivedTy::isaFunction(orig)) { 734 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 735 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 736 DerivedTy::pyClassName + 737 " (from " + origRepr + ")"); 738 } 739 return orig; 740 } 741 bind(pybind11::module & m)742 static void bind(pybind11::module &m) { 743 auto cls = ClassTy(m, DerivedTy::pyClassName); 744 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>()); 745 cls.def_static("isinstance", [](PyType &otherType) -> bool { 746 return DerivedTy::isaFunction(otherType); 747 }); 748 DerivedTy::bindDerived(cls); 749 } 750 751 /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)752 static void bindDerived(ClassTy &m) {} 753 }; 754 755 /// Wrapper around the generic MlirValue. 756 /// Values are managed completely by the operation that resulted in their 757 /// definition. For op result value, this is the operation that defines the 758 /// value. For block argument values, this is the operation that contains the 759 /// block to which the value is an argument (blocks cannot be detached in Python 760 /// bindings so such operation always exists). 761 class PyValue { 762 public: PyValue(PyOperationRef parentOperation,MlirValue value)763 PyValue(PyOperationRef parentOperation, MlirValue value) 764 : parentOperation(parentOperation), value(value) {} 765 get()766 MlirValue get() { return value; } getParentOperation()767 PyOperationRef &getParentOperation() { return parentOperation; } 768 checkValid()769 void checkValid() { return parentOperation->checkValid(); } 770 771 /// Gets a capsule wrapping the void* within the MlirValue. 772 pybind11::object getCapsule(); 773 774 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 775 /// the underlying MlirValue is still tied to the owning operation. 776 static PyValue createFromCapsule(pybind11::object capsule); 777 778 private: 779 PyOperationRef parentOperation; 780 MlirValue value; 781 }; 782 783 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 784 class PyAffineExpr : public BaseContextObject { 785 public: PyAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)786 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 787 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 788 bool operator==(const PyAffineExpr &other); MlirAffineExpr()789 operator MlirAffineExpr() const { return affineExpr; } get()790 MlirAffineExpr get() const { return affineExpr; } 791 792 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 793 pybind11::object getCapsule(); 794 795 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 796 /// Note that PyAffineExpr instances are uniqued, so the returned object 797 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 798 /// is taken by calling this function. 799 static PyAffineExpr createFromCapsule(pybind11::object capsule); 800 801 PyAffineExpr add(const PyAffineExpr &other) const; 802 PyAffineExpr mul(const PyAffineExpr &other) const; 803 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 804 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 805 PyAffineExpr mod(const PyAffineExpr &other) const; 806 807 private: 808 MlirAffineExpr affineExpr; 809 }; 810 811 class PyAffineMap : public BaseContextObject { 812 public: PyAffineMap(PyMlirContextRef contextRef,MlirAffineMap affineMap)813 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 814 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 815 bool operator==(const PyAffineMap &other); MlirAffineMap()816 operator MlirAffineMap() const { return affineMap; } get()817 MlirAffineMap get() const { return affineMap; } 818 819 /// Gets a capsule wrapping the void* within the MlirAffineMap. 820 pybind11::object getCapsule(); 821 822 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 823 /// Note that PyAffineMap instances are uniqued, so the returned object 824 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 825 /// is taken by calling this function. 826 static PyAffineMap createFromCapsule(pybind11::object capsule); 827 828 private: 829 MlirAffineMap affineMap; 830 }; 831 832 class PyIntegerSet : public BaseContextObject { 833 public: PyIntegerSet(PyMlirContextRef contextRef,MlirIntegerSet integerSet)834 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 835 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 836 bool operator==(const PyIntegerSet &other); MlirIntegerSet()837 operator MlirIntegerSet() const { return integerSet; } get()838 MlirIntegerSet get() const { return integerSet; } 839 840 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 841 pybind11::object getCapsule(); 842 843 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 844 /// Note that PyIntegerSet instances may be uniqued, so the returned object 845 /// may be a pre-existing object. Integer sets are owned by the context. 846 static PyIntegerSet createFromCapsule(pybind11::object capsule); 847 848 private: 849 MlirIntegerSet integerSet; 850 }; 851 852 void populateIRAffine(pybind11::module &m); 853 void populateIRAttributes(pybind11::module &m); 854 void populateIRCore(pybind11::module &m); 855 void populateIRTypes(pybind11::module &m); 856 857 } // namespace python 858 } // namespace mlir 859 860 namespace pybind11 { 861 namespace detail { 862 863 template <> 864 struct type_caster<mlir::python::DefaultingPyMlirContext> 865 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 866 template <> 867 struct type_caster<mlir::python::DefaultingPyLocation> 868 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 869 870 } // namespace detail 871 } // namespace pybind11 872 873 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 874