1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionImplementation.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionSupport.h"
12 #include "mlir/IR/SymbolTable.h"
13 
14 using namespace mlir;
15 
parseFunctionArgumentList(OpAsmParser & parser,bool allowAttributes,bool allowVariadic,SmallVectorImpl<OpAsmParser::OperandType> & argNames,SmallVectorImpl<Type> & argTypes,SmallVectorImpl<NamedAttrList> & argAttrs,bool & isVariadic)16 ParseResult mlir::function_like_impl::parseFunctionArgumentList(
17     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
18     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
19     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
20     bool &isVariadic) {
21   if (parser.parseLParen())
22     return failure();
23 
24   // The argument list either has to consistently have ssa-id's followed by
25   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
26   // sometimes not.
27   auto parseArgument = [&]() -> ParseResult {
28     llvm::SMLoc loc = parser.getCurrentLocation();
29 
30     // Parse argument name if present.
31     OpAsmParser::OperandType argument;
32     Type argumentType;
33     if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
34         !argument.name.empty()) {
35       // Reject this if the preceding argument was missing a name.
36       if (argNames.empty() && !argTypes.empty())
37         return parser.emitError(loc, "expected type instead of SSA identifier");
38       argNames.push_back(argument);
39 
40       if (parser.parseColonType(argumentType))
41         return failure();
42     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
43       isVariadic = true;
44       return success();
45     } else if (!argNames.empty()) {
46       // Reject this if the preceding argument had a name.
47       return parser.emitError(loc, "expected SSA identifier");
48     } else if (parser.parseType(argumentType)) {
49       return failure();
50     }
51 
52     // Add the argument type.
53     argTypes.push_back(argumentType);
54 
55     // Parse any argument attributes.
56     NamedAttrList attrs;
57     if (parser.parseOptionalAttrDict(attrs))
58       return failure();
59     if (!allowAttributes && !attrs.empty())
60       return parser.emitError(loc, "expected arguments without attributes");
61     argAttrs.push_back(attrs);
62 
63     // Parse a location if specified.  TODO: Don't drop it on the floor.
64     Optional<Location> explicitLoc;
65     if (!argument.name.empty() &&
66         parser.parseOptionalLocationSpecifier(explicitLoc))
67       return failure();
68 
69     return success();
70   };
71 
72   // Parse the function arguments.
73   isVariadic = false;
74   if (failed(parser.parseOptionalRParen())) {
75     do {
76       unsigned numTypedArguments = argTypes.size();
77       if (parseArgument())
78         return failure();
79 
80       llvm::SMLoc loc = parser.getCurrentLocation();
81       if (argTypes.size() == numTypedArguments &&
82           succeeded(parser.parseOptionalComma()))
83         return parser.emitError(
84             loc, "variadic arguments must be in the end of the argument list");
85     } while (succeeded(parser.parseOptionalComma()));
86     parser.parseRParen();
87   }
88 
89   return success();
90 }
91 
92 /// Parse a function result list.
93 ///
94 ///   function-result-list ::= function-result-list-parens
95 ///                          | non-function-type
96 ///   function-result-list-parens ::= `(` `)`
97 ///                                 | `(` function-result-list-no-parens `)`
98 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
99 ///   function-result ::= type attribute-dict?
100 ///
101 static ParseResult
parseFunctionResultList(OpAsmParser & parser,SmallVectorImpl<Type> & resultTypes,SmallVectorImpl<NamedAttrList> & resultAttrs)102 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
103                         SmallVectorImpl<NamedAttrList> &resultAttrs) {
104   if (failed(parser.parseOptionalLParen())) {
105     // We already know that there is no `(`, so parse a type.
106     // Because there is no `(`, it cannot be a function type.
107     Type ty;
108     if (parser.parseType(ty))
109       return failure();
110     resultTypes.push_back(ty);
111     resultAttrs.emplace_back();
112     return success();
113   }
114 
115   // Special case for an empty set of parens.
116   if (succeeded(parser.parseOptionalRParen()))
117     return success();
118 
119   // Parse individual function results.
120   do {
121     resultTypes.emplace_back();
122     resultAttrs.emplace_back();
123     if (parser.parseType(resultTypes.back()) ||
124         parser.parseOptionalAttrDict(resultAttrs.back())) {
125       return failure();
126     }
127   } while (succeeded(parser.parseOptionalComma()));
128   return parser.parseRParen();
129 }
130 
131 /// Parses a function signature using `parser`. The `allowVariadic` argument
132 /// indicates whether functions with variadic arguments are supported. The
133 /// trailing arguments are populated by this function with names, types and
134 /// attributes of the arguments and those of the results.
parseFunctionSignature(OpAsmParser & parser,bool allowVariadic,SmallVectorImpl<OpAsmParser::OperandType> & argNames,SmallVectorImpl<Type> & argTypes,SmallVectorImpl<NamedAttrList> & argAttrs,bool & isVariadic,SmallVectorImpl<Type> & resultTypes,SmallVectorImpl<NamedAttrList> & resultAttrs)135 ParseResult mlir::function_like_impl::parseFunctionSignature(
136     OpAsmParser &parser, bool allowVariadic,
137     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
138     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
139     bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
140     SmallVectorImpl<NamedAttrList> &resultAttrs) {
141   bool allowArgAttrs = true;
142   if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
143                                 argTypes, argAttrs, isVariadic))
144     return failure();
145   if (succeeded(parser.parseOptionalArrow()))
146     return parseFunctionResultList(parser, resultTypes, resultAttrs);
147   return success();
148 }
149 
150 /// Implementation of `addArgAndResultAttrs` that is attribute list type
151 /// agnostic.
152 template <typename AttrListT, typename AttrArrayBuildFnT>
addArgAndResultAttrsImpl(Builder & builder,OperationState & result,ArrayRef<AttrListT> argAttrs,ArrayRef<AttrListT> resultAttrs,AttrArrayBuildFnT && buildAttrArrayFn)153 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
154                                      ArrayRef<AttrListT> argAttrs,
155                                      ArrayRef<AttrListT> resultAttrs,
156                                      AttrArrayBuildFnT &&buildAttrArrayFn) {
157   auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
158 
159   // Add the attributes to the function arguments.
160   if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
161     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
162     result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts);
163   }
164   // Add the attributes to the function results.
165   if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
166     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
167     result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts);
168   }
169 }
170 
addArgAndResultAttrs(Builder & builder,OperationState & result,ArrayRef<DictionaryAttr> argAttrs,ArrayRef<DictionaryAttr> resultAttrs)171 void mlir::function_like_impl::addArgAndResultAttrs(
172     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
173     ArrayRef<DictionaryAttr> resultAttrs) {
174   auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
175     return ArrayRef<Attribute>(attrs.data(), attrs.size());
176   };
177   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
178 }
addArgAndResultAttrs(Builder & builder,OperationState & result,ArrayRef<NamedAttrList> argAttrs,ArrayRef<NamedAttrList> resultAttrs)179 void mlir::function_like_impl::addArgAndResultAttrs(
180     Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
181     ArrayRef<NamedAttrList> resultAttrs) {
182   MLIRContext *context = builder.getContext();
183   auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
184     return llvm::to_vector<8>(
185         llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
186           return attrList.getDictionary(context);
187         }));
188   };
189   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
190 }
191 
192 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
193 /// to construct the custom function type given lists of input and output types.
parseFunctionLikeOp(OpAsmParser & parser,OperationState & result,bool allowVariadic,FuncTypeBuilder funcTypeBuilder)194 ParseResult mlir::function_like_impl::parseFunctionLikeOp(
195     OpAsmParser &parser, OperationState &result, bool allowVariadic,
196     FuncTypeBuilder funcTypeBuilder) {
197   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
198   SmallVector<NamedAttrList, 4> argAttrs;
199   SmallVector<NamedAttrList, 4> resultAttrs;
200   SmallVector<Type, 4> argTypes;
201   SmallVector<Type, 4> resultTypes;
202   auto &builder = parser.getBuilder();
203 
204   // Parse visibility.
205   impl::parseOptionalVisibilityKeyword(parser, result.attributes);
206 
207   // Parse the name as a symbol.
208   StringAttr nameAttr;
209   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
210                              result.attributes))
211     return failure();
212 
213   // Parse the function signature.
214   llvm::SMLoc signatureLocation = parser.getCurrentLocation();
215   bool isVariadic = false;
216   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
217                              argAttrs, isVariadic, resultTypes, resultAttrs))
218     return failure();
219 
220   std::string errorMessage;
221   Type type = funcTypeBuilder(builder, argTypes, resultTypes,
222                               VariadicFlag(isVariadic), errorMessage);
223   if (!type) {
224     return parser.emitError(signatureLocation)
225            << "failed to construct function type"
226            << (errorMessage.empty() ? "" : ": ") << errorMessage;
227   }
228   result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
229 
230   // If function attributes are present, parse them.
231   NamedAttrList parsedAttributes;
232   llvm::SMLoc attributeDictLocation = parser.getCurrentLocation();
233   if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
234     return failure();
235 
236   // Disallow attributes that are inferred from elsewhere in the attribute
237   // dictionary.
238   for (StringRef disallowed :
239        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
240         getTypeAttrName()}) {
241     if (parsedAttributes.get(disallowed))
242       return parser.emitError(attributeDictLocation, "'")
243              << disallowed
244              << "' is an inferred attribute and should not be specified in the "
245                 "explicit attribute dictionary";
246   }
247   result.attributes.append(parsedAttributes);
248 
249   // Add the attributes to the function arguments.
250   assert(argAttrs.size() == argTypes.size());
251   assert(resultAttrs.size() == resultTypes.size());
252   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
253 
254   // Parse the optional function body. The printer will not print the body if
255   // its empty, so disallow parsing of empty body in the parser.
256   auto *body = result.addRegion();
257   llvm::SMLoc loc = parser.getCurrentLocation();
258   OptionalParseResult parseResult = parser.parseOptionalRegion(
259       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
260       /*enableNameShadowing=*/false);
261   if (parseResult.hasValue()) {
262     if (failed(*parseResult))
263       return failure();
264     // Function body was parsed, make sure its not empty.
265     if (body->empty())
266       return parser.emitError(loc, "expected non-empty function body");
267   }
268   return success();
269 }
270 
271 /// Print a function result list. The provided `attrs` must either be null, or
272 /// contain a set of DictionaryAttrs of the same arity as `types`.
printFunctionResultList(OpAsmPrinter & p,ArrayRef<Type> types,ArrayAttr attrs)273 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
274                                     ArrayAttr attrs) {
275   assert(!types.empty() && "Should not be called for empty result list.");
276   assert((!attrs || attrs.size() == types.size()) &&
277          "Invalid number of attributes.");
278 
279   auto &os = p.getStream();
280   bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
281                      (attrs && !attrs[0].cast<DictionaryAttr>().empty());
282   if (needsParens)
283     os << '(';
284   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
285     p.printType(types[i]);
286     if (attrs)
287       p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
288   });
289   if (needsParens)
290     os << ')';
291 }
292 
293 /// Print the signature of the function-like operation `op`.  Assumes `op` has
294 /// the FunctionLike trait and passed the verification.
printFunctionSignature(OpAsmPrinter & p,Operation * op,ArrayRef<Type> argTypes,bool isVariadic,ArrayRef<Type> resultTypes)295 void mlir::function_like_impl::printFunctionSignature(
296     OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
297     ArrayRef<Type> resultTypes) {
298   Region &body = op->getRegion(0);
299   bool isExternal = body.empty();
300 
301   p << '(';
302   ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
303   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
304     if (i > 0)
305       p << ", ";
306 
307     if (!isExternal) {
308       ArrayRef<NamedAttribute> attrs;
309       if (argAttrs)
310         attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
311       p.printRegionArgument(body.getArgument(i), attrs);
312     } else {
313       p.printType(argTypes[i]);
314       if (argAttrs)
315         p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
316     }
317   }
318 
319   if (isVariadic) {
320     if (!argTypes.empty())
321       p << ", ";
322     p << "...";
323   }
324 
325   p << ')';
326 
327   if (!resultTypes.empty()) {
328     p.getStream() << " -> ";
329     auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
330     printFunctionResultList(p, resultTypes, resultAttrs);
331   }
332 }
333 
334 /// Prints the list of function prefixed with the "attributes" keyword. The
335 /// attributes with names listed in "elided" as well as those used by the
336 /// function-like operation internally are not printed. Nothing is printed
337 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
338 /// passed the verification.
printFunctionAttributes(OpAsmPrinter & p,Operation * op,unsigned numInputs,unsigned numResults,ArrayRef<StringRef> elided)339 void mlir::function_like_impl::printFunctionAttributes(
340     OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
341     ArrayRef<StringRef> elided) {
342   // Print out function attributes, if present.
343   SmallVector<StringRef, 2> ignoredAttrs = {
344       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
345       getArgDictAttrName(), getResultDictAttrName()};
346   ignoredAttrs.append(elided.begin(), elided.end());
347 
348   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
349 }
350 
351 /// Printer implementation for function-like operations.  Accepts lists of
352 /// argument and result types to use while printing.
printFunctionLikeOp(OpAsmPrinter & p,Operation * op,ArrayRef<Type> argTypes,bool isVariadic,ArrayRef<Type> resultTypes)353 void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p,
354                                                    Operation *op,
355                                                    ArrayRef<Type> argTypes,
356                                                    bool isVariadic,
357                                                    ArrayRef<Type> resultTypes) {
358   // Print the operation and the function name.
359   auto funcName =
360       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
361           .getValue();
362   p << op->getName() << ' ';
363 
364   StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
365   if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
366     p << visibility.getValue() << ' ';
367   p.printSymbolName(funcName);
368 
369   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
370   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
371                           {visibilityAttrName});
372   // Print the body if this is not an external function.
373   Region &body = op->getRegion(0);
374   if (!body.empty())
375     p.printRegion(body, /*printEntryBlockArgs=*/false,
376                   /*printBlockTerminators=*/true);
377 }
378