1 //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModule.h"
10
11 #include "PybindUtils.h"
12
13 #include "mlir-c/BuiltinAttributes.h"
14 #include "mlir-c/BuiltinTypes.h"
15
16 namespace py = pybind11;
17 using namespace mlir;
18 using namespace mlir::python;
19
20 using llvm::None;
21 using llvm::Optional;
22 using llvm::SmallVector;
23 using llvm::Twine;
24
25 //------------------------------------------------------------------------------
26 // Docstrings (trivial, non-duplicated docstrings are included inline).
27 //------------------------------------------------------------------------------
28
29 static const char kDenseElementsAttrGetDocstring[] =
30 R"(Gets a DenseElementsAttr from a Python buffer or array.
31
32 When `type` is not provided, then some limited type inferencing is done based
33 on the buffer format. Support presently exists for 8/16/32/64 signed and
34 unsigned integers and float16/float32/float64. DenseElementsAttrs of these
35 types can also be converted back to a corresponding buffer.
36
37 For conversions outside of these types, a `type=` must be explicitly provided
38 and the buffer contents must be bit-castable to the MLIR internal
39 representation:
40
41 * Integer types (except for i1): the buffer must be byte aligned to the
42 next byte boundary.
43 * Floating point types: Must be bit-castable to the given floating point
44 size.
45 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
46 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
47 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
48
49 If a single element buffer is passed (or for i1, a single byte with value 0
50 or 255), then a splat will be created.
51
52 Args:
53 array: The array or buffer to convert.
54 signless: If inferring an appropriate MLIR type, use signless types for
55 integers (defaults True).
56 type: Skips inference of the MLIR element type and uses this instead. The
57 storage size must be consistent with the actual contents of the buffer.
58 shape: Overrides the shape of the buffer when constructing the MLIR
59 shaped type. This is needed when the physical and logical shape differ (as
60 for i1).
61 context: Explicit context, if not from context manager.
62
63 Returns:
64 DenseElementsAttr on success.
65
66 Raises:
67 ValueError: If the type of the buffer or array cannot be matched to an MLIR
68 type or if the buffer does not meet expectations.
69 )";
70
71 namespace {
72
toMlirStringRef(const std::string & s)73 static MlirStringRef toMlirStringRef(const std::string &s) {
74 return mlirStringRefCreate(s.data(), s.size());
75 }
76
77 class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
78 public:
79 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
80 static constexpr const char *pyClassName = "AffineMapAttr";
81 using PyConcreteAttribute::PyConcreteAttribute;
82
bindDerived(ClassTy & c)83 static void bindDerived(ClassTy &c) {
84 c.def_static(
85 "get",
86 [](PyAffineMap &affineMap) {
87 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
88 return PyAffineMapAttribute(affineMap.getContext(), attr);
89 },
90 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
91 }
92 };
93
94 template <typename T>
pyTryCast(py::handle object)95 static T pyTryCast(py::handle object) {
96 try {
97 return object.cast<T>();
98 } catch (py::cast_error &err) {
99 std::string msg =
100 std::string(
101 "Invalid attribute when attempting to create an ArrayAttribute (") +
102 err.what() + ")";
103 throw py::cast_error(msg);
104 } catch (py::reference_cast_error &err) {
105 std::string msg = std::string("Invalid attribute (None?) when attempting "
106 "to create an ArrayAttribute (") +
107 err.what() + ")";
108 throw py::cast_error(msg);
109 }
110 }
111
112 class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
113 public:
114 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
115 static constexpr const char *pyClassName = "ArrayAttr";
116 using PyConcreteAttribute::PyConcreteAttribute;
117
118 class PyArrayAttributeIterator {
119 public:
PyArrayAttributeIterator(PyAttribute attr)120 PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {}
121
dunderIter()122 PyArrayAttributeIterator &dunderIter() { return *this; }
123
dunderNext()124 PyAttribute dunderNext() {
125 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) {
126 throw py::stop_iteration();
127 }
128 return PyAttribute(attr.getContext(),
129 mlirArrayAttrGetElement(attr.get(), nextIndex++));
130 }
131
bind(py::module & m)132 static void bind(py::module &m) {
133 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
134 py::module_local())
135 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
136 .def("__next__", &PyArrayAttributeIterator::dunderNext);
137 }
138
139 private:
140 PyAttribute attr;
141 int nextIndex = 0;
142 };
143
getItem(intptr_t i)144 PyAttribute getItem(intptr_t i) {
145 return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i));
146 }
147
bindDerived(ClassTy & c)148 static void bindDerived(ClassTy &c) {
149 c.def_static(
150 "get",
151 [](py::list attributes, DefaultingPyMlirContext context) {
152 SmallVector<MlirAttribute> mlirAttributes;
153 mlirAttributes.reserve(py::len(attributes));
154 for (auto attribute : attributes) {
155 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
156 }
157 MlirAttribute attr = mlirArrayAttrGet(
158 context->get(), mlirAttributes.size(), mlirAttributes.data());
159 return PyArrayAttribute(context->getRef(), attr);
160 },
161 py::arg("attributes"), py::arg("context") = py::none(),
162 "Gets a uniqued Array attribute");
163 c.def("__getitem__",
164 [](PyArrayAttribute &arr, intptr_t i) {
165 if (i >= mlirArrayAttrGetNumElements(arr))
166 throw py::index_error("ArrayAttribute index out of range");
167 return arr.getItem(i);
168 })
169 .def("__len__",
170 [](const PyArrayAttribute &arr) {
171 return mlirArrayAttrGetNumElements(arr);
172 })
173 .def("__iter__", [](const PyArrayAttribute &arr) {
174 return PyArrayAttributeIterator(arr);
175 });
176 c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
177 std::vector<MlirAttribute> attributes;
178 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
179 attributes.reserve(numOldElements + py::len(extras));
180 for (intptr_t i = 0; i < numOldElements; ++i)
181 attributes.push_back(arr.getItem(i));
182 for (py::handle attr : extras)
183 attributes.push_back(pyTryCast<PyAttribute>(attr));
184 MlirAttribute arrayAttr = mlirArrayAttrGet(
185 arr.getContext()->get(), attributes.size(), attributes.data());
186 return PyArrayAttribute(arr.getContext(), arrayAttr);
187 });
188 }
189 };
190
191 /// Float Point Attribute subclass - FloatAttr.
192 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
193 public:
194 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
195 static constexpr const char *pyClassName = "FloatAttr";
196 using PyConcreteAttribute::PyConcreteAttribute;
197
bindDerived(ClassTy & c)198 static void bindDerived(ClassTy &c) {
199 c.def_static(
200 "get",
201 [](PyType &type, double value, DefaultingPyLocation loc) {
202 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
203 // TODO: Rework error reporting once diagnostic engine is exposed
204 // in C API.
205 if (mlirAttributeIsNull(attr)) {
206 throw SetPyError(PyExc_ValueError,
207 Twine("invalid '") +
208 py::repr(py::cast(type)).cast<std::string>() +
209 "' and expected floating point type.");
210 }
211 return PyFloatAttribute(type.getContext(), attr);
212 },
213 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
214 "Gets an uniqued float point attribute associated to a type");
215 c.def_static(
216 "get_f32",
217 [](double value, DefaultingPyMlirContext context) {
218 MlirAttribute attr = mlirFloatAttrDoubleGet(
219 context->get(), mlirF32TypeGet(context->get()), value);
220 return PyFloatAttribute(context->getRef(), attr);
221 },
222 py::arg("value"), py::arg("context") = py::none(),
223 "Gets an uniqued float point attribute associated to a f32 type");
224 c.def_static(
225 "get_f64",
226 [](double value, DefaultingPyMlirContext context) {
227 MlirAttribute attr = mlirFloatAttrDoubleGet(
228 context->get(), mlirF64TypeGet(context->get()), value);
229 return PyFloatAttribute(context->getRef(), attr);
230 },
231 py::arg("value"), py::arg("context") = py::none(),
232 "Gets an uniqued float point attribute associated to a f64 type");
233 c.def_property_readonly(
234 "value",
235 [](PyFloatAttribute &self) {
236 return mlirFloatAttrGetValueDouble(self);
237 },
238 "Returns the value of the float point attribute");
239 }
240 };
241
242 /// Integer Attribute subclass - IntegerAttr.
243 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
244 public:
245 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
246 static constexpr const char *pyClassName = "IntegerAttr";
247 using PyConcreteAttribute::PyConcreteAttribute;
248
bindDerived(ClassTy & c)249 static void bindDerived(ClassTy &c) {
250 c.def_static(
251 "get",
252 [](PyType &type, int64_t value) {
253 MlirAttribute attr = mlirIntegerAttrGet(type, value);
254 return PyIntegerAttribute(type.getContext(), attr);
255 },
256 py::arg("type"), py::arg("value"),
257 "Gets an uniqued integer attribute associated to a type");
258 c.def_property_readonly(
259 "value",
260 [](PyIntegerAttribute &self) {
261 return mlirIntegerAttrGetValueInt(self);
262 },
263 "Returns the value of the integer attribute");
264 }
265 };
266
267 /// Bool Attribute subclass - BoolAttr.
268 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
269 public:
270 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
271 static constexpr const char *pyClassName = "BoolAttr";
272 using PyConcreteAttribute::PyConcreteAttribute;
273
bindDerived(ClassTy & c)274 static void bindDerived(ClassTy &c) {
275 c.def_static(
276 "get",
277 [](bool value, DefaultingPyMlirContext context) {
278 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
279 return PyBoolAttribute(context->getRef(), attr);
280 },
281 py::arg("value"), py::arg("context") = py::none(),
282 "Gets an uniqued bool attribute");
283 c.def_property_readonly(
284 "value",
285 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
286 "Returns the value of the bool attribute");
287 }
288 };
289
290 class PyFlatSymbolRefAttribute
291 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
292 public:
293 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
294 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
295 using PyConcreteAttribute::PyConcreteAttribute;
296
bindDerived(ClassTy & c)297 static void bindDerived(ClassTy &c) {
298 c.def_static(
299 "get",
300 [](std::string value, DefaultingPyMlirContext context) {
301 MlirAttribute attr =
302 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
303 return PyFlatSymbolRefAttribute(context->getRef(), attr);
304 },
305 py::arg("value"), py::arg("context") = py::none(),
306 "Gets a uniqued FlatSymbolRef attribute");
307 c.def_property_readonly(
308 "value",
309 [](PyFlatSymbolRefAttribute &self) {
310 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
311 return py::str(stringRef.data, stringRef.length);
312 },
313 "Returns the value of the FlatSymbolRef attribute as a string");
314 }
315 };
316
317 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
318 public:
319 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
320 static constexpr const char *pyClassName = "StringAttr";
321 using PyConcreteAttribute::PyConcreteAttribute;
322
bindDerived(ClassTy & c)323 static void bindDerived(ClassTy &c) {
324 c.def_static(
325 "get",
326 [](std::string value, DefaultingPyMlirContext context) {
327 MlirAttribute attr =
328 mlirStringAttrGet(context->get(), toMlirStringRef(value));
329 return PyStringAttribute(context->getRef(), attr);
330 },
331 py::arg("value"), py::arg("context") = py::none(),
332 "Gets a uniqued string attribute");
333 c.def_static(
334 "get_typed",
335 [](PyType &type, std::string value) {
336 MlirAttribute attr =
337 mlirStringAttrTypedGet(type, toMlirStringRef(value));
338 return PyStringAttribute(type.getContext(), attr);
339 },
340
341 "Gets a uniqued string attribute associated to a type");
342 c.def_property_readonly(
343 "value",
344 [](PyStringAttribute &self) {
345 MlirStringRef stringRef = mlirStringAttrGetValue(self);
346 return py::str(stringRef.data, stringRef.length);
347 },
348 "Returns the value of the string attribute");
349 }
350 };
351
352 // TODO: Support construction of string elements.
353 class PyDenseElementsAttribute
354 : public PyConcreteAttribute<PyDenseElementsAttribute> {
355 public:
356 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
357 static constexpr const char *pyClassName = "DenseElementsAttr";
358 using PyConcreteAttribute::PyConcreteAttribute;
359
360 static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,Optional<PyType> explicitType,Optional<std::vector<int64_t>> explicitShape,DefaultingPyMlirContext contextWrapper)361 getFromBuffer(py::buffer array, bool signless, Optional<PyType> explicitType,
362 Optional<std::vector<int64_t>> explicitShape,
363 DefaultingPyMlirContext contextWrapper) {
364 // Request a contiguous view. In exotic cases, this will cause a copy.
365 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
366 Py_buffer *view = new Py_buffer();
367 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
368 delete view;
369 throw py::error_already_set();
370 }
371 py::buffer_info arrayInfo(view);
372 SmallVector<int64_t> shape;
373 if (explicitShape) {
374 shape.append(explicitShape->begin(), explicitShape->end());
375 } else {
376 shape.append(arrayInfo.shape.begin(),
377 arrayInfo.shape.begin() + arrayInfo.ndim);
378 }
379
380 MlirAttribute encodingAttr = mlirAttributeGetNull();
381 MlirContext context = contextWrapper->get();
382
383 // Detect format codes that are suitable for bulk loading. This includes
384 // all byte aligned integer and floating point types up to 8 bytes.
385 // Notably, this excludes, bool (which needs to be bit-packed) and
386 // other exotics which do not have a direct representation in the buffer
387 // protocol (i.e. complex, etc).
388 Optional<MlirType> bulkLoadElementType;
389 if (explicitType) {
390 bulkLoadElementType = *explicitType;
391 } else if (arrayInfo.format == "f") {
392 // f32
393 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
394 bulkLoadElementType = mlirF32TypeGet(context);
395 } else if (arrayInfo.format == "d") {
396 // f64
397 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
398 bulkLoadElementType = mlirF64TypeGet(context);
399 } else if (arrayInfo.format == "e") {
400 // f16
401 assert(arrayInfo.itemsize == 2 && "mismatched array itemsize");
402 bulkLoadElementType = mlirF16TypeGet(context);
403 } else if (isSignedIntegerFormat(arrayInfo.format)) {
404 if (arrayInfo.itemsize == 4) {
405 // i32
406 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32)
407 : mlirIntegerTypeSignedGet(context, 32);
408 } else if (arrayInfo.itemsize == 8) {
409 // i64
410 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64)
411 : mlirIntegerTypeSignedGet(context, 64);
412 } else if (arrayInfo.itemsize == 1) {
413 // i8
414 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
415 : mlirIntegerTypeSignedGet(context, 8);
416 } else if (arrayInfo.itemsize == 2) {
417 // i16
418 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16)
419 : mlirIntegerTypeSignedGet(context, 16);
420 }
421 } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
422 if (arrayInfo.itemsize == 4) {
423 // unsigned i32
424 bulkLoadElementType = signless
425 ? mlirIntegerTypeGet(context, 32)
426 : mlirIntegerTypeUnsignedGet(context, 32);
427 } else if (arrayInfo.itemsize == 8) {
428 // unsigned i64
429 bulkLoadElementType = signless
430 ? mlirIntegerTypeGet(context, 64)
431 : mlirIntegerTypeUnsignedGet(context, 64);
432 } else if (arrayInfo.itemsize == 1) {
433 // i8
434 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
435 : mlirIntegerTypeUnsignedGet(context, 8);
436 } else if (arrayInfo.itemsize == 2) {
437 // i16
438 bulkLoadElementType = signless
439 ? mlirIntegerTypeGet(context, 16)
440 : mlirIntegerTypeUnsignedGet(context, 16);
441 }
442 }
443 if (bulkLoadElementType) {
444 auto shapedType = mlirRankedTensorTypeGet(
445 shape.size(), shape.data(), *bulkLoadElementType, encodingAttr);
446 size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize;
447 MlirAttribute attr = mlirDenseElementsAttrRawBufferGet(
448 shapedType, rawBufferSize, arrayInfo.ptr);
449 if (mlirAttributeIsNull(attr)) {
450 throw std::invalid_argument(
451 "DenseElementsAttr could not be constructed from the given buffer. "
452 "This may mean that the Python buffer layout does not match that "
453 "MLIR expected layout and is a bug.");
454 }
455 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
456 }
457
458 throw std::invalid_argument(
459 std::string("unimplemented array format conversion from format: ") +
460 arrayInfo.format);
461 }
462
getSplat(PyType shapedType,PyAttribute & elementAttr)463 static PyDenseElementsAttribute getSplat(PyType shapedType,
464 PyAttribute &elementAttr) {
465 auto contextWrapper =
466 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
467 if (!mlirAttributeIsAInteger(elementAttr) &&
468 !mlirAttributeIsAFloat(elementAttr)) {
469 std::string message = "Illegal element type for DenseElementsAttr: ";
470 message.append(py::repr(py::cast(elementAttr)));
471 throw SetPyError(PyExc_ValueError, message);
472 }
473 if (!mlirTypeIsAShaped(shapedType) ||
474 !mlirShapedTypeHasStaticShape(shapedType)) {
475 std::string message =
476 "Expected a static ShapedType for the shaped_type parameter: ";
477 message.append(py::repr(py::cast(shapedType)));
478 throw SetPyError(PyExc_ValueError, message);
479 }
480 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
481 MlirType attrType = mlirAttributeGetType(elementAttr);
482 if (!mlirTypeEqual(shapedElementType, attrType)) {
483 std::string message =
484 "Shaped element type and attribute type must be equal: shaped=";
485 message.append(py::repr(py::cast(shapedType)));
486 message.append(", element=");
487 message.append(py::repr(py::cast(elementAttr)));
488 throw SetPyError(PyExc_ValueError, message);
489 }
490
491 MlirAttribute elements =
492 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
493 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
494 }
495
dunderLen()496 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
497
accessBuffer()498 py::buffer_info accessBuffer() {
499 if (mlirDenseElementsAttrIsSplat(*this)) {
500 // TODO: Currently crashes the program.
501 // Reported as https://github.com/pybind/pybind11/issues/3336
502 throw std::invalid_argument(
503 "unsupported data type for conversion to Python buffer");
504 }
505
506 MlirType shapedType = mlirAttributeGetType(*this);
507 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
508 std::string format;
509
510 if (mlirTypeIsAF32(elementType)) {
511 // f32
512 return bufferInfo<float>(shapedType);
513 } else if (mlirTypeIsAF64(elementType)) {
514 // f64
515 return bufferInfo<double>(shapedType);
516 } else if (mlirTypeIsAF16(elementType)) {
517 // f16
518 return bufferInfo<uint16_t>(shapedType, "e");
519 } else if (mlirTypeIsAInteger(elementType) &&
520 mlirIntegerTypeGetWidth(elementType) == 32) {
521 if (mlirIntegerTypeIsSignless(elementType) ||
522 mlirIntegerTypeIsSigned(elementType)) {
523 // i32
524 return bufferInfo<int32_t>(shapedType);
525 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
526 // unsigned i32
527 return bufferInfo<uint32_t>(shapedType);
528 }
529 } else if (mlirTypeIsAInteger(elementType) &&
530 mlirIntegerTypeGetWidth(elementType) == 64) {
531 if (mlirIntegerTypeIsSignless(elementType) ||
532 mlirIntegerTypeIsSigned(elementType)) {
533 // i64
534 return bufferInfo<int64_t>(shapedType);
535 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
536 // unsigned i64
537 return bufferInfo<uint64_t>(shapedType);
538 }
539 } else if (mlirTypeIsAInteger(elementType) &&
540 mlirIntegerTypeGetWidth(elementType) == 8) {
541 if (mlirIntegerTypeIsSignless(elementType) ||
542 mlirIntegerTypeIsSigned(elementType)) {
543 // i8
544 return bufferInfo<int8_t>(shapedType);
545 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
546 // unsigned i8
547 return bufferInfo<uint8_t>(shapedType);
548 }
549 } else if (mlirTypeIsAInteger(elementType) &&
550 mlirIntegerTypeGetWidth(elementType) == 16) {
551 if (mlirIntegerTypeIsSignless(elementType) ||
552 mlirIntegerTypeIsSigned(elementType)) {
553 // i16
554 return bufferInfo<int16_t>(shapedType);
555 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
556 // unsigned i16
557 return bufferInfo<uint16_t>(shapedType);
558 }
559 }
560
561 // TODO: Currently crashes the program.
562 // Reported as https://github.com/pybind/pybind11/issues/3336
563 throw std::invalid_argument(
564 "unsupported data type for conversion to Python buffer");
565 }
566
bindDerived(ClassTy & c)567 static void bindDerived(ClassTy &c) {
568 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
569 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
570 py::arg("array"), py::arg("signless") = true,
571 py::arg("type") = py::none(), py::arg("shape") = py::none(),
572 py::arg("context") = py::none(),
573 kDenseElementsAttrGetDocstring)
574 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
575 py::arg("shaped_type"), py::arg("element_attr"),
576 "Gets a DenseElementsAttr where all values are the same")
577 .def_property_readonly("is_splat",
578 [](PyDenseElementsAttribute &self) -> bool {
579 return mlirDenseElementsAttrIsSplat(self);
580 })
581 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
582 }
583
584 private:
isUnsignedIntegerFormat(const std::string & format)585 static bool isUnsignedIntegerFormat(const std::string &format) {
586 if (format.empty())
587 return false;
588 char code = format[0];
589 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
590 code == 'Q';
591 }
592
isSignedIntegerFormat(const std::string & format)593 static bool isSignedIntegerFormat(const std::string &format) {
594 if (format.empty())
595 return false;
596 char code = format[0];
597 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
598 code == 'q';
599 }
600
601 template <typename Type>
bufferInfo(MlirType shapedType,const char * explicitFormat=nullptr)602 py::buffer_info bufferInfo(MlirType shapedType,
603 const char *explicitFormat = nullptr) {
604 intptr_t rank = mlirShapedTypeGetRank(shapedType);
605 // Prepare the data for the buffer_info.
606 // Buffer is configured for read-only access below.
607 Type *data = static_cast<Type *>(
608 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
609 // Prepare the shape for the buffer_info.
610 SmallVector<intptr_t, 4> shape;
611 for (intptr_t i = 0; i < rank; ++i)
612 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
613 // Prepare the strides for the buffer_info.
614 SmallVector<intptr_t, 4> strides;
615 intptr_t strideFactor = 1;
616 for (intptr_t i = 1; i < rank; ++i) {
617 strideFactor = 1;
618 for (intptr_t j = i; j < rank; ++j) {
619 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
620 }
621 strides.push_back(sizeof(Type) * strideFactor);
622 }
623 strides.push_back(sizeof(Type));
624 std::string format;
625 if (explicitFormat) {
626 format = explicitFormat;
627 } else {
628 format = py::format_descriptor<Type>::format();
629 }
630 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
631 /*readonly=*/true);
632 }
633 }; // namespace
634
635 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
636 /// (and boolean) values. Supports element access.
637 class PyDenseIntElementsAttribute
638 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
639 PyDenseElementsAttribute> {
640 public:
641 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
642 static constexpr const char *pyClassName = "DenseIntElementsAttr";
643 using PyConcreteAttribute::PyConcreteAttribute;
644
645 /// Returns the element at the given linear position. Asserts if the index is
646 /// out of range.
dunderGetItem(intptr_t pos)647 py::int_ dunderGetItem(intptr_t pos) {
648 if (pos < 0 || pos >= dunderLen()) {
649 throw SetPyError(PyExc_IndexError,
650 "attempt to access out of bounds element");
651 }
652
653 MlirType type = mlirAttributeGetType(*this);
654 type = mlirShapedTypeGetElementType(type);
655 assert(mlirTypeIsAInteger(type) &&
656 "expected integer element type in dense int elements attribute");
657 // Dispatch element extraction to an appropriate C function based on the
658 // elemental type of the attribute. py::int_ is implicitly constructible
659 // from any C++ integral type and handles bitwidth correctly.
660 // TODO: consider caching the type properties in the constructor to avoid
661 // querying them on each element access.
662 unsigned width = mlirIntegerTypeGetWidth(type);
663 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
664 if (isUnsigned) {
665 if (width == 1) {
666 return mlirDenseElementsAttrGetBoolValue(*this, pos);
667 }
668 if (width == 32) {
669 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
670 }
671 if (width == 64) {
672 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
673 }
674 } else {
675 if (width == 1) {
676 return mlirDenseElementsAttrGetBoolValue(*this, pos);
677 }
678 if (width == 32) {
679 return mlirDenseElementsAttrGetInt32Value(*this, pos);
680 }
681 if (width == 64) {
682 return mlirDenseElementsAttrGetInt64Value(*this, pos);
683 }
684 }
685 throw SetPyError(PyExc_TypeError, "Unsupported integer type");
686 }
687
bindDerived(ClassTy & c)688 static void bindDerived(ClassTy &c) {
689 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
690 }
691 };
692
693 class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
694 public:
695 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
696 static constexpr const char *pyClassName = "DictAttr";
697 using PyConcreteAttribute::PyConcreteAttribute;
698
dunderLen()699 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
700
bindDerived(ClassTy & c)701 static void bindDerived(ClassTy &c) {
702 c.def("__len__", &PyDictAttribute::dunderLen);
703 c.def_static(
704 "get",
705 [](py::dict attributes, DefaultingPyMlirContext context) {
706 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
707 mlirNamedAttributes.reserve(attributes.size());
708 for (auto &it : attributes) {
709 auto &mlir_attr = it.second.cast<PyAttribute &>();
710 auto name = it.first.cast<std::string>();
711 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
712 mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
713 toMlirStringRef(name)),
714 mlir_attr));
715 }
716 MlirAttribute attr =
717 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
718 mlirNamedAttributes.data());
719 return PyDictAttribute(context->getRef(), attr);
720 },
721 py::arg("value") = py::dict(), py::arg("context") = py::none(),
722 "Gets an uniqued dict attribute");
723 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
724 MlirAttribute attr =
725 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
726 if (mlirAttributeIsNull(attr)) {
727 throw SetPyError(PyExc_KeyError,
728 "attempt to access a non-existent attribute");
729 }
730 return PyAttribute(self.getContext(), attr);
731 });
732 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
733 if (index < 0 || index >= self.dunderLen()) {
734 throw SetPyError(PyExc_IndexError,
735 "attempt to access out of bounds attribute");
736 }
737 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
738 return PyNamedAttribute(
739 namedAttr.attribute,
740 std::string(mlirIdentifierStr(namedAttr.name).data));
741 });
742 }
743 };
744
745 /// Refinement of PyDenseElementsAttribute for attributes containing
746 /// floating-point values. Supports element access.
747 class PyDenseFPElementsAttribute
748 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
749 PyDenseElementsAttribute> {
750 public:
751 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
752 static constexpr const char *pyClassName = "DenseFPElementsAttr";
753 using PyConcreteAttribute::PyConcreteAttribute;
754
dunderGetItem(intptr_t pos)755 py::float_ dunderGetItem(intptr_t pos) {
756 if (pos < 0 || pos >= dunderLen()) {
757 throw SetPyError(PyExc_IndexError,
758 "attempt to access out of bounds element");
759 }
760
761 MlirType type = mlirAttributeGetType(*this);
762 type = mlirShapedTypeGetElementType(type);
763 // Dispatch element extraction to an appropriate C function based on the
764 // elemental type of the attribute. py::float_ is implicitly constructible
765 // from float and double.
766 // TODO: consider caching the type properties in the constructor to avoid
767 // querying them on each element access.
768 if (mlirTypeIsAF32(type)) {
769 return mlirDenseElementsAttrGetFloatValue(*this, pos);
770 }
771 if (mlirTypeIsAF64(type)) {
772 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
773 }
774 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
775 }
776
bindDerived(ClassTy & c)777 static void bindDerived(ClassTy &c) {
778 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
779 }
780 };
781
782 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
783 public:
784 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
785 static constexpr const char *pyClassName = "TypeAttr";
786 using PyConcreteAttribute::PyConcreteAttribute;
787
bindDerived(ClassTy & c)788 static void bindDerived(ClassTy &c) {
789 c.def_static(
790 "get",
791 [](PyType value, DefaultingPyMlirContext context) {
792 MlirAttribute attr = mlirTypeAttrGet(value.get());
793 return PyTypeAttribute(context->getRef(), attr);
794 },
795 py::arg("value"), py::arg("context") = py::none(),
796 "Gets a uniqued Type attribute");
797 c.def_property_readonly("value", [](PyTypeAttribute &self) {
798 return PyType(self.getContext()->getRef(),
799 mlirTypeAttrGetValue(self.get()));
800 });
801 }
802 };
803
804 /// Unit Attribute subclass. Unit attributes don't have values.
805 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
806 public:
807 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
808 static constexpr const char *pyClassName = "UnitAttr";
809 using PyConcreteAttribute::PyConcreteAttribute;
810
bindDerived(ClassTy & c)811 static void bindDerived(ClassTy &c) {
812 c.def_static(
813 "get",
814 [](DefaultingPyMlirContext context) {
815 return PyUnitAttribute(context->getRef(),
816 mlirUnitAttrGet(context->get()));
817 },
818 py::arg("context") = py::none(), "Create a Unit attribute.");
819 }
820 };
821
822 } // namespace
823
populateIRAttributes(py::module & m)824 void mlir::python::populateIRAttributes(py::module &m) {
825 PyAffineMapAttribute::bind(m);
826 PyArrayAttribute::bind(m);
827 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
828 PyBoolAttribute::bind(m);
829 PyDenseElementsAttribute::bind(m);
830 PyDenseFPElementsAttribute::bind(m);
831 PyDenseIntElementsAttribute::bind(m);
832 PyDictAttribute::bind(m);
833 PyFlatSymbolRefAttribute::bind(m);
834 PyFloatAttribute::bind(m);
835 PyIntegerAttribute::bind(m);
836 PyStringAttribute::bind(m);
837 PyTypeAttribute::bind(m);
838 PyUnitAttribute::bind(m);
839 }
840