1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on standard types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/Attributes.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/StringRef.h"
24
25 namespace mlir {
26
isRowMajorMatmul(ArrayAttr indexingMaps)27 inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
28 auto context = indexingMaps.getContext();
29 AffineExpr m, n, k;
30 bindDims(context, m, n, k);
31 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
32 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
33 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
34 auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
35 return indexingMaps == maps;
36 }
37
isColumnMajorMatmul(ArrayAttr indexingMaps)38 inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
39 auto context = indexingMaps.getContext();
40 AffineExpr m, n, k;
41 bindDims(context, m, n, k);
42 auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
43 auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
44 auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
45 auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
46 return indexingMaps == maps;
47 }
48
49 /// Attribute name for the AffineArrayAttr which encodes the relationship
50 /// between a structured op iterators' and its operands.
getIndexingMapsAttrName()51 constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
52
53 /// Attribute name for the StrArrayAttr which encodes the type of a structured
54 /// op's iterators.
getIteratorTypesAttrName()55 constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }
56
57 /// Attribute name for the IntegerAttr which encodes the number of input buffer
58 /// arguments.
getArgsInAttrName()59 constexpr StringRef getArgsInAttrName() { return "args_in"; }
60
61 /// Attribute name for the IntegerAttr which encodes the number of input buffer
62 /// arguments.
getArgsOutAttrName()63 constexpr StringRef getArgsOutAttrName() { return "args_out"; }
64
65 /// Attribute name for the StringAttr which encodes an optional documentation
66 /// string of the structured op.
getDocAttrName()67 constexpr StringRef getDocAttrName() { return "doc"; }
68
69 /// Attribute name for the StrArrayAttr which encodes the external library
70 /// function that implements the structured op.
getLibraryCallAttrName()71 constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
72
73 /// Attribute name for the StrArrayAttr which encodes the value of strides.
getStridesAttrName()74 constexpr StringRef getStridesAttrName() { return "strides"; }
75
76 /// Attribute name for the StrArrayAttr which encodes the value of dilations.
getDilationsAttrName()77 constexpr StringRef getDilationsAttrName() { return "dilations"; }
78
79 /// Attribute name for the StrArrayAttr which encodes the value of paddings.
getPaddingAttrName()80 constexpr StringRef getPaddingAttrName() { return "padding"; }
81
82 /// Use to encode that a particular iterator type has parallel semantics.
getParallelIteratorTypeName()83 constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }
isParallelIterator(Attribute attr)84 inline bool isParallelIterator(Attribute attr) {
85 auto strAttr = attr.dyn_cast_or_null<StringAttr>();
86 return strAttr && strAttr.getValue() == getParallelIteratorTypeName();
87 }
88
89 /// Use to encode that a particular iterator type has reduction semantics.
getReductionIteratorTypeName()90 constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }
isReductionIterator(Attribute attr)91 inline bool isReductionIterator(Attribute attr) {
92 auto strAttr = attr.dyn_cast_or_null<StringAttr>();
93 return strAttr && strAttr.getValue() == getReductionIteratorTypeName();
94 }
95
96 /// Use to encode that a particular iterator type has window semantics.
getWindowIteratorTypeName()97 constexpr StringRef getWindowIteratorTypeName() { return "window"; }
isWindowIterator(Attribute attr)98 inline bool isWindowIterator(Attribute attr) {
99 auto strAttr = attr.dyn_cast_or_null<StringAttr>();
100 return strAttr && strAttr.getValue() == getWindowIteratorTypeName();
101 }
102
103 /// Use to encode that a particular iterator type has window semantics.
getAllIteratorTypeNames()104 inline ArrayRef<StringRef> getAllIteratorTypeNames() {
105 static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
106 getReductionIteratorTypeName(),
107 getWindowIteratorTypeName()};
108 return llvm::makeArrayRef(names);
109 }
110
111 /// Returns the iterator of a certain type.
getNumIterators(StringRef name,ArrayAttr iteratorTypes)112 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
113 auto names = getAllIteratorTypeNames();
114 (void)names;
115 assert(llvm::is_contained(names, name));
116 return llvm::count_if(iteratorTypes, [name](Attribute a) {
117 return a.cast<StringAttr>().getValue() == name;
118 });
119 }
120
getNumIterators(ArrayAttr iteratorTypes)121 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
122 unsigned res = 0;
123 for (auto n : getAllIteratorTypeNames())
124 res += getNumIterators(n, iteratorTypes);
125 return res;
126 }
127
128 /// Typed representation for loop type strings.
129 enum class IteratorType { Parallel, Reduction };
130
toString(IteratorType t)131 inline StringRef toString(IteratorType t) {
132 switch (t) {
133 case IteratorType::Parallel:
134 return getParallelIteratorTypeName();
135 case IteratorType::Reduction:
136 return getReductionIteratorTypeName();
137 }
138 llvm_unreachable("Unsupported IteratorType");
139 }
140
141 } // end namespace mlir
142
143 #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
144