1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "PybindUtils.h"
12 
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15 
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19 
20 using llvm::None;
21 using llvm::Optional;
22 using llvm::SmallVector;
23 using llvm::Twine;
24 
25 //------------------------------------------------------------------------------
26 // Docstrings (trivial, non-duplicated docstrings are included inline).
27 //------------------------------------------------------------------------------
28 
29 static const char kDenseElementsAttrGetDocstring[] =
30     R"(Gets a DenseElementsAttr from a Python buffer or array.
31 
32 When `type` is not provided, then some limited type inferencing is done based
33 on the buffer format. Support presently exists for 8/16/32/64 signed and
34 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
35 types can also be converted back to a corresponding buffer.
36 
37 For conversions outside of these types, a `type=` must be explicitly provided
38 and the buffer contents must be bit-castable to the MLIR internal
39 representation:
40 
41   * Integer types (except for i1): the buffer must be byte aligned to the
42     next byte boundary.
43   * Floating point types: Must be bit-castable to the given floating point
44     size.
45   * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
46     row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
47     this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
48 
49 If a single element buffer is passed (or for i1, a single byte with value 0
50 or 255), then a splat will be created.
51 
52 Args:
53   array: The array or buffer to convert.
54   signless: If inferring an appropriate MLIR type, use signless types for
55     integers (defaults True).
56   type: Skips inference of the MLIR element type and uses this instead. The
57     storage size must be consistent with the actual contents of the buffer.
58   shape: Overrides the shape of the buffer when constructing the MLIR
59     shaped type. This is needed when the physical and logical shape differ (as
60     for i1).
61   context: Explicit context, if not from context manager.
62 
63 Returns:
64   DenseElementsAttr on success.
65 
66 Raises:
67   ValueError: If the type of the buffer or array cannot be matched to an MLIR
68     type or if the buffer does not meet expectations.
69 )";
70 
71 namespace {
72 
toMlirStringRef(const std::string & s)73 static MlirStringRef toMlirStringRef(const std::string &s) {
74   return mlirStringRefCreate(s.data(), s.size());
75 }
76 
77 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78 public:
79   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80   static constexpr const char *pyClassName = "AffineMapAttr";
81   using PyConcreteAttribute::PyConcreteAttribute;
82 
bindDerived(ClassTy & c)83   static void bindDerived(ClassTy &c) {
84     c.def_static(
85         "get",
86         [](PyAffineMap &affineMap) {
87           MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88           return PyAffineMapAttribute(affineMap.getContext(), attr);
89         },
90         py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91   }
92 };
93 
94 template <typename T>
pyTryCast(py::handle object)95 static T pyTryCast(py::handle object) {
96   try {
97     return object.cast<T>();
98   } catch (py::cast_error &err) {
99     std::string msg =
100         std::string(
101             "Invalid attribute when attempting to create an ArrayAttribute (") +
102         err.what() + ")";
103     throw py::cast_error(msg);
104   } catch (py::reference_cast_error &err) {
105     std::string msg = std::string("Invalid attribute (None?) when attempting "
106                                   "to create an ArrayAttribute (") +
107                       err.what() + ")";
108     throw py::cast_error(msg);
109   }
110 }
111 
112 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113 public:
114   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115   static constexpr const char *pyClassName = "ArrayAttr";
116   using PyConcreteAttribute::PyConcreteAttribute;
117 
118   class PyArrayAttributeIterator {
119   public:
PyArrayAttributeIterator(PyAttribute attr)120     PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121 
dunderIter()122     PyArrayAttributeIterator &dunderIter() { return *this; }
123 
dunderNext()124     PyAttribute dunderNext() {
125       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126         throw py::stop_iteration();
127       }
128       return PyAttribute(attr.getContext(),
129                          mlirArrayAttrGetElement(attr.get(), nextIndex++));
130     }
131 
bind(py::module & m)132     static void bind(py::module &m) {
133       py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134                                            py::module_local())
135           .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136           .def("__next__", &PyArrayAttributeIterator::dunderNext);
137     }
138 
139   private:
140     PyAttribute attr;
141     int nextIndex = 0;
142   };
143 
getItem(intptr_t i)144   PyAttribute getItem(intptr_t i) {
145     return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146   }
147 
bindDerived(ClassTy & c)148   static void bindDerived(ClassTy &c) {
149     c.def_static(
150         "get",
151         [](py::list attributes, DefaultingPyMlirContext context) {
152           SmallVector<MlirAttribute> mlirAttributes;
153           mlirAttributes.reserve(py::len(attributes));
154           for (auto attribute : attributes) {
155             mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
156           }
157           MlirAttribute attr = mlirArrayAttrGet(
158               context->get(), mlirAttributes.size(), mlirAttributes.data());
159           return PyArrayAttribute(context->getRef(), attr);
160         },
161         py::arg("attributes"), py::arg("context") = py::none(),
162         "Gets a uniqued Array attribute");
163     c.def("__getitem__",
164           [](PyArrayAttribute &arr, intptr_t i) {
165             if (i >= mlirArrayAttrGetNumElements(arr))
166               throw py::index_error("ArrayAttribute index out of range");
167             return arr.getItem(i);
168           })
169         .def("__len__",
170              [](const PyArrayAttribute &arr) {
171                return mlirArrayAttrGetNumElements(arr);
172              })
173         .def("__iter__", [](const PyArrayAttribute &arr) {
174           return PyArrayAttributeIterator(arr);
175         });
176     c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177       std::vector<MlirAttribute> attributes;
178       intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179       attributes.reserve(numOldElements + py::len(extras));
180       for (intptr_t i = 0; i < numOldElements; ++i)
181         attributes.push_back(arr.getItem(i));
182       for (py::handle attr : extras)
183         attributes.push_back(pyTryCast<PyAttribute>(attr));
184       MlirAttribute arrayAttr = mlirArrayAttrGet(
185           arr.getContext()->get(), attributes.size(), attributes.data());
186       return PyArrayAttribute(arr.getContext(), arrayAttr);
187     });
188   }
189 };
190 
191 /// Float Point Attribute subclass - FloatAttr.
192 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193 public:
194   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195   static constexpr const char *pyClassName = "FloatAttr";
196   using PyConcreteAttribute::PyConcreteAttribute;
197 
bindDerived(ClassTy & c)198   static void bindDerived(ClassTy &c) {
199     c.def_static(
200         "get",
201         [](PyType &type, double value, DefaultingPyLocation loc) {
202           MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203           // TODO: Rework error reporting once diagnostic engine is exposed
204           // in C API.
205           if (mlirAttributeIsNull(attr)) {
206             throw SetPyError(PyExc_ValueError,
207                              Twine("invalid '") +
208                                  py::repr(py::cast(type)).cast<std::string>() +
209                                  "' and expected floating point type.");
210           }
211           return PyFloatAttribute(type.getContext(), attr);
212         },
213         py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214         "Gets an uniqued float point attribute associated to a type");
215     c.def_static(
216         "get_f32",
217         [](double value, DefaultingPyMlirContext context) {
218           MlirAttribute attr = mlirFloatAttrDoubleGet(
219               context->get(), mlirF32TypeGet(context->get()), value);
220           return PyFloatAttribute(context->getRef(), attr);
221         },
222         py::arg("value"), py::arg("context") = py::none(),
223         "Gets an uniqued float point attribute associated to a f32 type");
224     c.def_static(
225         "get_f64",
226         [](double value, DefaultingPyMlirContext context) {
227           MlirAttribute attr = mlirFloatAttrDoubleGet(
228               context->get(), mlirF64TypeGet(context->get()), value);
229           return PyFloatAttribute(context->getRef(), attr);
230         },
231         py::arg("value"), py::arg("context") = py::none(),
232         "Gets an uniqued float point attribute associated to a f64 type");
233     c.def_property_readonly(
234         "value",
235         [](PyFloatAttribute &self) {
236           return mlirFloatAttrGetValueDouble(self);
237         },
238         "Returns the value of the float point attribute");
239   }
240 };
241 
242 /// Integer Attribute subclass - IntegerAttr.
243 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244 public:
245   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246   static constexpr const char *pyClassName = "IntegerAttr";
247   using PyConcreteAttribute::PyConcreteAttribute;
248 
bindDerived(ClassTy & c)249   static void bindDerived(ClassTy &c) {
250     c.def_static(
251         "get",
252         [](PyType &type, int64_t value) {
253           MlirAttribute attr = mlirIntegerAttrGet(type, value);
254           return PyIntegerAttribute(type.getContext(), attr);
255         },
256         py::arg("type"), py::arg("value"),
257         "Gets an uniqued integer attribute associated to a type");
258     c.def_property_readonly(
259         "value",
260         [](PyIntegerAttribute &self) {
261           return mlirIntegerAttrGetValueInt(self);
262         },
263         "Returns the value of the integer attribute");
264   }
265 };
266 
267 /// Bool Attribute subclass - BoolAttr.
268 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269 public:
270   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271   static constexpr const char *pyClassName = "BoolAttr";
272   using PyConcreteAttribute::PyConcreteAttribute;
273 
bindDerived(ClassTy & c)274   static void bindDerived(ClassTy &c) {
275     c.def_static(
276         "get",
277         [](bool value, DefaultingPyMlirContext context) {
278           MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279           return PyBoolAttribute(context->getRef(), attr);
280         },
281         py::arg("value"), py::arg("context") = py::none(),
282         "Gets an uniqued bool attribute");
283     c.def_property_readonly(
284         "value",
285         [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286         "Returns the value of the bool attribute");
287   }
288 };
289 
290 class PyFlatSymbolRefAttribute
291     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292 public:
293   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295   using PyConcreteAttribute::PyConcreteAttribute;
296 
bindDerived(ClassTy & c)297   static void bindDerived(ClassTy &c) {
298     c.def_static(
299         "get",
300         [](std::string value, DefaultingPyMlirContext context) {
301           MlirAttribute attr =
302               mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303           return PyFlatSymbolRefAttribute(context->getRef(), attr);
304         },
305         py::arg("value"), py::arg("context") = py::none(),
306         "Gets a uniqued FlatSymbolRef attribute");
307     c.def_property_readonly(
308         "value",
309         [](PyFlatSymbolRefAttribute &self) {
310           MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311           return py::str(stringRef.data, stringRef.length);
312         },
313         "Returns the value of the FlatSymbolRef attribute as a string");
314   }
315 };
316 
317 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318 public:
319   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320   static constexpr const char *pyClassName = "StringAttr";
321   using PyConcreteAttribute::PyConcreteAttribute;
322 
bindDerived(ClassTy & c)323   static void bindDerived(ClassTy &c) {
324     c.def_static(
325         "get",
326         [](std::string value, DefaultingPyMlirContext context) {
327           MlirAttribute attr =
328               mlirStringAttrGet(context->get(), toMlirStringRef(value));
329           return PyStringAttribute(context->getRef(), attr);
330         },
331         py::arg("value"), py::arg("context") = py::none(),
332         "Gets a uniqued string attribute");
333     c.def_static(
334         "get_typed",
335         [](PyType &type, std::string value) {
336           MlirAttribute attr =
337               mlirStringAttrTypedGet(type, toMlirStringRef(value));
338           return PyStringAttribute(type.getContext(), attr);
339         },
340 
341         "Gets a uniqued string attribute associated to a type");
342     c.def_property_readonly(
343         "value",
344         [](PyStringAttribute &self) {
345           MlirStringRef stringRef = mlirStringAttrGetValue(self);
346           return py::str(stringRef.data, stringRef.length);
347         },
348         "Returns the value of the string attribute");
349   }
350 };
351 
352 // TODO: Support construction of string elements.
353 class PyDenseElementsAttribute
354     : public PyConcreteAttribute<PyDenseElementsAttribute> {
355 public:
356   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357   static constexpr const char *pyClassName = "DenseElementsAttr";
358   using PyConcreteAttribute::PyConcreteAttribute;
359 
360   static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,Optional<PyType> explicitType,Optional<std::vector<int64_t>> explicitShape,DefaultingPyMlirContext contextWrapper)361   getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
362                 Optional<std::vector<int64_t>> explicitShape,
363                 DefaultingPyMlirContext contextWrapper) {
364     // Request a contiguous view. In exotic cases, this will cause a copy.
365     int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366     Py_buffer *view = new Py_buffer();
367     if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368       delete view;
369       throw py::error_already_set();
370     }
371     py::buffer_info arrayInfo(view);
372     SmallVector<int64_t> shape;
373     if (explicitShape) {
374       shape.append(explicitShape->begin(), explicitShape->end());
375     } else {
376       shape.append(arrayInfo.shape.begin(),
377                    arrayInfo.shape.begin() + arrayInfo.ndim);
378     }
379 
380     MlirAttribute encodingAttr = mlirAttributeGetNull();
381     MlirContext context = contextWrapper->get();
382 
383     // Detect format codes that are suitable for bulk loading. This includes
384     // all byte aligned integer and floating point types up to 8 bytes.
385     // Notably, this excludes, bool (which needs to be bit-packed) and
386     // other exotics which do not have a direct representation in the buffer
387     // protocol (i.e. complex, etc).
388     Optional<MlirType> bulkLoadElementType;
389     if (explicitType) {
390       bulkLoadElementType = *explicitType;
391     } else if (arrayInfo.format == "f") {
392       // f32
393       assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
394       bulkLoadElementType = mlirF32TypeGet(context);
395     } else if (arrayInfo.format == "d") {
396       // f64
397       assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
398       bulkLoadElementType = mlirF64TypeGet(context);
399     } else if (arrayInfo.format == "e") {
400       // f16
401       assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
402       bulkLoadElementType = mlirF16TypeGet(context);
403     } else if (isSignedIntegerFormat(arrayInfo.format)) {
404       if (arrayInfo.itemsize == 4) {
405         // i32
406         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407                                        : mlirIntegerTypeSignedGet(context, 32);
408       } else if (arrayInfo.itemsize == 8) {
409         // i64
410         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411                                        : mlirIntegerTypeSignedGet(context, 64);
412       } else if (arrayInfo.itemsize == 1) {
413         // i8
414         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
415                                        : mlirIntegerTypeSignedGet(context, 8);
416       } else if (arrayInfo.itemsize == 2) {
417         // i16
418         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
419                                        : mlirIntegerTypeSignedGet(context, 16);
420       }
421     } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422       if (arrayInfo.itemsize == 4) {
423         // unsigned i32
424         bulkLoadElementType = signless
425                                   ? mlirIntegerTypeGet(context, 32)
426                                   : mlirIntegerTypeUnsignedGet(context, 32);
427       } else if (arrayInfo.itemsize == 8) {
428         // unsigned i64
429         bulkLoadElementType = signless
430                                   ? mlirIntegerTypeGet(context, 64)
431                                   : mlirIntegerTypeUnsignedGet(context, 64);
432       } else if (arrayInfo.itemsize == 1) {
433         // i8
434         bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
435                                        : mlirIntegerTypeUnsignedGet(context, 8);
436       } else if (arrayInfo.itemsize == 2) {
437         // i16
438         bulkLoadElementType = signless
439                                   ? mlirIntegerTypeGet(context, 16)
440                                   : mlirIntegerTypeUnsignedGet(context, 16);
441       }
442     }
443     if (bulkLoadElementType) {
444       auto shapedType = mlirRankedTensorTypeGet(
445           shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
446       size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
447       MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
448           shapedType, rawBufferSize, arrayInfo.ptr);
449       if (mlirAttributeIsNull(attr)) {
450         throw std::invalid_argument(
451             "DenseElementsAttr could not be constructed from the given buffer. "
452             "This may mean that the Python buffer layout does not match that "
453             "MLIR expected layout and is a bug.");
454       }
455       return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
456     }
457 
458     throw std::invalid_argument(
459         std::string("unimplemented array format conversion from format: ") +
460         arrayInfo.format);
461   }
462 
getSplat(PyType shapedType,PyAttribute & elementAttr)463   static PyDenseElementsAttribute getSplat(PyType shapedType,
464                                            PyAttribute &elementAttr) {
465     auto contextWrapper =
466         PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467     if (!mlirAttributeIsAInteger(elementAttr) &&
468         !mlirAttributeIsAFloat(elementAttr)) {
469       std::string message = "Illegal element type for DenseElementsAttr: ";
470       message.append(py::repr(py::cast(elementAttr)));
471       throw SetPyError(PyExc_ValueError, message);
472     }
473     if (!mlirTypeIsAShaped(shapedType) ||
474         !mlirShapedTypeHasStaticShape(shapedType)) {
475       std::string message =
476           "Expected a static ShapedType for the shaped_type parameter: ";
477       message.append(py::repr(py::cast(shapedType)));
478       throw SetPyError(PyExc_ValueError, message);
479     }
480     MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481     MlirType attrType = mlirAttributeGetType(elementAttr);
482     if (!mlirTypeEqual(shapedElementType, attrType)) {
483       std::string message =
484           "Shaped element type and attribute type must be equal: shaped=";
485       message.append(py::repr(py::cast(shapedType)));
486       message.append(", element=");
487       message.append(py::repr(py::cast(elementAttr)));
488       throw SetPyError(PyExc_ValueError, message);
489     }
490 
491     MlirAttribute elements =
492         mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493     return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494   }
495 
dunderLen()496   intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497 
accessBuffer()498   py::buffer_info accessBuffer() {
499     if (mlirDenseElementsAttrIsSplat(*this)) {
500       // TODO: Currently crashes the program.
501       // Reported as https://github.com/pybind/pybind11/issues/3336
502       throw std::invalid_argument(
503           "unsupported data type for conversion to Python buffer");
504     }
505 
506     MlirType shapedType = mlirAttributeGetType(*this);
507     MlirType elementType = mlirShapedTypeGetElementType(shapedType);
508     std::string format;
509 
510     if (mlirTypeIsAF32(elementType)) {
511       // f32
512       return bufferInfo<float>(shapedType);
513     } else if (mlirTypeIsAF64(elementType)) {
514       // f64
515       return bufferInfo<double>(shapedType);
516     } else if (mlirTypeIsAF16(elementType)) {
517       // f16
518       return bufferInfo<uint16_t>(shapedType, "e");
519     } else if (mlirTypeIsAInteger(elementType) &&
520                mlirIntegerTypeGetWidth(elementType) == 32) {
521       if (mlirIntegerTypeIsSignless(elementType) ||
522           mlirIntegerTypeIsSigned(elementType)) {
523         // i32
524         return bufferInfo<int32_t>(shapedType);
525       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
526         // unsigned i32
527         return bufferInfo<uint32_t>(shapedType);
528       }
529     } else if (mlirTypeIsAInteger(elementType) &&
530                mlirIntegerTypeGetWidth(elementType) == 64) {
531       if (mlirIntegerTypeIsSignless(elementType) ||
532           mlirIntegerTypeIsSigned(elementType)) {
533         // i64
534         return bufferInfo<int64_t>(shapedType);
535       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
536         // unsigned i64
537         return bufferInfo<uint64_t>(shapedType);
538       }
539     } else if (mlirTypeIsAInteger(elementType) &&
540                mlirIntegerTypeGetWidth(elementType) == 8) {
541       if (mlirIntegerTypeIsSignless(elementType) ||
542           mlirIntegerTypeIsSigned(elementType)) {
543         // i8
544         return bufferInfo<int8_t>(shapedType);
545       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
546         // unsigned i8
547         return bufferInfo<uint8_t>(shapedType);
548       }
549     } else if (mlirTypeIsAInteger(elementType) &&
550                mlirIntegerTypeGetWidth(elementType) == 16) {
551       if (mlirIntegerTypeIsSignless(elementType) ||
552           mlirIntegerTypeIsSigned(elementType)) {
553         // i16
554         return bufferInfo<int16_t>(shapedType);
555       } else if (mlirIntegerTypeIsUnsigned(elementType)) {
556         // unsigned i16
557         return bufferInfo<uint16_t>(shapedType);
558       }
559     }
560 
561     // TODO: Currently crashes the program.
562     // Reported as https://github.com/pybind/pybind11/issues/3336
563     throw std::invalid_argument(
564         "unsupported data type for conversion to Python buffer");
565   }
566 
bindDerived(ClassTy & c)567   static void bindDerived(ClassTy &c) {
568     c.def("__len__", &PyDenseElementsAttribute::dunderLen)
569         .def_static("get", PyDenseElementsAttribute::getFromBuffer,
570                     py::arg("array"), py::arg("signless") = true,
571                     py::arg("type") = py::none(), py::arg("shape") = py::none(),
572                     py::arg("context") = py::none(),
573                     kDenseElementsAttrGetDocstring)
574         .def_static("get_splat", PyDenseElementsAttribute::getSplat,
575                     py::arg("shaped_type"), py::arg("element_attr"),
576                     "Gets a DenseElementsAttr where all values are the same")
577         .def_property_readonly("is_splat",
578                                [](PyDenseElementsAttribute &self) -> bool {
579                                  return mlirDenseElementsAttrIsSplat(self);
580                                })
581         .def_buffer(&PyDenseElementsAttribute::accessBuffer);
582   }
583 
584 private:
isUnsignedIntegerFormat(const std::string & format)585   static bool isUnsignedIntegerFormat(const std::string &format) {
586     if (format.empty())
587       return false;
588     char code = format[0];
589     return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
590            code == 'Q';
591   }
592 
isSignedIntegerFormat(const std::string & format)593   static bool isSignedIntegerFormat(const std::string &format) {
594     if (format.empty())
595       return false;
596     char code = format[0];
597     return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
598            code == 'q';
599   }
600 
601   template <typename Type>
bufferInfo(MlirType shapedType,const char * explicitFormat=nullptr)602   py::buffer_info bufferInfo(MlirType shapedType,
603                              const char *explicitFormat = nullptr) {
604     intptr_t rank = mlirShapedTypeGetRank(shapedType);
605     // Prepare the data for the buffer_info.
606     // Buffer is configured for read-only access below.
607     Type *data = static_cast<Type *>(
608         const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
609     // Prepare the shape for the buffer_info.
610     SmallVector<intptr_t, 4> shape;
611     for (intptr_t i = 0; i < rank; ++i)
612       shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
613     // Prepare the strides for the buffer_info.
614     SmallVector<intptr_t, 4> strides;
615     intptr_t strideFactor = 1;
616     for (intptr_t i = 1; i < rank; ++i) {
617       strideFactor = 1;
618       for (intptr_t j = i; j < rank; ++j) {
619         strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
620       }
621       strides.push_back(sizeof(Type) * strideFactor);
622     }
623     strides.push_back(sizeof(Type));
624     std::string format;
625     if (explicitFormat) {
626       format = explicitFormat;
627     } else {
628       format = py::format_descriptor<Type>::format();
629     }
630     return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
631                            /*readonly=*/true);
632   }
633 }; // namespace
634 
635 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
636 /// (and boolean) values. Supports element access.
637 class PyDenseIntElementsAttribute
638     : public PyConcreteAttribute<PyDenseIntElementsAttribute,
639                                  PyDenseElementsAttribute> {
640 public:
641   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
642   static constexpr const char *pyClassName = "DenseIntElementsAttr";
643   using PyConcreteAttribute::PyConcreteAttribute;
644 
645   /// Returns the element at the given linear position. Asserts if the index is
646   /// out of range.
dunderGetItem(intptr_t pos)647   py::int_ dunderGetItem(intptr_t pos) {
648     if (pos < 0 || pos >= dunderLen()) {
649       throw SetPyError(PyExc_IndexError,
650                        "attempt to access out of bounds element");
651     }
652 
653     MlirType type = mlirAttributeGetType(*this);
654     type = mlirShapedTypeGetElementType(type);
655     assert(mlirTypeIsAInteger(type) &&
656            "expected integer element type in dense int elements attribute");
657     // Dispatch element extraction to an appropriate C function based on the
658     // elemental type of the attribute. py::int_ is implicitly constructible
659     // from any C++ integral type and handles bitwidth correctly.
660     // TODO: consider caching the type properties in the constructor to avoid
661     // querying them on each element access.
662     unsigned width = mlirIntegerTypeGetWidth(type);
663     bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
664     if (isUnsigned) {
665       if (width == 1) {
666         return mlirDenseElementsAttrGetBoolValue(*this, pos);
667       }
668       if (width == 32) {
669         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
670       }
671       if (width == 64) {
672         return mlirDenseElementsAttrGetUInt64Value(*this, pos);
673       }
674     } else {
675       if (width == 1) {
676         return mlirDenseElementsAttrGetBoolValue(*this, pos);
677       }
678       if (width == 32) {
679         return mlirDenseElementsAttrGetInt32Value(*this, pos);
680       }
681       if (width == 64) {
682         return mlirDenseElementsAttrGetInt64Value(*this, pos);
683       }
684     }
685     throw SetPyError(PyExc_TypeError, "Unsupported integer type");
686   }
687 
bindDerived(ClassTy & c)688   static void bindDerived(ClassTy &c) {
689     c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
690   }
691 };
692 
693 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
694 public:
695   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
696   static constexpr const char *pyClassName = "DictAttr";
697   using PyConcreteAttribute::PyConcreteAttribute;
698 
dunderLen()699   intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
700 
bindDerived(ClassTy & c)701   static void bindDerived(ClassTy &c) {
702     c.def("__len__", &PyDictAttribute::dunderLen);
703     c.def_static(
704         "get",
705         [](py::dict attributes, DefaultingPyMlirContext context) {
706           SmallVector<MlirNamedAttribute> mlirNamedAttributes;
707           mlirNamedAttributes.reserve(attributes.size());
708           for (auto &it : attributes) {
709             auto &mlir_attr = it.second.cast<PyAttribute &>();
710             auto name = it.first.cast<std::string>();
711             mlirNamedAttributes.push_back(mlirNamedAttributeGet(
712                 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
713                                   toMlirStringRef(name)),
714                 mlir_attr));
715           }
716           MlirAttribute attr =
717               mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
718                                     mlirNamedAttributes.data());
719           return PyDictAttribute(context->getRef(), attr);
720         },
721         py::arg("value") = py::dict(), py::arg("context") = py::none(),
722         "Gets an uniqued dict attribute");
723     c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
724       MlirAttribute attr =
725           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
726       if (mlirAttributeIsNull(attr)) {
727         throw SetPyError(PyExc_KeyError,
728                          "attempt to access a non-existent attribute");
729       }
730       return PyAttribute(self.getContext(), attr);
731     });
732     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
733       if (index < 0 || index >= self.dunderLen()) {
734         throw SetPyError(PyExc_IndexError,
735                          "attempt to access out of bounds attribute");
736       }
737       MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
738       return PyNamedAttribute(
739           namedAttr.attribute,
740           std::string(mlirIdentifierStr(namedAttr.name).data));
741     });
742   }
743 };
744 
745 /// Refinement of PyDenseElementsAttribute for attributes containing
746 /// floating-point values. Supports element access.
747 class PyDenseFPElementsAttribute
748     : public PyConcreteAttribute<PyDenseFPElementsAttribute,
749                                  PyDenseElementsAttribute> {
750 public:
751   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
752   static constexpr const char *pyClassName = "DenseFPElementsAttr";
753   using PyConcreteAttribute::PyConcreteAttribute;
754 
dunderGetItem(intptr_t pos)755   py::float_ dunderGetItem(intptr_t pos) {
756     if (pos < 0 || pos >= dunderLen()) {
757       throw SetPyError(PyExc_IndexError,
758                        "attempt to access out of bounds element");
759     }
760 
761     MlirType type = mlirAttributeGetType(*this);
762     type = mlirShapedTypeGetElementType(type);
763     // Dispatch element extraction to an appropriate C function based on the
764     // elemental type of the attribute. py::float_ is implicitly constructible
765     // from float and double.
766     // TODO: consider caching the type properties in the constructor to avoid
767     // querying them on each element access.
768     if (mlirTypeIsAF32(type)) {
769       return mlirDenseElementsAttrGetFloatValue(*this, pos);
770     }
771     if (mlirTypeIsAF64(type)) {
772       return mlirDenseElementsAttrGetDoubleValue(*this, pos);
773     }
774     throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
775   }
776 
bindDerived(ClassTy & c)777   static void bindDerived(ClassTy &c) {
778     c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
779   }
780 };
781 
782 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
783 public:
784   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
785   static constexpr const char *pyClassName = "TypeAttr";
786   using PyConcreteAttribute::PyConcreteAttribute;
787 
bindDerived(ClassTy & c)788   static void bindDerived(ClassTy &c) {
789     c.def_static(
790         "get",
791         [](PyType value, DefaultingPyMlirContext context) {
792           MlirAttribute attr = mlirTypeAttrGet(value.get());
793           return PyTypeAttribute(context->getRef(), attr);
794         },
795         py::arg("value"), py::arg("context") = py::none(),
796         "Gets a uniqued Type attribute");
797     c.def_property_readonly("value", [](PyTypeAttribute &self) {
798       return PyType(self.getContext()->getRef(),
799                     mlirTypeAttrGetValue(self.get()));
800     });
801   }
802 };
803 
804 /// Unit Attribute subclass. Unit attributes don't have values.
805 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
806 public:
807   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
808   static constexpr const char *pyClassName = "UnitAttr";
809   using PyConcreteAttribute::PyConcreteAttribute;
810 
bindDerived(ClassTy & c)811   static void bindDerived(ClassTy &c) {
812     c.def_static(
813         "get",
814         [](DefaultingPyMlirContext context) {
815           return PyUnitAttribute(context->getRef(),
816                                  mlirUnitAttrGet(context->get()));
817         },
818         py::arg("context") = py::none(), "Create a Unit attribute.");
819   }
820 };
821 
822 } // namespace
823 
populateIRAttributes(py::module & m)824 void mlir::python::populateIRAttributes(py::module &m) {
825   PyAffineMapAttribute::bind(m);
826   PyArrayAttribute::bind(m);
827   PyArrayAttribute::PyArrayAttributeIterator::bind(m);
828   PyBoolAttribute::bind(m);
829   PyDenseElementsAttribute::bind(m);
830   PyDenseFPElementsAttribute::bind(m);
831   PyDenseIntElementsAttribute::bind(m);
832   PyDictAttribute::bind(m);
833   PyFlatSymbolRefAttribute::bind(m);
834   PyFloatAttribute::bind(m);
835   PyIntegerAttribute::bind(m);
836   PyStringAttribute::bind(m);
837   PyTypeAttribute::bind(m);
838   PyUnitAttribute::bind(m);
839 }
840