1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23
24 using namespace mlir;
25
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29
30 #include "mlir/IR/BuiltinDialect.cpp.inc"
31
32 namespace {
33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
34 using OpAsmDialectInterface::OpAsmDialectInterface;
35
getAlias__anon4aef3d7a0111::BuiltinOpAsmDialectInterface36 LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
37 if (attr.isa<AffineMapAttr>()) {
38 os << "map";
39 return success();
40 }
41 if (attr.isa<IntegerSetAttr>()) {
42 os << "set";
43 return success();
44 }
45 if (attr.isa<LocationAttr>()) {
46 os << "loc";
47 return success();
48 }
49 return failure();
50 }
51
getAlias__anon4aef3d7a0111::BuiltinOpAsmDialectInterface52 LogicalResult getAlias(Type type, raw_ostream &os) const final {
53 if (auto tupleType = type.dyn_cast<TupleType>()) {
54 if (tupleType.size() > 16) {
55 os << "tuple";
56 return success();
57 }
58 }
59 return failure();
60 }
61 };
62 } // end anonymous namespace.
63
initialize()64 void BuiltinDialect::initialize() {
65 registerTypes();
66 registerAttributes();
67 registerLocationAttributes();
68 addOperations<
69 #define GET_OP_LIST
70 #include "mlir/IR/BuiltinOps.cpp.inc"
71 >();
72 addInterfaces<BuiltinOpAsmDialectInterface>();
73 }
74
75 //===----------------------------------------------------------------------===//
76 // FuncOp
77 //===----------------------------------------------------------------------===//
78
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs)79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
80 ArrayRef<NamedAttribute> attrs) {
81 OperationState state(location, "func");
82 OpBuilder builder(location->getContext());
83 FuncOp::build(builder, state, name, type, attrs);
84 return cast<FuncOp>(Operation::create(state));
85 }
create(Location location,StringRef name,FunctionType type,Operation::dialect_attr_range attrs)86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87 Operation::dialect_attr_range attrs) {
88 SmallVector<NamedAttribute, 8> attrRef(attrs);
89 return create(location, name, type, llvm::makeArrayRef(attrRef));
90 }
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
92 ArrayRef<NamedAttribute> attrs,
93 ArrayRef<DictionaryAttr> argAttrs) {
94 FuncOp func = create(location, name, type, attrs);
95 func.setAllArgAttrs(argAttrs);
96 return func;
97 }
98
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
100 FunctionType type, ArrayRef<NamedAttribute> attrs,
101 ArrayRef<DictionaryAttr> argAttrs) {
102 state.addAttribute(SymbolTable::getSymbolAttrName(),
103 builder.getStringAttr(name));
104 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
105 state.attributes.append(attrs.begin(), attrs.end());
106 state.addRegion();
107
108 if (argAttrs.empty())
109 return;
110 assert(type.getNumInputs() == argAttrs.size());
111 function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
112 /*resultAttrs=*/llvm::None);
113 }
114
parseFuncOp(OpAsmParser & parser,OperationState & result)115 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
116 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
117 ArrayRef<Type> results,
118 function_like_impl::VariadicFlag, std::string &) {
119 return builder.getFunctionType(argTypes, results);
120 };
121
122 return function_like_impl::parseFunctionLikeOp(
123 parser, result, /*allowVariadic=*/false, buildFuncType);
124 }
125
print(FuncOp op,OpAsmPrinter & p)126 static void print(FuncOp op, OpAsmPrinter &p) {
127 FunctionType fnType = op.getType();
128 function_like_impl::printFunctionLikeOp(
129 p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
130 }
131
verify(FuncOp op)132 static LogicalResult verify(FuncOp op) {
133 // If this function is external there is nothing to do.
134 if (op.isExternal())
135 return success();
136
137 // Verify that the argument list of the function and the arg list of the entry
138 // block line up. The trait already verified that the number of arguments is
139 // the same between the signature and the block.
140 auto fnInputTypes = op.getType().getInputs();
141 Block &entryBlock = op.front();
142 for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
143 if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
144 return op.emitOpError("type of entry block argument #")
145 << i << '(' << entryBlock.getArgument(i).getType()
146 << ") must match the type of the corresponding argument in "
147 << "function signature(" << fnInputTypes[i] << ')';
148
149 return success();
150 }
151
152 /// Clone the internal blocks from this function into dest and all attributes
153 /// from this function to dest.
cloneInto(FuncOp dest,BlockAndValueMapping & mapper)154 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
155 // Add the attributes of this function to dest.
156 llvm::MapVector<Identifier, Attribute> newAttrs;
157 for (const auto &attr : dest->getAttrs())
158 newAttrs.insert(attr);
159 for (const auto &attr : (*this)->getAttrs())
160 newAttrs.insert(attr);
161 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
162
163 // Clone the body.
164 getBody().cloneInto(&dest.getBody(), mapper);
165 }
166
167 /// Create a deep copy of this function and all of its blocks, remapping
168 /// any operands that use values outside of the function using the map that is
169 /// provided (leaving them alone if no entry is present). Replaces references
170 /// to cloned sub-values with the corresponding value that is copied, and adds
171 /// those mappings to the mapper.
clone(BlockAndValueMapping & mapper)172 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
173 // Create the new function.
174 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
175
176 // If the function has a body, then the user might be deleting arguments to
177 // the function by specifying them in the mapper. If so, we don't add the
178 // argument to the input type vector.
179 if (!isExternal()) {
180 FunctionType oldType = getType();
181
182 unsigned oldNumArgs = oldType.getNumInputs();
183 SmallVector<Type, 4> newInputs;
184 newInputs.reserve(oldNumArgs);
185 for (unsigned i = 0; i != oldNumArgs; ++i)
186 if (!mapper.contains(getArgument(i)))
187 newInputs.push_back(oldType.getInput(i));
188
189 /// If any of the arguments were dropped, update the type and drop any
190 /// necessary argument attributes.
191 if (newInputs.size() != oldNumArgs) {
192 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
193 oldType.getResults()));
194
195 if (ArrayAttr argAttrs = getAllArgAttrs()) {
196 SmallVector<Attribute> newArgAttrs;
197 newArgAttrs.reserve(newInputs.size());
198 for (unsigned i = 0; i != oldNumArgs; ++i)
199 if (!mapper.contains(getArgument(i)))
200 newArgAttrs.push_back(argAttrs[i]);
201 newFunc.setAllArgAttrs(newArgAttrs);
202 }
203 }
204 }
205
206 /// Clone the current function into the new one and return it.
207 cloneInto(newFunc, mapper);
208 return newFunc;
209 }
clone()210 FuncOp FuncOp::clone() {
211 BlockAndValueMapping mapper;
212 return clone(mapper);
213 }
214
215 //===----------------------------------------------------------------------===//
216 // ModuleOp
217 //===----------------------------------------------------------------------===//
218
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)219 void ModuleOp::build(OpBuilder &builder, OperationState &state,
220 Optional<StringRef> name) {
221 state.addRegion()->emplaceBlock();
222 if (name) {
223 state.attributes.push_back(builder.getNamedAttr(
224 mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
225 }
226 }
227
228 /// Construct a module from the given context.
create(Location loc,Optional<StringRef> name)229 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
230 OpBuilder builder(loc->getContext());
231 return builder.create<ModuleOp>(loc, name);
232 }
233
getDataLayoutSpec()234 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
235 // Take the first and only (if present) attribute that implements the
236 // interface. This needs a linear search, but is called only once per data
237 // layout object construction that is used for repeated queries.
238 for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) {
239 if (auto spec = attr.dyn_cast<DataLayoutSpecInterface>())
240 return spec;
241 }
242 return {};
243 }
244
verify(ModuleOp op)245 static LogicalResult verify(ModuleOp op) {
246 // Check that none of the attributes are non-dialect attributes, except for
247 // the symbol related attributes.
248 for (auto attr : op->getAttrs()) {
249 if (!attr.first.strref().contains('.') &&
250 !llvm::is_contained(
251 ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
252 mlir::SymbolTable::getVisibilityAttrName()},
253 attr.first.strref()))
254 return op.emitOpError() << "can only contain attributes with "
255 "dialect-prefixed names, found: '"
256 << attr.first << "'";
257 }
258
259 // Check that there is at most one data layout spec attribute.
260 StringRef layoutSpecAttrName;
261 DataLayoutSpecInterface layoutSpec;
262 for (const NamedAttribute &na : op->getAttrs()) {
263 if (auto spec = na.second.dyn_cast<DataLayoutSpecInterface>()) {
264 if (layoutSpec) {
265 InFlightDiagnostic diag =
266 op.emitOpError() << "expects at most one data layout attribute";
267 diag.attachNote() << "'" << layoutSpecAttrName
268 << "' is a data layout attribute";
269 diag.attachNote() << "'" << na.first << "' is a data layout attribute";
270 }
271 layoutSpecAttrName = na.first.strref();
272 layoutSpec = spec;
273 }
274 }
275
276 return success();
277 }
278
279 //===----------------------------------------------------------------------===//
280 // UnrealizedConversionCastOp
281 //===----------------------------------------------------------------------===//
282
283 LogicalResult
fold(ArrayRef<Attribute> attrOperands,SmallVectorImpl<OpFoldResult> & foldResults)284 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
285 SmallVectorImpl<OpFoldResult> &foldResults) {
286 OperandRange operands = inputs();
287 ResultRange results = outputs();
288
289 if (operands.getType() == results.getType()) {
290 foldResults.append(operands.begin(), operands.end());
291 return success();
292 }
293
294 if (operands.empty())
295 return failure();
296
297 // Check that the input is a cast with results that all feed into this
298 // operation, and operand types that directly match the result types of this
299 // operation.
300 Value firstInput = operands.front();
301 auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
302 if (!inputOp || inputOp.getResults() != operands ||
303 inputOp.getOperandTypes() != results.getTypes())
304 return failure();
305
306 // If everything matches up, we can fold the passthrough.
307 foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
308 return success();
309 }
310
areCastCompatible(TypeRange inputs,TypeRange outputs)311 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
312 TypeRange outputs) {
313 // `UnrealizedConversionCastOp` is agnostic of the input/output types.
314 return true;
315 }
316
317 //===----------------------------------------------------------------------===//
318 // TableGen'd op method definitions
319 //===----------------------------------------------------------------------===//
320
321 #define GET_OP_CLASSES
322 #include "mlir/IR/BuiltinOps.cpp.inc"
323