1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
16 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/IR/StandardTypes.h"
21 #include "mlir/Parser.h"
22 #include "mlir/Support/StringExtras.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 namespace mlir {
32 namespace spirv {
33 #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
34 } // namespace spirv
35 } // namespace mlir
36 
37 using namespace mlir;
38 using namespace mlir::spirv;
39 
40 //===----------------------------------------------------------------------===//
41 // InlinerInterface
42 //===----------------------------------------------------------------------===//
43 
44 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
containsReturn(Region & region)45 static inline bool containsReturn(Region &region) {
46   return llvm::any_of(region, [](Block &block) {
47     Operation *terminator = block.getTerminator();
48     return isa<spirv::ReturnOp>(terminator) ||
49            isa<spirv::ReturnValueOp>(terminator);
50   });
51 }
52 
53 namespace {
54 /// This class defines the interface for inlining within the SPIR-V dialect.
55 struct SPIRVInlinerInterface : public DialectInlinerInterface {
56   using DialectInlinerInterface::DialectInlinerInterface;
57 
58   /// Returns true if the given region 'src' can be inlined into the region
59   /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonf10db8e40211::SPIRVInlinerInterface60   bool isLegalToInline(Region *dest, Region *src,
61                        BlockAndValueMapping &) const final {
62     // Return true here when inlining into spv.selection and spv.loop
63     // operations.
64     auto op = dest->getParentOp();
65     return isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op);
66   }
67 
68   /// Returns true if the given operation 'op', that is registered to this
69   /// dialect, can be inlined into the region 'dest' that is attached to an
70   /// operation registered to the current dialect.
isLegalToInline__anonf10db8e40211::SPIRVInlinerInterface71   bool isLegalToInline(Operation *op, Region *dest,
72                        BlockAndValueMapping &) const final {
73     // TODO(antiagainst): Enable inlining structured control flows with return.
74     if ((isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op)) &&
75         containsReturn(op->getRegion(0)))
76       return false;
77     // TODO(antiagainst): we need to filter OpKill here to avoid inlining it to
78     // a loop continue construct:
79     // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
80     // However OpKill is fragment shader specific and we don't support it yet.
81     return true;
82   }
83 
84   /// Handle the given inlined terminator by replacing it with a new operation
85   /// as necessary.
handleTerminator__anonf10db8e40211::SPIRVInlinerInterface86   void handleTerminator(Operation *op, Block *newDest) const final {
87     if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
88       OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
89       op->erase();
90     } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
91       llvm_unreachable("unimplemented spv.ReturnValue in inliner");
92     }
93   }
94 
95   /// Handle the given inlined terminator by replacing it with a new operation
96   /// as necessary.
handleTerminator__anonf10db8e40211::SPIRVInlinerInterface97   void handleTerminator(Operation *op,
98                         ArrayRef<Value> valuesToRepl) const final {
99     // Only spv.ReturnValue needs to be handled here.
100     auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
101     if (!retValOp)
102       return;
103 
104     // Replace the values directly with the return operands.
105     assert(valuesToRepl.size() == 1 &&
106            "spv.ReturnValue expected to only handle one result");
107     valuesToRepl.front().replaceAllUsesWith(retValOp.value());
108   }
109 };
110 } // namespace
111 
112 //===----------------------------------------------------------------------===//
113 // SPIR-V Dialect
114 //===----------------------------------------------------------------------===//
115 
SPIRVDialect(MLIRContext * context)116 SPIRVDialect::SPIRVDialect(MLIRContext *context)
117     : Dialect(getDialectNamespace(), context) {
118   addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
119 
120   // Add SPIR-V ops.
121   addOperations<
122 #define GET_OP_LIST
123 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
124       >();
125 
126   addInterfaces<SPIRVInlinerInterface>();
127 
128   // Allow unknown operations because SPIR-V is extensible.
129   allowUnknownOperations();
130 }
131 
getAttributeName(Decoration decoration)132 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
133   return convertToSnakeCase(stringifyDecoration(decoration));
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Type Parsing
138 //===----------------------------------------------------------------------===//
139 
140 // Forward declarations.
141 template <typename ValTy>
142 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
143                                       DialectAsmParser &parser);
144 template <>
145 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
146                                     DialectAsmParser &parser);
147 
148 template <>
149 Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
150                                             DialectAsmParser &parser);
151 
isValidSPIRVIntType(IntegerType type)152 static bool isValidSPIRVIntType(IntegerType type) {
153   return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}),
154                             type.getWidth());
155 }
156 
isValidScalarType(Type type)157 bool SPIRVDialect::isValidScalarType(Type type) {
158   if (type.isa<FloatType>()) {
159     return !type.isBF16();
160   }
161   if (auto intType = type.dyn_cast<IntegerType>()) {
162     return isValidSPIRVIntType(intType);
163   }
164   return false;
165 }
166 
isValidSPIRVVectorType(VectorType type)167 static bool isValidSPIRVVectorType(VectorType type) {
168   return type.getRank() == 1 &&
169          SPIRVDialect::isValidScalarType(type.getElementType()) &&
170          type.getNumElements() >= 2 && type.getNumElements() <= 4;
171 }
172 
isValidType(Type type)173 bool SPIRVDialect::isValidType(Type type) {
174   // Allow SPIR-V dialect types
175   if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
176       type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
177     return true;
178   }
179   if (SPIRVDialect::isValidScalarType(type)) {
180     return true;
181   }
182   if (auto vectorType = type.dyn_cast<VectorType>()) {
183     return isValidSPIRVVectorType(vectorType);
184   }
185   return false;
186 }
187 
parseAndVerifyType(SPIRVDialect const & dialect,DialectAsmParser & parser)188 static Type parseAndVerifyType(SPIRVDialect const &dialect,
189                                DialectAsmParser &parser) {
190   Type type;
191   llvm::SMLoc typeLoc = parser.getCurrentLocation();
192   if (parser.parseType(type))
193     return Type();
194 
195   // Allow SPIR-V dialect types
196   if (&type.getDialect() == &dialect)
197     return type;
198 
199   // Check other allowed types
200   if (auto t = type.dyn_cast<FloatType>()) {
201     if (type.isBF16()) {
202       parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
203       return Type();
204     }
205   } else if (auto t = type.dyn_cast<IntegerType>()) {
206     if (!isValidSPIRVIntType(t)) {
207       parser.emitError(typeLoc,
208                        "only 1/8/16/32/64-bit integer type allowed but found ")
209           << type;
210       return Type();
211     }
212   } else if (auto t = type.dyn_cast<VectorType>()) {
213     if (t.getRank() != 1) {
214       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215       return Type();
216     }
217     if (t.getNumElements() > 4) {
218       parser.emitError(
219           typeLoc, "vector length has to be less than or equal to 4 but found ")
220           << t.getNumElements();
221       return Type();
222     }
223   } else {
224     parser.emitError(typeLoc, "cannot use ")
225         << type << " to compose SPIR-V types";
226     return Type();
227   }
228 
229   return type;
230 }
231 
232 // element-type ::= integer-type
233 //                | floating-point-type
234 //                | vector-type
235 //                | spirv-type
236 //
237 // array-type ::= `!spv.array<` integer-literal `x` element-type
238 //                (`[` integer-literal `]`)? `>`
parseArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)239 static Type parseArrayType(SPIRVDialect const &dialect,
240                            DialectAsmParser &parser) {
241   if (parser.parseLess())
242     return Type();
243 
244   SmallVector<int64_t, 1> countDims;
245   llvm::SMLoc countLoc = parser.getCurrentLocation();
246   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
247     return Type();
248   if (countDims.size() != 1) {
249     parser.emitError(countLoc,
250                      "expected single integer for array element count");
251     return Type();
252   }
253 
254   // According to the SPIR-V spec:
255   // "Length is the number of elements in the array. It must be at least 1."
256   int64_t count = countDims[0];
257   if (count == 0) {
258     parser.emitError(countLoc, "expected array length greater than 0");
259     return Type();
260   }
261 
262   Type elementType = parseAndVerifyType(dialect, parser);
263   if (!elementType)
264     return Type();
265 
266   ArrayType::LayoutInfo layoutInfo = 0;
267   if (succeeded(parser.parseOptionalLSquare())) {
268     llvm::SMLoc layoutLoc = parser.getCurrentLocation();
269     auto layout = parseAndVerify<ArrayType::LayoutInfo>(dialect, parser);
270     if (!layout)
271       return Type();
272 
273     if (!(layoutInfo = layout.getValue())) {
274       parser.emitError(layoutLoc, "ArrayStride must be greater than zero");
275       return Type();
276     }
277 
278     if (parser.parseRSquare())
279       return Type();
280   }
281 
282   if (parser.parseGreater())
283     return Type();
284   return ArrayType::get(elementType, count, layoutInfo);
285 }
286 
287 // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
288 // methods in alphabetical order
289 //
290 // storage-class ::= `UniformConstant`
291 //                 | `Uniform`
292 //                 | `Workgroup`
293 //                 | <and other storage classes...>
294 //
295 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
parsePointerType(SPIRVDialect const & dialect,DialectAsmParser & parser)296 static Type parsePointerType(SPIRVDialect const &dialect,
297                              DialectAsmParser &parser) {
298   if (parser.parseLess())
299     return Type();
300 
301   auto pointeeType = parseAndVerifyType(dialect, parser);
302   if (!pointeeType)
303     return Type();
304 
305   StringRef storageClassSpec;
306   llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
307   if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
308     return Type();
309 
310   auto storageClass = symbolizeStorageClass(storageClassSpec);
311   if (!storageClass) {
312     parser.emitError(storageClassLoc, "unknown storage class: ")
313         << storageClassSpec;
314     return Type();
315   }
316   if (parser.parseGreater())
317     return Type();
318   return PointerType::get(pointeeType, *storageClass);
319 }
320 
321 // runtime-array-type ::= `!spv.rtarray<` element-type `>`
parseRuntimeArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)322 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
323                                   DialectAsmParser &parser) {
324   if (parser.parseLess())
325     return Type();
326 
327   Type elementType = parseAndVerifyType(dialect, parser);
328   if (!elementType)
329     return Type();
330 
331   if (parser.parseGreater())
332     return Type();
333   return RuntimeArrayType::get(elementType);
334 }
335 
336 // Specialize this function to parse each of the parameters that define an
337 // ImageType. By default it assumes this is an enum type.
338 template <typename ValTy>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)339 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
340                                       DialectAsmParser &parser) {
341   StringRef enumSpec;
342   llvm::SMLoc enumLoc = parser.getCurrentLocation();
343   if (parser.parseKeyword(&enumSpec)) {
344     return llvm::None;
345   }
346 
347   auto val = spirv::symbolizeEnum<ValTy>()(enumSpec);
348   if (!val)
349     parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
350   return val;
351 }
352 
353 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)354 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
355                                     DialectAsmParser &parser) {
356   // TODO(ravishankarm): Further verify that the element type can be sampled
357   auto ty = parseAndVerifyType(dialect, parser);
358   if (!ty)
359     return llvm::None;
360   return ty;
361 }
362 
363 template <typename IntTy>
parseAndVerifyInteger(SPIRVDialect const & dialect,DialectAsmParser & parser)364 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
365                                              DialectAsmParser &parser) {
366   IntTy offsetVal = std::numeric_limits<IntTy>::max();
367   if (parser.parseInteger(offsetVal))
368     return llvm::None;
369   return offsetVal;
370 }
371 
372 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)373 Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
374                                             DialectAsmParser &parser) {
375   return parseAndVerifyInteger<uint64_t>(dialect, parser);
376 }
377 
378 namespace {
379 // Functor object to parse a comma separated list of specs. The function
380 // parseAndVerify does the actual parsing and verification of individual
381 // elements. This is a functor since parsing the last element of the list
382 // (termination condition) needs partial specialization.
383 template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
384   Optional<std::tuple<ParseType, Args...>>
operator ()__anonf10db8e40311::parseCommaSeparatedList385   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
386     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
387     if (!parseVal)
388       return llvm::None;
389 
390     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
391     if (numArgs != 0 && failed(parser.parseComma()))
392       return llvm::None;
393     auto remainingValues = parseCommaSeparatedList<Args...>{}(dialect, parser);
394     if (!remainingValues)
395       return llvm::None;
396     return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
397                           remainingValues.getValue());
398   }
399 };
400 
401 // Partial specialization of the function to parse a comma separated list of
402 // specs to parse the last element of the list.
403 template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
operator ()__anonf10db8e40311::parseCommaSeparatedList404   Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
405                                              DialectAsmParser &parser) const {
406     if (auto value = parseAndVerify<ParseType>(dialect, parser))
407       return std::tuple<ParseType>(value.getValue());
408     return llvm::None;
409   }
410 };
411 } // namespace
412 
413 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
414 //
415 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
416 //
417 // arrayed-info ::= `NonArrayed` | `Arrayed`
418 //
419 // sampling-info ::= `SingleSampled` | `MultiSampled`
420 //
421 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
422 //
423 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
424 //
425 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
426 //                              arrayed-info `,` sampling-info `,`
427 //                              sampler-use-info `,` format `>`
parseImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)428 static Type parseImageType(SPIRVDialect const &dialect,
429                            DialectAsmParser &parser) {
430   if (parser.parseLess())
431     return Type();
432 
433   auto value =
434       parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
435                               ImageSamplingInfo, ImageSamplerUseInfo,
436                               ImageFormat>{}(dialect, parser);
437   if (!value)
438     return Type();
439 
440   if (parser.parseGreater())
441     return Type();
442   return ImageType::get(value.getValue());
443 }
444 
445 // Parse decorations associated with a member.
parseStructMemberDecorations(SPIRVDialect const & dialect,DialectAsmParser & parser,ArrayRef<Type> memberTypes,SmallVectorImpl<StructType::LayoutInfo> & layoutInfo,SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorationInfo)446 static ParseResult parseStructMemberDecorations(
447     SPIRVDialect const &dialect, DialectAsmParser &parser,
448     ArrayRef<Type> memberTypes,
449     SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
450     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
451 
452   // Check if the first element is offset.
453   llvm::SMLoc layoutLoc = parser.getCurrentLocation();
454   StructType::LayoutInfo layout = 0;
455   OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
456   if (layoutParseResult.hasValue()) {
457     if (failed(*layoutParseResult))
458       return failure();
459 
460     if (layoutInfo.size() != memberTypes.size() - 1) {
461       return parser.emitError(
462           layoutLoc, "layout specification must be given for all members");
463     }
464     layoutInfo.push_back(layout);
465   }
466 
467   // Check for no spirv::Decorations.
468   if (succeeded(parser.parseOptionalRSquare()))
469     return success();
470 
471   // If there was a layout, make sure to parse the comma.
472   if (layoutParseResult.hasValue() && parser.parseComma())
473     return failure();
474 
475   // Check for spirv::Decorations.
476   do {
477     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
478     if (!memberDecoration)
479       return failure();
480 
481     memberDecorationInfo.emplace_back(
482         static_cast<uint32_t>(memberTypes.size() - 1),
483         memberDecoration.getValue());
484   } while (succeeded(parser.parseOptionalComma()));
485 
486   return parser.parseRSquare();
487 }
488 
489 // struct-member-decoration ::= integer-literal? spirv-decoration*
490 // struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)?
491 //                     (`, ` spirv-type (`[` struct-member-decoration `]`)? `>`
parseStructType(SPIRVDialect const & dialect,DialectAsmParser & parser)492 static Type parseStructType(SPIRVDialect const &dialect,
493                             DialectAsmParser &parser) {
494   if (parser.parseLess())
495     return Type();
496 
497   if (succeeded(parser.parseOptionalGreater()))
498     return StructType::getEmpty(dialect.getContext());
499 
500   SmallVector<Type, 4> memberTypes;
501   SmallVector<StructType::LayoutInfo, 4> layoutInfo;
502   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
503 
504   do {
505     Type memberType;
506     if (parser.parseType(memberType))
507       return Type();
508     memberTypes.push_back(memberType);
509 
510     if (succeeded(parser.parseOptionalLSquare())) {
511       if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
512                                        memberDecorationInfo)) {
513         return Type();
514       }
515     }
516   } while (succeeded(parser.parseOptionalComma()));
517 
518   if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
519     parser.emitError(parser.getNameLoc(),
520                      "layout specification must be given for all members");
521     return Type();
522   }
523   if (parser.parseGreater())
524     return Type();
525   return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
526 }
527 
528 // spirv-type ::= array-type
529 //              | element-type
530 //              | image-type
531 //              | pointer-type
532 //              | runtime-array-type
533 //              | struct-type
parseType(DialectAsmParser & parser) const534 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
535   StringRef keyword;
536   if (parser.parseKeyword(&keyword))
537     return Type();
538 
539   if (keyword == "array")
540     return parseArrayType(*this, parser);
541   if (keyword == "image")
542     return parseImageType(*this, parser);
543   if (keyword == "ptr")
544     return parsePointerType(*this, parser);
545   if (keyword == "rtarray")
546     return parseRuntimeArrayType(*this, parser);
547   if (keyword == "struct")
548     return parseStructType(*this, parser);
549 
550   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
551   return Type();
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // Type Printing
556 //===----------------------------------------------------------------------===//
557 
print(ArrayType type,DialectAsmPrinter & os)558 static void print(ArrayType type, DialectAsmPrinter &os) {
559   os << "array<" << type.getNumElements() << " x " << type.getElementType();
560   if (type.hasLayout()) {
561     os << " [" << type.getArrayStride() << "]";
562   }
563   os << ">";
564 }
565 
print(RuntimeArrayType type,DialectAsmPrinter & os)566 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
567   os << "rtarray<" << type.getElementType() << ">";
568 }
569 
print(PointerType type,DialectAsmPrinter & os)570 static void print(PointerType type, DialectAsmPrinter &os) {
571   os << "ptr<" << type.getPointeeType() << ", "
572      << stringifyStorageClass(type.getStorageClass()) << ">";
573 }
574 
print(ImageType type,DialectAsmPrinter & os)575 static void print(ImageType type, DialectAsmPrinter &os) {
576   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
577      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
578      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
579      << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
580      << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
581      << stringifyImageFormat(type.getImageFormat()) << ">";
582 }
583 
print(StructType type,DialectAsmPrinter & os)584 static void print(StructType type, DialectAsmPrinter &os) {
585   os << "struct<";
586   auto printMember = [&](unsigned i) {
587     os << type.getElementType(i);
588     SmallVector<spirv::Decoration, 0> decorations;
589     type.getMemberDecorations(i, decorations);
590     if (type.hasLayout() || !decorations.empty()) {
591       os << " [";
592       if (type.hasLayout()) {
593         os << type.getOffset(i);
594         if (!decorations.empty())
595           os << ", ";
596       }
597       auto each_fn = [&os](spirv::Decoration decoration) {
598         os << stringifyDecoration(decoration);
599       };
600       interleaveComma(decorations, os, each_fn);
601       os << "]";
602     }
603   };
604   interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
605                   printMember);
606   os << ">";
607 }
608 
printType(Type type,DialectAsmPrinter & os) const609 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
610   switch (type.getKind()) {
611   case TypeKind::Array:
612     print(type.cast<ArrayType>(), os);
613     return;
614   case TypeKind::Pointer:
615     print(type.cast<PointerType>(), os);
616     return;
617   case TypeKind::RuntimeArray:
618     print(type.cast<RuntimeArrayType>(), os);
619     return;
620   case TypeKind::Image:
621     print(type.cast<ImageType>(), os);
622     return;
623   case TypeKind::Struct:
624     print(type.cast<StructType>(), os);
625     return;
626   default:
627     llvm_unreachable("unhandled SPIR-V type");
628   }
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Constant
633 //===----------------------------------------------------------------------===//
634 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)635 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
636                                              Attribute value, Type type,
637                                              Location loc) {
638   if (!ConstantOp::isBuildableWith(type))
639     return nullptr;
640 
641   return builder.create<spirv::ConstantOp>(loc, type, value);
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Shader Interface ABI
646 //===----------------------------------------------------------------------===//
647 
verifyOperationAttribute(Operation * op,NamedAttribute attribute)648 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
649                                                      NamedAttribute attribute) {
650   StringRef symbol = attribute.first.strref();
651   Attribute attr = attribute.second;
652 
653   // TODO(antiagainst): figure out a way to generate the description from the
654   // StructAttr definition.
655   if (symbol == spirv::getEntryPointABIAttrName()) {
656     if (!attr.isa<spirv::EntryPointABIAttr>())
657       return op->emitError("'")
658              << symbol
659              << "' attribute must be a dictionary attribute containing one "
660                 "32-bit integer elements attribute: 'local_size'";
661   } else if (symbol == spirv::getTargetEnvAttrName()) {
662     if (!attr.isa<spirv::TargetEnvAttr>())
663       return op->emitError("'")
664              << symbol
665              << "' must be a dictionary attribute containing one 32-bit "
666                 "integer attribute 'version', one string array attribute "
667                 "'extensions', and one 32-bit integer array attribute "
668                 "'capabilities'";
669   } else {
670     return op->emitError("found unsupported '")
671            << symbol << "' attribute on operation";
672   }
673 
674   return success();
675 }
676 
677 // Verifies the given SPIR-V `attribute` attached to a region's argument or
678 // result and reports error to the given location if invalid.
679 static LogicalResult
verifyRegionAttribute(Location loc,NamedAttribute attribute,bool forArg)680 verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) {
681   StringRef symbol = attribute.first.strref();
682   Attribute attr = attribute.second;
683 
684   if (symbol != spirv::getInterfaceVarABIAttrName())
685     return emitError(loc, "found unsupported '")
686            << symbol << "' attribute on region "
687            << (forArg ? "argument" : "result");
688 
689   if (!attr.isa<spirv::InterfaceVarABIAttr>())
690     return emitError(loc, "'")
691            << symbol
692            << "' attribute must be a dictionary attribute containing three "
693               "32-bit integer attributes: 'descriptor_set', 'binding', and "
694               "'storage_class'";
695 
696   return success();
697 }
698 
verifyRegionArgAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)699 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
700                                                      unsigned /*regionIndex*/,
701                                                      unsigned /*argIndex*/,
702                                                      NamedAttribute attribute) {
703   return verifyRegionAttribute(op->getLoc(), attribute,
704                                /*forArg=*/true);
705 }
706 
verifyRegionResultAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)707 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
708     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
709     NamedAttribute attribute) {
710   return verifyRegionAttribute(op->getLoc(), attribute,
711                                /*forArg=*/false);
712 }
713