1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 #include "mlir/IR/BuiltinDialect.cpp.inc"
31 
32 namespace {
33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
34   using OpAsmDialectInterface::OpAsmDialectInterface;
35 
getAlias__anon4aef3d7a0111::BuiltinOpAsmDialectInterface36   LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
37     if (attr.isa<AffineMapAttr>()) {
38       os << "map";
39       return success();
40     }
41     if (attr.isa<IntegerSetAttr>()) {
42       os << "set";
43       return success();
44     }
45     if (attr.isa<LocationAttr>()) {
46       os << "loc";
47       return success();
48     }
49     return failure();
50   }
51 
getAlias__anon4aef3d7a0111::BuiltinOpAsmDialectInterface52   LogicalResult getAlias(Type type, raw_ostream &os) const final {
53     if (auto tupleType = type.dyn_cast<TupleType>()) {
54       if (tupleType.size() > 16) {
55         os << "tuple";
56         return success();
57       }
58     }
59     return failure();
60   }
61 };
62 } // end anonymous namespace.
63 
initialize()64 void BuiltinDialect::initialize() {
65   registerTypes();
66   registerAttributes();
67   registerLocationAttributes();
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/IR/BuiltinOps.cpp.inc"
71       >();
72   addInterfaces<BuiltinOpAsmDialectInterface>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // FuncOp
77 //===----------------------------------------------------------------------===//
78 
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs)79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
80                       ArrayRef<NamedAttribute> attrs) {
81   OperationState state(location, "func");
82   OpBuilder builder(location->getContext());
83   FuncOp::build(builder, state, name, type, attrs);
84   return cast<FuncOp>(Operation::create(state));
85 }
create(Location location,StringRef name,FunctionType type,Operation::dialect_attr_range attrs)86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87                       Operation::dialect_attr_range attrs) {
88   SmallVector<NamedAttribute, 8> attrRef(attrs);
89   return create(location, name, type, llvm::makeArrayRef(attrRef));
90 }
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
92                       ArrayRef<NamedAttribute> attrs,
93                       ArrayRef<DictionaryAttr> argAttrs) {
94   FuncOp func = create(location, name, type, attrs);
95   func.setAllArgAttrs(argAttrs);
96   return func;
97 }
98 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
100                    FunctionType type, ArrayRef<NamedAttribute> attrs,
101                    ArrayRef<DictionaryAttr> argAttrs) {
102   state.addAttribute(SymbolTable::getSymbolAttrName(),
103                      builder.getStringAttr(name));
104   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
105   state.attributes.append(attrs.begin(), attrs.end());
106   state.addRegion();
107 
108   if (argAttrs.empty())
109     return;
110   assert(type.getNumInputs() == argAttrs.size());
111   function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
112                                            /*resultAttrs=*/llvm::None);
113 }
114 
parseFuncOp(OpAsmParser & parser,OperationState & result)115 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
116   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
117                           ArrayRef<Type> results,
118                           function_like_impl::VariadicFlag, std::string &) {
119     return builder.getFunctionType(argTypes, results);
120   };
121 
122   return function_like_impl::parseFunctionLikeOp(
123       parser, result, /*allowVariadic=*/false, buildFuncType);
124 }
125 
print(FuncOp op,OpAsmPrinter & p)126 static void print(FuncOp op, OpAsmPrinter &p) {
127   FunctionType fnType = op.getType();
128   function_like_impl::printFunctionLikeOp(
129       p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
130 }
131 
verify(FuncOp op)132 static LogicalResult verify(FuncOp op) {
133   // If this function is external there is nothing to do.
134   if (op.isExternal())
135     return success();
136 
137   // Verify that the argument list of the function and the arg list of the entry
138   // block line up.  The trait already verified that the number of arguments is
139   // the same between the signature and the block.
140   auto fnInputTypes = op.getType().getInputs();
141   Block &entryBlock = op.front();
142   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
143     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
144       return op.emitOpError("type of entry block argument #")
145              << i << '(' << entryBlock.getArgument(i).getType()
146              << ") must match the type of the corresponding argument in "
147              << "function signature(" << fnInputTypes[i] << ')';
148 
149   return success();
150 }
151 
152 /// Clone the internal blocks from this function into dest and all attributes
153 /// from this function to dest.
cloneInto(FuncOp dest,BlockAndValueMapping & mapper)154 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
155   // Add the attributes of this function to dest.
156   llvm::MapVector<Identifier, Attribute> newAttrs;
157   for (const auto &attr : dest->getAttrs())
158     newAttrs.insert(attr);
159   for (const auto &attr : (*this)->getAttrs())
160     newAttrs.insert(attr);
161   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
162 
163   // Clone the body.
164   getBody().cloneInto(&dest.getBody(), mapper);
165 }
166 
167 /// Create a deep copy of this function and all of its blocks, remapping
168 /// any operands that use values outside of the function using the map that is
169 /// provided (leaving them alone if no entry is present). Replaces references
170 /// to cloned sub-values with the corresponding value that is copied, and adds
171 /// those mappings to the mapper.
clone(BlockAndValueMapping & mapper)172 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
173   // Create the new function.
174   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
175 
176   // If the function has a body, then the user might be deleting arguments to
177   // the function by specifying them in the mapper. If so, we don't add the
178   // argument to the input type vector.
179   if (!isExternal()) {
180     FunctionType oldType = getType();
181 
182     unsigned oldNumArgs = oldType.getNumInputs();
183     SmallVector<Type, 4> newInputs;
184     newInputs.reserve(oldNumArgs);
185     for (unsigned i = 0; i != oldNumArgs; ++i)
186       if (!mapper.contains(getArgument(i)))
187         newInputs.push_back(oldType.getInput(i));
188 
189     /// If any of the arguments were dropped, update the type and drop any
190     /// necessary argument attributes.
191     if (newInputs.size() != oldNumArgs) {
192       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
193                                         oldType.getResults()));
194 
195       if (ArrayAttr argAttrs = getAllArgAttrs()) {
196         SmallVector<Attribute> newArgAttrs;
197         newArgAttrs.reserve(newInputs.size());
198         for (unsigned i = 0; i != oldNumArgs; ++i)
199           if (!mapper.contains(getArgument(i)))
200             newArgAttrs.push_back(argAttrs[i]);
201         newFunc.setAllArgAttrs(newArgAttrs);
202       }
203     }
204   }
205 
206   /// Clone the current function into the new one and return it.
207   cloneInto(newFunc, mapper);
208   return newFunc;
209 }
clone()210 FuncOp FuncOp::clone() {
211   BlockAndValueMapping mapper;
212   return clone(mapper);
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // ModuleOp
217 //===----------------------------------------------------------------------===//
218 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)219 void ModuleOp::build(OpBuilder &builder, OperationState &state,
220                      Optional<StringRef> name) {
221   state.addRegion()->emplaceBlock();
222   if (name) {
223     state.attributes.push_back(builder.getNamedAttr(
224         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
225   }
226 }
227 
228 /// Construct a module from the given context.
create(Location loc,Optional<StringRef> name)229 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
230   OpBuilder builder(loc->getContext());
231   return builder.create<ModuleOp>(loc, name);
232 }
233 
getDataLayoutSpec()234 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
235   // Take the first and only (if present) attribute that implements the
236   // interface. This needs a linear search, but is called only once per data
237   // layout object construction that is used for repeated queries.
238   for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) {
239     if (auto spec = attr.dyn_cast<DataLayoutSpecInterface>())
240       return spec;
241   }
242   return {};
243 }
244 
verify(ModuleOp op)245 static LogicalResult verify(ModuleOp op) {
246   // Check that none of the attributes are non-dialect attributes, except for
247   // the symbol related attributes.
248   for (auto attr : op->getAttrs()) {
249     if (!attr.first.strref().contains('.') &&
250         !llvm::is_contained(
251             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
252                                 mlir::SymbolTable::getVisibilityAttrName()},
253             attr.first.strref()))
254       return op.emitOpError() << "can only contain attributes with "
255                                  "dialect-prefixed names, found: '"
256                               << attr.first << "'";
257   }
258 
259   // Check that there is at most one data layout spec attribute.
260   StringRef layoutSpecAttrName;
261   DataLayoutSpecInterface layoutSpec;
262   for (const NamedAttribute &na : op->getAttrs()) {
263     if (auto spec = na.second.dyn_cast<DataLayoutSpecInterface>()) {
264       if (layoutSpec) {
265         InFlightDiagnostic diag =
266             op.emitOpError() << "expects at most one data layout attribute";
267         diag.attachNote() << "'" << layoutSpecAttrName
268                           << "' is a data layout attribute";
269         diag.attachNote() << "'" << na.first << "' is a data layout attribute";
270       }
271       layoutSpecAttrName = na.first.strref();
272       layoutSpec = spec;
273     }
274   }
275 
276   return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // UnrealizedConversionCastOp
281 //===----------------------------------------------------------------------===//
282 
283 LogicalResult
fold(ArrayRef<Attribute> attrOperands,SmallVectorImpl<OpFoldResult> & foldResults)284 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
285                                  SmallVectorImpl<OpFoldResult> &foldResults) {
286   OperandRange operands = inputs();
287   ResultRange results = outputs();
288 
289   if (operands.getType() == results.getType()) {
290     foldResults.append(operands.begin(), operands.end());
291     return success();
292   }
293 
294   if (operands.empty())
295     return failure();
296 
297   // Check that the input is a cast with results that all feed into this
298   // operation, and operand types that directly match the result types of this
299   // operation.
300   Value firstInput = operands.front();
301   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
302   if (!inputOp || inputOp.getResults() != operands ||
303       inputOp.getOperandTypes() != results.getTypes())
304     return failure();
305 
306   // If everything matches up, we can fold the passthrough.
307   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
308   return success();
309 }
310 
areCastCompatible(TypeRange inputs,TypeRange outputs)311 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
312                                                    TypeRange outputs) {
313   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
314   return true;
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // TableGen'd op method definitions
319 //===----------------------------------------------------------------------===//
320 
321 #define GET_OP_CLASSES
322 #include "mlir/IR/BuiltinOps.cpp.inc"
323