1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/FunctionImplementation.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Interfaces/CallInterfaces.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/bit.h"
33 
34 using namespace mlir;
35 
36 // TODO: generate these strings using ODS.
37 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
38 static constexpr const char kSourceMemoryAccessAttrName[] =
39     "source_memory_access";
40 static constexpr const char kAlignmentAttrName[] = "alignment";
41 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
42 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
43 static constexpr const char kCallee[] = "callee";
44 static constexpr const char kClusterSize[] = "cluster_size";
45 static constexpr const char kControl[] = "control";
46 static constexpr const char kDefaultValueAttrName[] = "default_value";
47 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
48 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
49 static constexpr const char kFnNameAttrName[] = "fn";
50 static constexpr const char kGroupOperationAttrName[] = "group_operation";
51 static constexpr const char kIndicesAttrName[] = "indices";
52 static constexpr const char kInitializerAttrName[] = "initializer";
53 static constexpr const char kInterfaceAttrName[] = "interface";
54 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
55 static constexpr const char kSemanticsAttrName[] = "semantics";
56 static constexpr const char kSpecIdAttrName[] = "spec_id";
57 static constexpr const char kTypeAttrName[] = "type";
58 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
59 static constexpr const char kValueAttrName[] = "value";
60 static constexpr const char kValuesAttrName[] = "values";
61 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
62 
63 //===----------------------------------------------------------------------===//
64 // Common utility functions
65 //===----------------------------------------------------------------------===//
66 
67 /// Returns true if the given op is a function-like op or nested in a
68 /// function-like op without a module-like op in the middle.
isNestedInFunctionLikeOp(Operation * op)69 static bool isNestedInFunctionLikeOp(Operation *op) {
70   if (!op)
71     return false;
72   if (op->hasTrait<OpTrait::SymbolTable>())
73     return false;
74   if (op->hasTrait<OpTrait::FunctionLike>())
75     return true;
76   return isNestedInFunctionLikeOp(op->getParentOp());
77 }
78 
79 /// Returns true if the given op is an module-like op that maintains a symbol
80 /// table.
isDirectInModuleLikeOp(Operation * op)81 static bool isDirectInModuleLikeOp(Operation *op) {
82   return op && op->hasTrait<OpTrait::SymbolTable>();
83 }
84 
extractValueFromConstOp(Operation * op,int32_t & value)85 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
86   auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
87   if (!constOp) {
88     return failure();
89   }
90   auto valueAttr = constOp.value();
91   auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
92   if (!integerValueAttr) {
93     return failure();
94   }
95   value = integerValueAttr.getInt();
96   return success();
97 }
98 
99 template <typename Ty>
100 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)101 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
102                            function_ref<StringRef(Ty)> stringifyFn) {
103   if (enumValues.empty()) {
104     return nullptr;
105   }
106   SmallVector<StringRef, 1> enumValStrs;
107   enumValStrs.reserve(enumValues.size());
108   for (auto val : enumValues) {
109     enumValStrs.emplace_back(stringifyFn(val));
110   }
111   return builder.getStrArrayAttr(enumValStrs);
112 }
113 
114 /// Parses the next string attribute in `parser` as an enumerant of the given
115 /// `EnumClass`.
116 template <typename EnumClass>
117 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())118 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
119                  StringRef attrName = spirv::attributeName<EnumClass>()) {
120   Attribute attrVal;
121   NamedAttrList attr;
122   auto loc = parser.getCurrentLocation();
123   if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
124                             attrName, attr)) {
125     return failure();
126   }
127   if (!attrVal.isa<StringAttr>()) {
128     return parser.emitError(loc, "expected ")
129            << attrName << " attribute specified as string";
130   }
131   auto attrOptional =
132       spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
133   if (!attrOptional) {
134     return parser.emitError(loc, "invalid ")
135            << attrName << " attribute specification: " << attrVal;
136   }
137   value = attrOptional.getValue();
138   return success();
139 }
140 
141 /// Parses the next string attribute in `parser` as an enumerant of the given
142 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
143 /// attribute with the enum class's name as attribute name.
144 template <typename EnumClass>
145 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())146 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
147                  StringRef attrName = spirv::attributeName<EnumClass>()) {
148   if (parseEnumStrAttr(value, parser)) {
149     return failure();
150   }
151   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
152                                    llvm::bit_cast<int32_t>(value)));
153   return success();
154 }
155 
156 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
157 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
158 /// the enum class's name as attribute name.
159 template <typename EnumClass>
160 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())161 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
162                      OperationState &state,
163                      StringRef attrName = spirv::attributeName<EnumClass>()) {
164   if (parseEnumKeywordAttr(value, parser)) {
165     return failure();
166   }
167   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
168                                    llvm::bit_cast<int32_t>(value)));
169   return success();
170 }
171 
172 /// Parses Function, Selection and Loop control attributes. If no control is
173 /// specified, "None" is used as a default.
174 template <typename EnumClass>
175 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())176 parseControlAttribute(OpAsmParser &parser, OperationState &state,
177                       StringRef attrName = spirv::attributeName<EnumClass>()) {
178   if (succeeded(parser.parseOptionalKeyword(kControl))) {
179     EnumClass control;
180     if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
181         parser.parseRParen())
182       return failure();
183     return success();
184   }
185   // Set control to "None" otherwise.
186   Builder builder = parser.getBuilder();
187   state.addAttribute(attrName, builder.getI32IntegerAttr(0));
188   return success();
189 }
190 
191 /// Parses optional memory access attributes attached to a memory access
192 /// operand/pointer. Specifically, parses the following syntax:
193 ///     (`[` memory-access `]`)?
194 /// where:
195 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
196 ///         integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)197 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
198                                                OperationState &state) {
199   // Parse an optional list of attributes staring with '['
200   if (parser.parseOptionalLSquare()) {
201     // Nothing to do
202     return success();
203   }
204 
205   spirv::MemoryAccess memoryAccessAttr;
206   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
207                        kMemoryAccessAttrName)) {
208     return failure();
209   }
210 
211   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
212     // Parse integer attribute for alignment.
213     Attribute alignmentAttr;
214     Type i32Type = parser.getBuilder().getIntegerType(32);
215     if (parser.parseComma() ||
216         parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
217                               state.attributes)) {
218       return failure();
219     }
220   }
221   return parser.parseRSquare();
222 }
223 
224 // TODO Make sure to merge this and the previous function into one template
225 // parameterized by memory access attribute name and alignment. Doing so now
226 // results in VS2017 in producing an internal error (at the call site) that's
227 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)228 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
229                                                      OperationState &state) {
230   // Parse an optional list of attributes staring with '['
231   if (parser.parseOptionalLSquare()) {
232     // Nothing to do
233     return success();
234   }
235 
236   spirv::MemoryAccess memoryAccessAttr;
237   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
238                        kSourceMemoryAccessAttrName)) {
239     return failure();
240   }
241 
242   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
243     // Parse integer attribute for alignment.
244     Attribute alignmentAttr;
245     Type i32Type = parser.getBuilder().getIntegerType(32);
246     if (parser.parseComma() ||
247         parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
248                               state.attributes)) {
249       return failure();
250     }
251   }
252   return parser.parseRSquare();
253 }
254 
255 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)256 static void printMemoryAccessAttribute(
257     MemoryOpTy memoryOp, OpAsmPrinter &printer,
258     SmallVectorImpl<StringRef> &elidedAttrs,
259     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
260     Optional<uint32_t> alignmentAttrValue = None) {
261   // Print optional memory access attribute.
262   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
263                                               : memoryOp.memory_access())) {
264     elidedAttrs.push_back(kMemoryAccessAttrName);
265 
266     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
267 
268     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
269       // Print integer alignment attribute.
270       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
271                                                : memoryOp.alignment())) {
272         elidedAttrs.push_back(kAlignmentAttrName);
273         printer << ", " << alignment;
274       }
275     }
276     printer << "]";
277   }
278   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
279 }
280 
281 // TODO Make sure to merge this and the previous function into one template
282 // parameterized by memory access attribute name and alignment. Doing so now
283 // results in VS2017 in producing an internal error (at the call site) that's
284 // not detailed enough to understand what is happening.
285 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)286 static void printSourceMemoryAccessAttribute(
287     MemoryOpTy memoryOp, OpAsmPrinter &printer,
288     SmallVectorImpl<StringRef> &elidedAttrs,
289     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
290     Optional<uint32_t> alignmentAttrValue = None) {
291 
292   printer << ", ";
293 
294   // Print optional memory access attribute.
295   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
296                                               : memoryOp.memory_access())) {
297     elidedAttrs.push_back(kSourceMemoryAccessAttrName);
298 
299     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
300 
301     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
302       // Print integer alignment attribute.
303       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
304                                                : memoryOp.alignment())) {
305         elidedAttrs.push_back(kSourceAlignmentAttrName);
306         printer << ", " << alignment;
307       }
308     }
309     printer << "]";
310   }
311   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
312 }
313 
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)314 static LogicalResult verifyCastOp(Operation *op,
315                                   bool requireSameBitWidth = true,
316                                   bool skipBitWidthCheck = false) {
317   // Some CastOps have no limit on bit widths for result and operand type.
318   if (skipBitWidthCheck)
319     return success();
320 
321   Type operandType = op->getOperand(0).getType();
322   Type resultType = op->getResult(0).getType();
323 
324   // ODS checks that result type and operand type have the same shape.
325   if (auto vectorType = operandType.dyn_cast<VectorType>()) {
326     operandType = vectorType.getElementType();
327     resultType = resultType.cast<VectorType>().getElementType();
328   }
329 
330   if (auto coopMatrixType =
331           operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
332     operandType = coopMatrixType.getElementType();
333     resultType =
334         resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
335   }
336 
337   auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
338   auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
339   auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
340 
341   if (requireSameBitWidth) {
342     if (!isSameBitWidth) {
343       return op->emitOpError(
344                  "expected the same bit widths for operand type and result "
345                  "type, but provided ")
346              << operandType << " and " << resultType;
347     }
348     return success();
349   }
350 
351   if (isSameBitWidth) {
352     return op->emitOpError(
353                "expected the different bit widths for operand type and result "
354                "type, but provided ")
355            << operandType << " and " << resultType;
356   }
357   return success();
358 }
359 
360 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)361 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
362   // ODS checks for attributes values. Just need to verify that if the
363   // memory-access attribute is Aligned, then the alignment attribute must be
364   // present.
365   auto *op = memoryOp.getOperation();
366   auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
367   if (!memAccessAttr) {
368     // Alignment attribute shouldn't be present if memory access attribute is
369     // not present.
370     if (op->getAttr(kAlignmentAttrName)) {
371       return memoryOp.emitOpError(
372           "invalid alignment specification without aligned memory access "
373           "specification");
374     }
375     return success();
376   }
377 
378   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
379   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
380 
381   if (!memAccess) {
382     return memoryOp.emitOpError("invalid memory access specifier: ")
383            << memAccessVal;
384   }
385 
386   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
387     if (!op->getAttr(kAlignmentAttrName)) {
388       return memoryOp.emitOpError("missing alignment value");
389     }
390   } else {
391     if (op->getAttr(kAlignmentAttrName)) {
392       return memoryOp.emitOpError(
393           "invalid alignment specification with non-aligned memory access "
394           "specification");
395     }
396   }
397   return success();
398 }
399 
400 // TODO Make sure to merge this and the previous function into one template
401 // parameterized by memory access attribute name and alignment. Doing so now
402 // results in VS2017 in producing an internal error (at the call site) that's
403 // not detailed enough to understand what is happening.
404 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)405 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
406   // ODS checks for attributes values. Just need to verify that if the
407   // memory-access attribute is Aligned, then the alignment attribute must be
408   // present.
409   auto *op = memoryOp.getOperation();
410   auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
411   if (!memAccessAttr) {
412     // Alignment attribute shouldn't be present if memory access attribute is
413     // not present.
414     if (op->getAttr(kSourceAlignmentAttrName)) {
415       return memoryOp.emitOpError(
416           "invalid alignment specification without aligned memory access "
417           "specification");
418     }
419     return success();
420   }
421 
422   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
423   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
424 
425   if (!memAccess) {
426     return memoryOp.emitOpError("invalid memory access specifier: ")
427            << memAccessVal;
428   }
429 
430   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
431     if (!op->getAttr(kSourceAlignmentAttrName)) {
432       return memoryOp.emitOpError("missing alignment value");
433     }
434   } else {
435     if (op->getAttr(kSourceAlignmentAttrName)) {
436       return memoryOp.emitOpError(
437           "invalid alignment specification with non-aligned memory access "
438           "specification");
439     }
440   }
441   return success();
442 }
443 
444 template <typename BarrierOp>
verifyMemorySemantics(BarrierOp op)445 static LogicalResult verifyMemorySemantics(BarrierOp op) {
446   // According to the SPIR-V specification:
447   // "Despite being a mask and allowing multiple bits to be combined, it is
448   // invalid for more than one of these four bits to be set: Acquire, Release,
449   // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
450   // Release semantics is done by setting the AcquireRelease bit, not by setting
451   // two bits."
452   auto memorySemantics = op.memory_semantics();
453   auto atMostOneInSet = spirv::MemorySemantics::Acquire |
454                         spirv::MemorySemantics::Release |
455                         spirv::MemorySemantics::AcquireRelease |
456                         spirv::MemorySemantics::SequentiallyConsistent;
457 
458   auto bitCount = llvm::countPopulation(
459       static_cast<uint32_t>(memorySemantics & atMostOneInSet));
460   if (bitCount > 1) {
461     return op.emitError("expected at most one of these four memory constraints "
462                         "to be set: `Acquire`, `Release`,"
463                         "`AcquireRelease` or `SequentiallyConsistent`");
464   }
465   return success();
466 }
467 
468 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)469 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
470                                                    Value val) {
471   // ODS already checks ptr is spirv::PointerType. Just check that the pointee
472   // type of the pointer and the type of the value are the same
473   //
474   // TODO: Check that the value type satisfies restrictions of
475   // SPIR-V OpLoad/OpStore operations
476   if (val.getType() !=
477       ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
478     return op.emitOpError("mismatch in result type and pointer type");
479   }
480   return success();
481 }
482 
483 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)484 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
485                                                         Value ptr, Value val) {
486   auto valType = val.getType();
487   if (auto valVecTy = valType.dyn_cast<VectorType>())
488     valType = valVecTy.getElementType();
489 
490   if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
491     return op.emitOpError("mismatch in result type and pointer type");
492   }
493   return success();
494 }
495 
parseVariableDecorations(OpAsmParser & parser,OperationState & state)496 static ParseResult parseVariableDecorations(OpAsmParser &parser,
497                                             OperationState &state) {
498   auto builtInName = llvm::convertToSnakeFromCamelCase(
499       stringifyDecoration(spirv::Decoration::BuiltIn));
500   if (succeeded(parser.parseOptionalKeyword("bind"))) {
501     Attribute set, binding;
502     // Parse optional descriptor binding
503     auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
504         stringifyDecoration(spirv::Decoration::DescriptorSet));
505     auto bindingName = llvm::convertToSnakeFromCamelCase(
506         stringifyDecoration(spirv::Decoration::Binding));
507     Type i32Type = parser.getBuilder().getIntegerType(32);
508     if (parser.parseLParen() ||
509         parser.parseAttribute(set, i32Type, descriptorSetName,
510                               state.attributes) ||
511         parser.parseComma() ||
512         parser.parseAttribute(binding, i32Type, bindingName,
513                               state.attributes) ||
514         parser.parseRParen()) {
515       return failure();
516     }
517   } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
518     StringAttr builtIn;
519     if (parser.parseLParen() ||
520         parser.parseAttribute(builtIn, builtInName, state.attributes) ||
521         parser.parseRParen()) {
522       return failure();
523     }
524   }
525 
526   // Parse other attributes
527   if (parser.parseOptionalAttrDict(state.attributes))
528     return failure();
529 
530   return success();
531 }
532 
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)533 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
534                                      SmallVectorImpl<StringRef> &elidedAttrs) {
535   // Print optional descriptor binding
536   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
537       stringifyDecoration(spirv::Decoration::DescriptorSet));
538   auto bindingName = llvm::convertToSnakeFromCamelCase(
539       stringifyDecoration(spirv::Decoration::Binding));
540   auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
541   auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
542   if (descriptorSet && binding) {
543     elidedAttrs.push_back(descriptorSetName);
544     elidedAttrs.push_back(bindingName);
545     printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
546             << ")";
547   }
548 
549   // Print BuiltIn attribute if present
550   auto builtInName = llvm::convertToSnakeFromCamelCase(
551       stringifyDecoration(spirv::Decoration::BuiltIn));
552   if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
553     printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
554     elidedAttrs.push_back(builtInName);
555   }
556 
557   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
558 }
559 
560 // Get bit width of types.
getBitWidth(Type type)561 static unsigned getBitWidth(Type type) {
562   if (type.isa<spirv::PointerType>()) {
563     // Just return 64 bits for pointer types for now.
564     // TODO: Make sure not caller relies on the actual pointer width value.
565     return 64;
566   }
567 
568   if (type.isIntOrFloat())
569     return type.getIntOrFloatBitWidth();
570 
571   if (auto vectorType = type.dyn_cast<VectorType>()) {
572     assert(vectorType.getElementType().isIntOrFloat());
573     return vectorType.getNumElements() *
574            vectorType.getElementType().getIntOrFloatBitWidth();
575   }
576   llvm_unreachable("unhandled bit width computation for type");
577 }
578 
579 /// Walks the given type hierarchy with the given indices, potentially down
580 /// to component granularity, to select an element type. Returns null type and
581 /// emits errors with the given loc on failure.
582 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)583 getElementType(Type type, ArrayRef<int32_t> indices,
584                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
585   if (indices.empty()) {
586     emitErrorFn("expected at least one index for spv.CompositeExtract");
587     return nullptr;
588   }
589 
590   for (auto index : indices) {
591     if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
592       if (cType.hasCompileTimeKnownNumElements() &&
593           (index < 0 ||
594            static_cast<uint64_t>(index) >= cType.getNumElements())) {
595         emitErrorFn("index ") << index << " out of bounds for " << type;
596         return nullptr;
597       }
598       type = cType.getElementType(index);
599     } else {
600       emitErrorFn("cannot extract from non-composite type ")
601           << type << " with index " << index;
602       return nullptr;
603     }
604   }
605   return type;
606 }
607 
608 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)609 getElementType(Type type, Attribute indices,
610                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
611   auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
612   if (!indicesArrayAttr) {
613     emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
614     return nullptr;
615   }
616   if (!indicesArrayAttr.size()) {
617     emitErrorFn("expected at least one index for spv.CompositeExtract");
618     return nullptr;
619   }
620 
621   SmallVector<int32_t, 2> indexVals;
622   for (auto indexAttr : indicesArrayAttr) {
623     auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
624     if (!indexIntAttr) {
625       emitErrorFn("expected an 32-bit integer for index, but found '")
626           << indexAttr << "'";
627       return nullptr;
628     }
629     indexVals.push_back(indexIntAttr.getInt());
630   }
631   return getElementType(type, indexVals, emitErrorFn);
632 }
633 
getElementType(Type type,Attribute indices,Location loc)634 static Type getElementType(Type type, Attribute indices, Location loc) {
635   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
636     return ::mlir::emitError(loc, err);
637   };
638   return getElementType(type, indices, errorFn);
639 }
640 
getElementType(Type type,Attribute indices,OpAsmParser & parser,llvm::SMLoc loc)641 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
642                            llvm::SMLoc loc) {
643   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
644     return parser.emitError(loc, err);
645   };
646   return getElementType(type, indices, errorFn);
647 }
648 
649 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)650 static inline bool isMergeBlock(Block &block) {
651   return !block.empty() && std::next(block.begin()) == block.end() &&
652          isa<spirv::MergeOp>(block.front());
653 }
654 
655 //===----------------------------------------------------------------------===//
656 // Common parsers and printers
657 //===----------------------------------------------------------------------===//
658 
659 // Parses an atomic update op. If the update op does not take a value (like
660 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)661 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
662                                        OperationState &state, bool hasValue) {
663   spirv::Scope scope;
664   spirv::MemorySemantics memoryScope;
665   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
666   OpAsmParser::OperandType ptrInfo, valueInfo;
667   Type type;
668   llvm::SMLoc loc;
669   if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
670       parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
671       parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
672       parser.getCurrentLocation(&loc) || parser.parseColonType(type))
673     return failure();
674 
675   auto ptrType = type.dyn_cast<spirv::PointerType>();
676   if (!ptrType)
677     return parser.emitError(loc, "expected pointer type");
678 
679   SmallVector<Type, 2> operandTypes;
680   operandTypes.push_back(ptrType);
681   if (hasValue)
682     operandTypes.push_back(ptrType.getPointeeType());
683   if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
684                              state.operands))
685     return failure();
686   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
687 }
688 
689 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)690 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
691   printer << op->getName() << " \"";
692   auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
693   printer << spirv::stringifyScope(
694                  static_cast<spirv::Scope>(scopeAttr.getInt()))
695           << "\" \"";
696   auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
697   printer << spirv::stringifyMemorySemantics(
698                  static_cast<spirv::MemorySemantics>(
699                      memorySemanticsAttr.getInt()))
700           << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
701 }
702 
703 // Verifies an atomic update op.
verifyAtomicUpdateOp(Operation * op)704 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
705   auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
706   auto elementType = ptrType.getPointeeType();
707   if (!elementType.isa<IntegerType>())
708     return op->emitOpError(
709                "pointer operand must point to an integer value, found ")
710            << elementType;
711 
712   if (op->getNumOperands() > 1) {
713     auto valueType = op->getOperand(1).getType();
714     if (valueType != elementType)
715       return op->emitOpError("expected value to have the same type as the "
716                              "pointer operand's pointee type ")
717              << elementType << ", but found " << valueType;
718   }
719   return success();
720 }
721 
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)722 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
723                                                     OperationState &state) {
724   spirv::Scope executionScope;
725   spirv::GroupOperation groupOperation;
726   OpAsmParser::OperandType valueInfo;
727   if (parseEnumStrAttr(executionScope, parser, state,
728                        kExecutionScopeAttrName) ||
729       parseEnumStrAttr(groupOperation, parser, state,
730                        kGroupOperationAttrName) ||
731       parser.parseOperand(valueInfo))
732     return failure();
733 
734   Optional<OpAsmParser::OperandType> clusterSizeInfo;
735   if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
736     clusterSizeInfo = OpAsmParser::OperandType();
737     if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
738         parser.parseRParen())
739       return failure();
740   }
741 
742   Type resultType;
743   if (parser.parseColonType(resultType))
744     return failure();
745 
746   if (parser.resolveOperand(valueInfo, resultType, state.operands))
747     return failure();
748 
749   if (clusterSizeInfo.hasValue()) {
750     Type i32Type = parser.getBuilder().getIntegerType(32);
751     if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
752       return failure();
753   }
754 
755   return parser.addTypeToList(resultType, state.types);
756 }
757 
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)758 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
759                                              OpAsmPrinter &printer) {
760   printer << groupOp->getName() << " \""
761           << stringifyScope(static_cast<spirv::Scope>(
762                  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
763                      .getInt()))
764           << "\" \""
765           << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
766                  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
767                      .getInt()))
768           << "\" " << groupOp->getOperand(0);
769 
770   if (groupOp->getNumOperands() > 1)
771     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
772   printer << " : " << groupOp->getResult(0).getType();
773 }
774 
verifyGroupNonUniformArithmeticOp(Operation * groupOp)775 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
776   spirv::Scope scope = static_cast<spirv::Scope>(
777       groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
778   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
779     return groupOp->emitOpError(
780         "execution scope must be 'Workgroup' or 'Subgroup'");
781 
782   spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
783       groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
784   if (operation == spirv::GroupOperation::ClusteredReduce &&
785       groupOp->getNumOperands() == 1)
786     return groupOp->emitOpError("cluster size operand must be provided for "
787                                 "'ClusteredReduce' group operation");
788   if (groupOp->getNumOperands() > 1) {
789     Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
790     int32_t clusterSize = 0;
791 
792     // TODO: support specialization constant here.
793     if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
794       return groupOp->emitOpError(
795           "cluster size operand must come from a constant op");
796 
797     if (!llvm::isPowerOf2_32(clusterSize))
798       return groupOp->emitOpError(
799           "cluster size operand must be a power of two");
800   }
801   return success();
802 }
803 
parseUnaryOp(OpAsmParser & parser,OperationState & state)804 static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
805   OpAsmParser::OperandType operandInfo;
806   Type type;
807   if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
808       parser.resolveOperands(operandInfo, type, state.operands)) {
809     return failure();
810   }
811   state.addTypes(type);
812   return success();
813 }
814 
printUnaryOp(Operation * unaryOp,OpAsmPrinter & printer)815 static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
816   printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
817           << unaryOp->getOperand(0).getType();
818 }
819 
820 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Builder & builder,Type operandType)821 static Type getUnaryOpResultType(Builder &builder, Type operandType) {
822   Type resultType = builder.getIntegerType(1);
823   if (auto vecType = operandType.dyn_cast<VectorType>()) {
824     return VectorType::get(vecType.getNumElements(), resultType);
825   }
826   return resultType;
827 }
828 
parseLogicalUnaryOp(OpAsmParser & parser,OperationState & state)829 static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
830                                        OperationState &state) {
831   OpAsmParser::OperandType operandInfo;
832   Type type;
833   if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
834       parser.resolveOperand(operandInfo, type, state.operands)) {
835     return failure();
836   }
837   state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
838   return success();
839 }
840 
parseLogicalBinaryOp(OpAsmParser & parser,OperationState & result)841 static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
842                                         OperationState &result) {
843   SmallVector<OpAsmParser::OperandType, 2> ops;
844   Type type;
845   if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
846       parser.resolveOperands(ops, type, result.operands)) {
847     return failure();
848   }
849   result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
850   return success();
851 }
852 
printLogicalOp(Operation * logicalOp,OpAsmPrinter & printer)853 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
854   printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
855           << logicalOp->getOperand(0).getType();
856 }
857 
parseShiftOp(OpAsmParser & parser,OperationState & state)858 static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
859   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
860   Type baseType;
861   Type shiftType;
862   auto loc = parser.getCurrentLocation();
863 
864   if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
865       parser.parseType(baseType) || parser.parseComma() ||
866       parser.parseType(shiftType) ||
867       parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
868                              state.operands)) {
869     return failure();
870   }
871   state.addTypes(baseType);
872   return success();
873 }
874 
printShiftOp(Operation * op,OpAsmPrinter & printer)875 static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
876   Value base = op->getOperand(0);
877   Value shift = op->getOperand(1);
878   printer << op->getName() << ' ' << base << ", " << shift << " : "
879           << base.getType() << ", " << shift.getType();
880 }
881 
verifyShiftOp(Operation * op)882 static LogicalResult verifyShiftOp(Operation *op) {
883   if (op->getOperand(0).getType() != op->getResult(0).getType()) {
884     return op->emitError("expected the same type for the first operand and "
885                          "result, but provided ")
886            << op->getOperand(0).getType() << " and "
887            << op->getResult(0).getType();
888   }
889   return success();
890 }
891 
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)892 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
893                                  Value lhs, Value rhs) {
894   assert(lhs.getType() == rhs.getType());
895 
896   Type boolType = builder.getI1Type();
897   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
898     boolType = VectorType::get(vecType.getShape(), boolType);
899   state.addTypes(boolType);
900 
901   state.addOperands({lhs, rhs});
902 }
903 
buildLogicalUnaryOp(OpBuilder & builder,OperationState & state,Value value)904 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
905                                 Value value) {
906   Type boolType = builder.getI1Type();
907   if (auto vecType = value.getType().dyn_cast<VectorType>())
908     boolType = VectorType::get(vecType.getShape(), boolType);
909   state.addTypes(boolType);
910 
911   state.addOperands(value);
912 }
913 
914 //===----------------------------------------------------------------------===//
915 // spv.AccessChainOp
916 //===----------------------------------------------------------------------===//
917 
getElementPtrType(Type type,ValueRange indices,Location baseLoc)918 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
919   auto ptrType = type.dyn_cast<spirv::PointerType>();
920   if (!ptrType) {
921     emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
922                        "to composite type, but provided ")
923         << type;
924     return nullptr;
925   }
926 
927   auto resultType = ptrType.getPointeeType();
928   auto resultStorageClass = ptrType.getStorageClass();
929   int32_t index = 0;
930 
931   for (auto indexSSA : indices) {
932     auto cType = resultType.dyn_cast<spirv::CompositeType>();
933     if (!cType) {
934       emitError(baseLoc,
935                 "'spv.AccessChain' op cannot extract from non-composite type ")
936           << resultType << " with index " << index;
937       return nullptr;
938     }
939     index = 0;
940     if (resultType.isa<spirv::StructType>()) {
941       Operation *op = indexSSA.getDefiningOp();
942       if (!op) {
943         emitError(baseLoc, "'spv.AccessChain' op index must be an "
944                            "integer spv.Constant to access "
945                            "element of spv.struct");
946         return nullptr;
947       }
948 
949       // TODO: this should be relaxed to allow
950       // integer literals of other bitwidths.
951       if (failed(extractValueFromConstOp(op, index))) {
952         emitError(baseLoc,
953                   "'spv.AccessChain' index must be an integer spv.Constant to "
954                   "access element of spv.struct, but provided ")
955             << op->getName();
956         return nullptr;
957       }
958       if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
959         emitError(baseLoc, "'spv.AccessChain' op index ")
960             << index << " out of bounds for " << resultType;
961         return nullptr;
962       }
963     }
964     resultType = cType.getElementType(index);
965   }
966   return spirv::PointerType::get(resultType, resultStorageClass);
967 }
968 
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)969 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
970                                  Value basePtr, ValueRange indices) {
971   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
972   assert(type && "Unable to deduce return type based on basePtr and indices");
973   build(builder, state, type, basePtr, indices);
974 }
975 
parseAccessChainOp(OpAsmParser & parser,OperationState & state)976 static ParseResult parseAccessChainOp(OpAsmParser &parser,
977                                       OperationState &state) {
978   OpAsmParser::OperandType ptrInfo;
979   SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
980   Type type;
981   auto loc = parser.getCurrentLocation();
982   SmallVector<Type, 4> indicesTypes;
983 
984   if (parser.parseOperand(ptrInfo) ||
985       parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
986       parser.parseColonType(type) ||
987       parser.resolveOperand(ptrInfo, type, state.operands)) {
988     return failure();
989   }
990 
991   // Check that the provided indices list is not empty before parsing their
992   // type list.
993   if (indicesInfo.empty()) {
994     return emitError(state.location, "'spv.AccessChain' op expected at "
995                                      "least one index ");
996   }
997 
998   if (parser.parseComma() || parser.parseTypeList(indicesTypes))
999     return failure();
1000 
1001   // Check that the indices types list is not empty and that it has a one-to-one
1002   // mapping to the provided indices.
1003   if (indicesTypes.size() != indicesInfo.size()) {
1004     return emitError(state.location, "'spv.AccessChain' op indices "
1005                                      "types' count must be equal to indices "
1006                                      "info count");
1007   }
1008 
1009   if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1010     return failure();
1011 
1012   auto resultType = getElementPtrType(
1013       type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1014   if (!resultType) {
1015     return failure();
1016   }
1017 
1018   state.addTypes(resultType);
1019   return success();
1020 }
1021 
print(spirv::AccessChainOp op,OpAsmPrinter & printer)1022 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
1023   printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
1024           << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
1025           << op.indices().getTypes();
1026 }
1027 
verify(spirv::AccessChainOp accessChainOp)1028 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
1029   SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
1030                                 accessChainOp.indices().end());
1031   auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1032                                       indices, accessChainOp.getLoc());
1033   if (!resultType) {
1034     return failure();
1035   }
1036 
1037   auto providedResultType =
1038       accessChainOp.getType().dyn_cast<spirv::PointerType>();
1039   if (!providedResultType) {
1040     return accessChainOp.emitOpError(
1041                "result type must be a pointer, but provided")
1042            << providedResultType;
1043   }
1044 
1045   if (resultType != providedResultType) {
1046     return accessChainOp.emitOpError("invalid result type: expected ")
1047            << resultType << ", but provided " << providedResultType;
1048   }
1049 
1050   return success();
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // spv.mlir.addressof
1055 //===----------------------------------------------------------------------===//
1056 
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1057 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1058                                spirv::GlobalVariableOp var) {
1059   build(builder, state, var.type(), builder.getSymbolRefAttr(var));
1060 }
1061 
verify(spirv::AddressOfOp addressOfOp)1062 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
1063   auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1064       SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
1065                                            addressOfOp.variable()));
1066   if (!varOp) {
1067     return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
1068   }
1069   if (addressOfOp.pointer().getType() != varOp.type()) {
1070     return addressOfOp.emitOpError(
1071         "result type mismatch with the referenced global variable's type");
1072   }
1073   return success();
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // spv.AtomicCompareExchangeWeak
1078 //===----------------------------------------------------------------------===//
1079 
parseAtomicCompareExchangeWeakOp(OpAsmParser & parser,OperationState & state)1080 static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
1081                                                     OperationState &state) {
1082   spirv::Scope memoryScope;
1083   spirv::MemorySemantics equalSemantics, unequalSemantics;
1084   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
1085   Type type;
1086   if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1087       parseEnumStrAttr(equalSemantics, parser, state,
1088                        kEqualSemanticsAttrName) ||
1089       parseEnumStrAttr(unequalSemantics, parser, state,
1090                        kUnequalSemanticsAttrName) ||
1091       parser.parseOperandList(operandInfo, 3))
1092     return failure();
1093 
1094   auto loc = parser.getCurrentLocation();
1095   if (parser.parseColonType(type))
1096     return failure();
1097 
1098   auto ptrType = type.dyn_cast<spirv::PointerType>();
1099   if (!ptrType)
1100     return parser.emitError(loc, "expected pointer type");
1101 
1102   if (parser.resolveOperands(
1103           operandInfo,
1104           {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1105           parser.getNameLoc(), state.operands))
1106     return failure();
1107 
1108   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1109 }
1110 
print(spirv::AtomicCompareExchangeWeakOp atomOp,OpAsmPrinter & printer)1111 static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
1112                   OpAsmPrinter &printer) {
1113   printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
1114           << stringifyScope(atomOp.memory_scope()) << "\" \""
1115           << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1116           << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1117           << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1118 }
1119 
verify(spirv::AtomicCompareExchangeWeakOp atomOp)1120 static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
1121   // According to the spec:
1122   // "The type of Value must be the same as Result Type. The type of the value
1123   // pointed to by Pointer must be the same as Result Type. This type must also
1124   // match the type of Comparator."
1125   if (atomOp.getType() != atomOp.value().getType())
1126     return atomOp.emitOpError("value operand must have the same type as the op "
1127                               "result, but found ")
1128            << atomOp.value().getType() << " vs " << atomOp.getType();
1129 
1130   if (atomOp.getType() != atomOp.comparator().getType())
1131     return atomOp.emitOpError(
1132                "comparator operand must have the same type as the op "
1133                "result, but found ")
1134            << atomOp.comparator().getType() << " vs " << atomOp.getType();
1135 
1136   Type pointeeType =
1137       atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
1138   if (atomOp.getType() != pointeeType)
1139     return atomOp.emitOpError(
1140                "pointer operand's pointee type must have the same "
1141                "as the op result type, but found ")
1142            << pointeeType << " vs " << atomOp.getType();
1143 
1144   // TODO: Unequal cannot be set to Release or Acquire and Release.
1145   // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1146 
1147   return success();
1148 }
1149 
1150 //===----------------------------------------------------------------------===//
1151 // spv.BitcastOp
1152 //===----------------------------------------------------------------------===//
1153 
verify(spirv::BitcastOp bitcastOp)1154 static LogicalResult verify(spirv::BitcastOp bitcastOp) {
1155   // TODO: The SPIR-V spec validation rules are different for different
1156   // versions.
1157   auto operandType = bitcastOp.operand().getType();
1158   auto resultType = bitcastOp.result().getType();
1159   if (operandType == resultType) {
1160     return bitcastOp.emitError(
1161         "result type must be different from operand type");
1162   }
1163   if (operandType.isa<spirv::PointerType>() &&
1164       !resultType.isa<spirv::PointerType>()) {
1165     return bitcastOp.emitError(
1166         "unhandled bit cast conversion from pointer type to non-pointer type");
1167   }
1168   if (!operandType.isa<spirv::PointerType>() &&
1169       resultType.isa<spirv::PointerType>()) {
1170     return bitcastOp.emitError(
1171         "unhandled bit cast conversion from non-pointer type to pointer type");
1172   }
1173   auto operandBitWidth = getBitWidth(operandType);
1174   auto resultBitWidth = getBitWidth(resultType);
1175   if (operandBitWidth != resultBitWidth) {
1176     return bitcastOp.emitOpError("mismatch in result type bitwidth ")
1177            << resultBitWidth << " and operand type bitwidth "
1178            << operandBitWidth;
1179   }
1180   return success();
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // spv.BranchOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1188 spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
1189   assert(index == 0 && "invalid successor index");
1190   return targetOperandsMutable();
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // spv.BranchConditionalOp
1195 //===----------------------------------------------------------------------===//
1196 
1197 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1198 spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
1199   assert(index < 2 && "invalid successor index");
1200   return index == kTrueIndex ? trueTargetOperandsMutable()
1201                              : falseTargetOperandsMutable();
1202 }
1203 
parseBranchConditionalOp(OpAsmParser & parser,OperationState & state)1204 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
1205                                             OperationState &state) {
1206   auto &builder = parser.getBuilder();
1207   OpAsmParser::OperandType condInfo;
1208   Block *dest;
1209 
1210   // Parse the condition.
1211   Type boolTy = builder.getI1Type();
1212   if (parser.parseOperand(condInfo) ||
1213       parser.resolveOperand(condInfo, boolTy, state.operands))
1214     return failure();
1215 
1216   // Parse the optional branch weights.
1217   if (succeeded(parser.parseOptionalLSquare())) {
1218     IntegerAttr trueWeight, falseWeight;
1219     NamedAttrList weights;
1220 
1221     auto i32Type = builder.getIntegerType(32);
1222     if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1223         parser.parseComma() ||
1224         parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1225         parser.parseRSquare())
1226       return failure();
1227 
1228     state.addAttribute(kBranchWeightAttrName,
1229                        builder.getArrayAttr({trueWeight, falseWeight}));
1230   }
1231 
1232   // Parse the true branch.
1233   SmallVector<Value, 4> trueOperands;
1234   if (parser.parseComma() ||
1235       parser.parseSuccessorAndUseList(dest, trueOperands))
1236     return failure();
1237   state.addSuccessors(dest);
1238   state.addOperands(trueOperands);
1239 
1240   // Parse the false branch.
1241   SmallVector<Value, 4> falseOperands;
1242   if (parser.parseComma() ||
1243       parser.parseSuccessorAndUseList(dest, falseOperands))
1244     return failure();
1245   state.addSuccessors(dest);
1246   state.addOperands(falseOperands);
1247   state.addAttribute(
1248       spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1249       builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1250                                 static_cast<int32_t>(falseOperands.size())}));
1251 
1252   return success();
1253 }
1254 
print(spirv::BranchConditionalOp branchOp,OpAsmPrinter & printer)1255 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
1256   printer << spirv::BranchConditionalOp::getOperationName() << ' '
1257           << branchOp.condition();
1258 
1259   if (auto weights = branchOp.branch_weights()) {
1260     printer << " [";
1261     llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1262       printer << a.cast<IntegerAttr>().getInt();
1263     });
1264     printer << "]";
1265   }
1266 
1267   printer << ", ";
1268   printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
1269                                    branchOp.getTrueBlockArguments());
1270   printer << ", ";
1271   printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
1272                                    branchOp.getFalseBlockArguments());
1273 }
1274 
verify(spirv::BranchConditionalOp branchOp)1275 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1276   if (auto weights = branchOp.branch_weights()) {
1277     if (weights->getValue().size() != 2) {
1278       return branchOp.emitOpError("must have exactly two branch weights");
1279     }
1280     if (llvm::all_of(*weights, [](Attribute attr) {
1281           return attr.cast<IntegerAttr>().getValue().isNullValue();
1282         }))
1283       return branchOp.emitOpError("branch weights cannot both be zero");
1284   }
1285 
1286   return success();
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // spv.CompositeConstruct
1291 //===----------------------------------------------------------------------===//
1292 
parseCompositeConstructOp(OpAsmParser & parser,OperationState & state)1293 static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
1294                                              OperationState &state) {
1295   SmallVector<OpAsmParser::OperandType, 4> operands;
1296   Type type;
1297   auto loc = parser.getCurrentLocation();
1298 
1299   if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1300     return failure();
1301   }
1302   auto cType = type.dyn_cast<spirv::CompositeType>();
1303   if (!cType) {
1304     return parser.emitError(
1305                loc, "result type must be a composite type, but provided ")
1306            << type;
1307   }
1308 
1309   if (cType.hasCompileTimeKnownNumElements() &&
1310       operands.size() != cType.getNumElements()) {
1311     return parser.emitError(loc, "has incorrect number of operands: expected ")
1312            << cType.getNumElements() << ", but provided " << operands.size();
1313   }
1314   // TODO: Add support for constructing a vector type from the vector operands.
1315   // According to the spec: "for constructing a vector, the operands may
1316   // also be vectors with the same component type as the Result Type component
1317   // type".
1318   SmallVector<Type, 4> elementTypes;
1319   elementTypes.reserve(operands.size());
1320   for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1321     elementTypes.push_back(cType.getElementType(index));
1322   }
1323   state.addTypes(type);
1324   return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1325 }
1326 
print(spirv::CompositeConstructOp compositeConstructOp,OpAsmPrinter & printer)1327 static void print(spirv::CompositeConstructOp compositeConstructOp,
1328                   OpAsmPrinter &printer) {
1329   printer << spirv::CompositeConstructOp::getOperationName() << " "
1330           << compositeConstructOp.constituents() << " : "
1331           << compositeConstructOp.getResult().getType();
1332 }
1333 
verify(spirv::CompositeConstructOp compositeConstructOp)1334 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1335   auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1336   SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
1337 
1338   if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1339     if (constituents.size() != 1)
1340       return compositeConstructOp.emitError(
1341                  "has incorrect number of operands: expected ")
1342              << "1, but provided " << constituents.size();
1343   } else if (constituents.size() != cType.getNumElements()) {
1344     return compositeConstructOp.emitError(
1345                "has incorrect number of operands: expected ")
1346            << cType.getNumElements() << ", but provided "
1347            << constituents.size();
1348   }
1349 
1350   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1351     if (constituents[index].getType() != cType.getElementType(index)) {
1352       return compositeConstructOp.emitError(
1353                  "operand type mismatch: expected operand type ")
1354              << cType.getElementType(index) << ", but provided "
1355              << constituents[index].getType();
1356     }
1357   }
1358 
1359   return success();
1360 }
1361 
1362 //===----------------------------------------------------------------------===//
1363 // spv.CompositeExtractOp
1364 //===----------------------------------------------------------------------===//
1365 
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1366 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1367                                       Value composite,
1368                                       ArrayRef<int32_t> indices) {
1369   auto indexAttr = builder.getI32ArrayAttr(indices);
1370   auto elementType =
1371       getElementType(composite.getType(), indexAttr, state.location);
1372   if (!elementType) {
1373     return;
1374   }
1375   build(builder, state, elementType, composite, indexAttr);
1376 }
1377 
parseCompositeExtractOp(OpAsmParser & parser,OperationState & state)1378 static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
1379                                            OperationState &state) {
1380   OpAsmParser::OperandType compositeInfo;
1381   Attribute indicesAttr;
1382   Type compositeType;
1383   llvm::SMLoc attrLocation;
1384 
1385   if (parser.parseOperand(compositeInfo) ||
1386       parser.getCurrentLocation(&attrLocation) ||
1387       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1388       parser.parseColonType(compositeType) ||
1389       parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1390     return failure();
1391   }
1392 
1393   Type resultType =
1394       getElementType(compositeType, indicesAttr, parser, attrLocation);
1395   if (!resultType) {
1396     return failure();
1397   }
1398   state.addTypes(resultType);
1399   return success();
1400 }
1401 
print(spirv::CompositeExtractOp compositeExtractOp,OpAsmPrinter & printer)1402 static void print(spirv::CompositeExtractOp compositeExtractOp,
1403                   OpAsmPrinter &printer) {
1404   printer << spirv::CompositeExtractOp::getOperationName() << ' '
1405           << compositeExtractOp.composite() << compositeExtractOp.indices()
1406           << " : " << compositeExtractOp.composite().getType();
1407 }
1408 
verify(spirv::CompositeExtractOp compExOp)1409 static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
1410   auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
1411   auto resultType = getElementType(compExOp.composite().getType(),
1412                                    indicesArrayAttr, compExOp.getLoc());
1413   if (!resultType)
1414     return failure();
1415 
1416   if (resultType != compExOp.getType()) {
1417     return compExOp.emitOpError("invalid result type: expected ")
1418            << resultType << " but provided " << compExOp.getType();
1419   }
1420 
1421   return success();
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // spv.CompositeInsert
1426 //===----------------------------------------------------------------------===//
1427 
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1428 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1429                                      Value object, Value composite,
1430                                      ArrayRef<int32_t> indices) {
1431   auto indexAttr = builder.getI32ArrayAttr(indices);
1432   build(builder, state, composite.getType(), object, composite, indexAttr);
1433 }
1434 
parseCompositeInsertOp(OpAsmParser & parser,OperationState & state)1435 static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
1436                                           OperationState &state) {
1437   SmallVector<OpAsmParser::OperandType, 2> operands;
1438   Type objectType, compositeType;
1439   Attribute indicesAttr;
1440   auto loc = parser.getCurrentLocation();
1441 
1442   return failure(
1443       parser.parseOperandList(operands, 2) ||
1444       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1445       parser.parseColonType(objectType) ||
1446       parser.parseKeywordType("into", compositeType) ||
1447       parser.resolveOperands(operands, {objectType, compositeType}, loc,
1448                              state.operands) ||
1449       parser.addTypesToList(compositeType, state.types));
1450 }
1451 
verify(spirv::CompositeInsertOp compositeInsertOp)1452 static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
1453   auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
1454   auto objectType =
1455       getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
1456                      compositeInsertOp.getLoc());
1457   if (!objectType)
1458     return failure();
1459 
1460   if (objectType != compositeInsertOp.object().getType()) {
1461     return compositeInsertOp.emitOpError("object operand type should be ")
1462            << objectType << ", but found "
1463            << compositeInsertOp.object().getType();
1464   }
1465 
1466   if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
1467     return compositeInsertOp.emitOpError("result type should be the same as "
1468                                          "the composite type, but found ")
1469            << compositeInsertOp.composite().getType() << " vs "
1470            << compositeInsertOp.getType();
1471   }
1472 
1473   return success();
1474 }
1475 
print(spirv::CompositeInsertOp compositeInsertOp,OpAsmPrinter & printer)1476 static void print(spirv::CompositeInsertOp compositeInsertOp,
1477                   OpAsmPrinter &printer) {
1478   printer << spirv::CompositeInsertOp::getOperationName() << " "
1479           << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
1480           << compositeInsertOp.indices() << " : "
1481           << compositeInsertOp.object().getType() << " into "
1482           << compositeInsertOp.composite().getType();
1483 }
1484 
1485 //===----------------------------------------------------------------------===//
1486 // spv.Constant
1487 //===----------------------------------------------------------------------===//
1488 
parseConstantOp(OpAsmParser & parser,OperationState & state)1489 static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
1490   Attribute value;
1491   if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1492     return failure();
1493 
1494   Type type = value.getType();
1495   if (type.isa<NoneType, TensorType>()) {
1496     if (parser.parseColonType(type))
1497       return failure();
1498   }
1499 
1500   return parser.addTypeToList(type, state.types);
1501 }
1502 
print(spirv::ConstantOp constOp,OpAsmPrinter & printer)1503 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
1504   printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
1505   if (constOp.getType().isa<spirv::ArrayType>())
1506     printer << " : " << constOp.getType();
1507 }
1508 
verify(spirv::ConstantOp constOp)1509 static LogicalResult verify(spirv::ConstantOp constOp) {
1510   auto opType = constOp.getType();
1511   auto value = constOp.value();
1512   auto valueType = value.getType();
1513 
1514   // ODS already generates checks to make sure the result type is valid. We just
1515   // need to additionally check that the value's attribute type is consistent
1516   // with the result type.
1517   if (value.isa<IntegerAttr, FloatAttr>()) {
1518     if (valueType != opType)
1519       return constOp.emitOpError("result type (")
1520              << opType << ") does not match value type (" << valueType << ")";
1521     return success();
1522   }
1523   if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1524     if (valueType == opType)
1525       return success();
1526     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1527     auto shapedType = valueType.dyn_cast<ShapedType>();
1528     if (!arrayType) {
1529       return constOp.emitOpError(
1530           "must have spv.array result type for array value");
1531     }
1532 
1533     int numElements = arrayType.getNumElements();
1534     auto opElemType = arrayType.getElementType();
1535     while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1536       numElements *= t.getNumElements();
1537       opElemType = t.getElementType();
1538     }
1539     if (!opElemType.isIntOrFloat())
1540       return constOp.emitOpError("only support nested array result type");
1541 
1542     auto valueElemType = shapedType.getElementType();
1543     if (valueElemType != opElemType) {
1544       return constOp.emitOpError("result element type (")
1545              << opElemType << ") does not match value element type ("
1546              << valueElemType << ")";
1547     }
1548 
1549     if (numElements != shapedType.getNumElements()) {
1550       return constOp.emitOpError("result number of elements (")
1551              << numElements << ") does not match value number of elements ("
1552              << shapedType.getNumElements() << ")";
1553     }
1554     return success();
1555   }
1556   if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
1557     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1558     if (!arrayType)
1559       return constOp.emitOpError(
1560           "must have spv.array result type for array value");
1561     Type elemType = arrayType.getElementType();
1562     for (Attribute element : attayAttr.getValue()) {
1563       if (element.getType() != elemType)
1564         return constOp.emitOpError("has array element whose type (")
1565                << element.getType()
1566                << ") does not match the result element type (" << elemType
1567                << ')';
1568     }
1569     return success();
1570   }
1571   return constOp.emitOpError("cannot have value of type ") << valueType;
1572 }
1573 
isBuildableWith(Type type)1574 bool spirv::ConstantOp::isBuildableWith(Type type) {
1575   // Must be valid SPIR-V type first.
1576   if (!type.isa<spirv::SPIRVType>())
1577     return false;
1578 
1579   if (isa<SPIRVDialect>(type.getDialect())) {
1580     // TODO: support constant struct
1581     return type.isa<spirv::ArrayType>();
1582   }
1583 
1584   return true;
1585 }
1586 
getZero(Type type,Location loc,OpBuilder & builder)1587 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1588                                              OpBuilder &builder) {
1589   if (auto intType = type.dyn_cast<IntegerType>()) {
1590     unsigned width = intType.getWidth();
1591     if (width == 1)
1592       return builder.create<spirv::ConstantOp>(loc, type,
1593                                                builder.getBoolAttr(false));
1594     return builder.create<spirv::ConstantOp>(
1595         loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1596   }
1597   if (auto floatType = type.dyn_cast<FloatType>()) {
1598     return builder.create<spirv::ConstantOp>(
1599         loc, type, builder.getFloatAttr(floatType, 0.0));
1600   }
1601   if (auto vectorType = type.dyn_cast<VectorType>()) {
1602     Type elemType = vectorType.getElementType();
1603     if (elemType.isa<IntegerType>()) {
1604       return builder.create<spirv::ConstantOp>(
1605           loc, type,
1606           DenseElementsAttr::get(vectorType,
1607                                  IntegerAttr::get(elemType, 0.0).getValue()));
1608     }
1609     if (elemType.isa<FloatType>()) {
1610       return builder.create<spirv::ConstantOp>(
1611           loc, type,
1612           DenseFPElementsAttr::get(vectorType,
1613                                    FloatAttr::get(elemType, 0.0).getValue()));
1614     }
1615   }
1616 
1617   llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1618 }
1619 
getOne(Type type,Location loc,OpBuilder & builder)1620 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1621                                             OpBuilder &builder) {
1622   if (auto intType = type.dyn_cast<IntegerType>()) {
1623     unsigned width = intType.getWidth();
1624     if (width == 1)
1625       return builder.create<spirv::ConstantOp>(loc, type,
1626                                                builder.getBoolAttr(true));
1627     return builder.create<spirv::ConstantOp>(
1628         loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1629   }
1630   if (auto floatType = type.dyn_cast<FloatType>()) {
1631     return builder.create<spirv::ConstantOp>(
1632         loc, type, builder.getFloatAttr(floatType, 1.0));
1633   }
1634   if (auto vectorType = type.dyn_cast<VectorType>()) {
1635     Type elemType = vectorType.getElementType();
1636     if (elemType.isa<IntegerType>()) {
1637       return builder.create<spirv::ConstantOp>(
1638           loc, type,
1639           DenseElementsAttr::get(vectorType,
1640                                  IntegerAttr::get(elemType, 1.0).getValue()));
1641     }
1642     if (elemType.isa<FloatType>()) {
1643       return builder.create<spirv::ConstantOp>(
1644           loc, type,
1645           DenseFPElementsAttr::get(vectorType,
1646                                    FloatAttr::get(elemType, 1.0).getValue()));
1647     }
1648   }
1649 
1650   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1651 }
1652 
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1653 void mlir::spirv::ConstantOp::getAsmResultNames(
1654     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1655   Type type = getType();
1656 
1657   SmallString<32> specialNameBuffer;
1658   llvm::raw_svector_ostream specialName(specialNameBuffer);
1659   specialName << "cst";
1660 
1661   IntegerType intTy = type.dyn_cast<IntegerType>();
1662 
1663   if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1664     if (intTy && intTy.getWidth() == 1) {
1665       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1666     }
1667 
1668     if (intTy.isSignless()) {
1669       specialName << intCst.getInt();
1670     } else {
1671       specialName << intCst.getSInt();
1672     }
1673   }
1674 
1675   if (intTy || type.isa<FloatType>()) {
1676     specialName << '_' << type;
1677   }
1678 
1679   if (auto vecType = type.dyn_cast<VectorType>()) {
1680     specialName << "_vec_";
1681     specialName << vecType.getDimSize(0);
1682 
1683     Type elementType = vecType.getElementType();
1684 
1685     if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1686       specialName << "x" << elementType;
1687     }
1688   }
1689 
1690   setNameFn(getResult(), specialName.str());
1691 }
1692 
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1693 void mlir::spirv::AddressOfOp::getAsmResultNames(
1694     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1695   SmallString<32> specialNameBuffer;
1696   llvm::raw_svector_ostream specialName(specialNameBuffer);
1697   specialName << variable() << "_addr";
1698   setNameFn(getResult(), specialName.str());
1699 }
1700 
1701 //===----------------------------------------------------------------------===//
1702 // spv.EntryPoint
1703 //===----------------------------------------------------------------------===//
1704 
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)1705 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
1706                                 spirv::ExecutionModel executionModel,
1707                                 spirv::FuncOp function,
1708                                 ArrayRef<Attribute> interfaceVars) {
1709   build(builder, state,
1710         spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
1711         builder.getSymbolRefAttr(function),
1712         builder.getArrayAttr(interfaceVars));
1713 }
1714 
parseEntryPointOp(OpAsmParser & parser,OperationState & state)1715 static ParseResult parseEntryPointOp(OpAsmParser &parser,
1716                                      OperationState &state) {
1717   spirv::ExecutionModel execModel;
1718   SmallVector<OpAsmParser::OperandType, 0> identifiers;
1719   SmallVector<Type, 0> idTypes;
1720   SmallVector<Attribute, 4> interfaceVars;
1721 
1722   FlatSymbolRefAttr fn;
1723   if (parseEnumStrAttr(execModel, parser, state) ||
1724       parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
1725     return failure();
1726   }
1727 
1728   if (!parser.parseOptionalComma()) {
1729     // Parse the interface variables
1730     do {
1731       // The name of the interface variable attribute isnt important
1732       auto attrName = "var_symbol";
1733       FlatSymbolRefAttr var;
1734       NamedAttrList attrs;
1735       if (parser.parseAttribute(var, Type(), attrName, attrs)) {
1736         return failure();
1737       }
1738       interfaceVars.push_back(var);
1739     } while (!parser.parseOptionalComma());
1740   }
1741   state.addAttribute(kInterfaceAttrName,
1742                      parser.getBuilder().getArrayAttr(interfaceVars));
1743   return success();
1744 }
1745 
print(spirv::EntryPointOp entryPointOp,OpAsmPrinter & printer)1746 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
1747   printer << spirv::EntryPointOp::getOperationName() << " \""
1748           << stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
1749   printer.printSymbolName(entryPointOp.fn());
1750   auto interfaceVars = entryPointOp.interface().getValue();
1751   if (!interfaceVars.empty()) {
1752     printer << ", ";
1753     llvm::interleaveComma(interfaceVars, printer);
1754   }
1755 }
1756 
verify(spirv::EntryPointOp entryPointOp)1757 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
1758   // Checks for fn and interface symbol reference are done in spirv::ModuleOp
1759   // verification.
1760   return success();
1761 }
1762 
1763 //===----------------------------------------------------------------------===//
1764 // spv.ExecutionMode
1765 //===----------------------------------------------------------------------===//
1766 
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)1767 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
1768                                    spirv::FuncOp function,
1769                                    spirv::ExecutionMode executionMode,
1770                                    ArrayRef<int32_t> params) {
1771   build(builder, state, builder.getSymbolRefAttr(function),
1772         spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
1773         builder.getI32ArrayAttr(params));
1774 }
1775 
parseExecutionModeOp(OpAsmParser & parser,OperationState & state)1776 static ParseResult parseExecutionModeOp(OpAsmParser &parser,
1777                                         OperationState &state) {
1778   spirv::ExecutionMode execMode;
1779   Attribute fn;
1780   if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
1781       parseEnumStrAttr(execMode, parser, state)) {
1782     return failure();
1783   }
1784 
1785   SmallVector<int32_t, 4> values;
1786   Type i32Type = parser.getBuilder().getIntegerType(32);
1787   while (!parser.parseOptionalComma()) {
1788     NamedAttrList attr;
1789     Attribute value;
1790     if (parser.parseAttribute(value, i32Type, "value", attr)) {
1791       return failure();
1792     }
1793     values.push_back(value.cast<IntegerAttr>().getInt());
1794   }
1795   state.addAttribute(kValuesAttrName,
1796                      parser.getBuilder().getI32ArrayAttr(values));
1797   return success();
1798 }
1799 
print(spirv::ExecutionModeOp execModeOp,OpAsmPrinter & printer)1800 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
1801   printer << spirv::ExecutionModeOp::getOperationName() << " ";
1802   printer.printSymbolName(execModeOp.fn());
1803   printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
1804           << "\"";
1805   auto values = execModeOp.values();
1806   if (!values.size())
1807     return;
1808   printer << ", ";
1809   llvm::interleaveComma(values, printer, [&](Attribute a) {
1810     printer << a.cast<IntegerAttr>().getInt();
1811   });
1812 }
1813 
1814 //===----------------------------------------------------------------------===//
1815 // spv.func
1816 //===----------------------------------------------------------------------===//
1817 
parseFuncOp(OpAsmParser & parser,OperationState & state)1818 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
1819   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
1820   SmallVector<NamedAttrList, 4> argAttrs;
1821   SmallVector<NamedAttrList, 4> resultAttrs;
1822   SmallVector<Type, 4> argTypes;
1823   SmallVector<Type, 4> resultTypes;
1824   auto &builder = parser.getBuilder();
1825 
1826   // Parse the name as a symbol.
1827   StringAttr nameAttr;
1828   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1829                              state.attributes))
1830     return failure();
1831 
1832   // Parse the function signature.
1833   bool isVariadic = false;
1834   if (function_like_impl::parseFunctionSignature(
1835           parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
1836           isVariadic, resultTypes, resultAttrs))
1837     return failure();
1838 
1839   auto fnType = builder.getFunctionType(argTypes, resultTypes);
1840   state.addAttribute(function_like_impl::getTypeAttrName(),
1841                      TypeAttr::get(fnType));
1842 
1843   // Parse the optional function control keyword.
1844   spirv::FunctionControl fnControl;
1845   if (parseEnumStrAttr(fnControl, parser, state))
1846     return failure();
1847 
1848   // If additional attributes are present, parse them.
1849   if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1850     return failure();
1851 
1852   // Add the attributes to the function arguments.
1853   assert(argAttrs.size() == argTypes.size());
1854   assert(resultAttrs.size() == resultTypes.size());
1855   function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
1856                                            resultAttrs);
1857 
1858   // Parse the optional function body.
1859   auto *body = state.addRegion();
1860   OptionalParseResult result = parser.parseOptionalRegion(
1861       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1862   return failure(result.hasValue() && failed(*result));
1863 }
1864 
print(spirv::FuncOp fnOp,OpAsmPrinter & printer)1865 static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
1866   // Print function name, signature, and control.
1867   printer << spirv::FuncOp::getOperationName() << " ";
1868   printer.printSymbolName(fnOp.sym_name());
1869   auto fnType = fnOp.getType();
1870   function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
1871                                              /*isVariadic=*/false,
1872                                              fnType.getResults());
1873   printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
1874           << "\"";
1875   function_like_impl::printFunctionAttributes(
1876       printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
1877       {spirv::attributeName<spirv::FunctionControl>()});
1878 
1879   // Print the body if this is not an external function.
1880   Region &body = fnOp.body();
1881   if (!body.empty())
1882     printer.printRegion(body, /*printEntryBlockArgs=*/false,
1883                         /*printBlockTerminators=*/true);
1884 }
1885 
verifyType()1886 LogicalResult spirv::FuncOp::verifyType() {
1887   auto type = getTypeAttr().getValue();
1888   if (!type.isa<FunctionType>())
1889     return emitOpError("requires '" + getTypeAttrName() +
1890                        "' attribute of function type");
1891   if (getType().getNumResults() > 1)
1892     return emitOpError("cannot have more than one result");
1893   return success();
1894 }
1895 
verifyBody()1896 LogicalResult spirv::FuncOp::verifyBody() {
1897   FunctionType fnType = getType();
1898 
1899   auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1900     if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1901       if (fnType.getNumResults() != 0)
1902         return retOp.emitOpError("cannot be used in functions returning value");
1903     } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1904       if (fnType.getNumResults() != 1)
1905         return retOp.emitOpError(
1906                    "returns 1 value but enclosing function requires ")
1907                << fnType.getNumResults() << " results";
1908 
1909       auto retOperandType = retOp.value().getType();
1910       auto fnResultType = fnType.getResult(0);
1911       if (retOperandType != fnResultType)
1912         return retOp.emitOpError(" return value's type (")
1913                << retOperandType << ") mismatch with function's result type ("
1914                << fnResultType << ")";
1915     }
1916     return WalkResult::advance();
1917   });
1918 
1919   // TODO: verify other bits like linkage type.
1920 
1921   return failure(walkResult.wasInterrupted());
1922 }
1923 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)1924 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1925                           StringRef name, FunctionType type,
1926                           spirv::FunctionControl control,
1927                           ArrayRef<NamedAttribute> attrs) {
1928   state.addAttribute(SymbolTable::getSymbolAttrName(),
1929                      builder.getStringAttr(name));
1930   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
1931   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1932                      builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
1933   state.attributes.append(attrs.begin(), attrs.end());
1934   state.addRegion();
1935 }
1936 
1937 // CallableOpInterface
getCallableRegion()1938 Region *spirv::FuncOp::getCallableRegion() {
1939   return isExternal() ? nullptr : &body();
1940 }
1941 
1942 // CallableOpInterface
getCallableResults()1943 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
1944   return getType().getResults();
1945 }
1946 
1947 //===----------------------------------------------------------------------===//
1948 // spv.FunctionCall
1949 //===----------------------------------------------------------------------===//
1950 
verify(spirv::FunctionCallOp functionCallOp)1951 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
1952   auto fnName = functionCallOp.callee();
1953 
1954   auto funcOp =
1955       dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
1956           functionCallOp->getParentOp(), fnName));
1957   if (!funcOp) {
1958     return functionCallOp.emitOpError("callee function '")
1959            << fnName << "' not found in nearest symbol table";
1960   }
1961 
1962   auto functionType = funcOp.getType();
1963 
1964   if (functionCallOp.getNumResults() > 1) {
1965     return functionCallOp.emitOpError(
1966                "expected callee function to have 0 or 1 result, but provided ")
1967            << functionCallOp.getNumResults();
1968   }
1969 
1970   if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
1971     return functionCallOp.emitOpError(
1972                "has incorrect number of operands for callee: expected ")
1973            << functionType.getNumInputs() << ", but provided "
1974            << functionCallOp.getNumOperands();
1975   }
1976 
1977   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
1978     if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
1979       return functionCallOp.emitOpError(
1980                  "operand type mismatch: expected operand type ")
1981              << functionType.getInput(i) << ", but provided "
1982              << functionCallOp.getOperand(i).getType() << " for operand number "
1983              << i;
1984     }
1985   }
1986 
1987   if (functionType.getNumResults() != functionCallOp.getNumResults()) {
1988     return functionCallOp.emitOpError(
1989                "has incorrect number of results has for callee: expected ")
1990            << functionType.getNumResults() << ", but provided "
1991            << functionCallOp.getNumResults();
1992   }
1993 
1994   if (functionCallOp.getNumResults() &&
1995       (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
1996     return functionCallOp.emitOpError("result type mismatch: expected ")
1997            << functionType.getResult(0) << ", but provided "
1998            << functionCallOp.getResult(0).getType();
1999   }
2000 
2001   return success();
2002 }
2003 
getCallableForCallee()2004 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2005   return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2006 }
2007 
getArgOperands()2008 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2009   return arguments();
2010 }
2011 
2012 //===----------------------------------------------------------------------===//
2013 // spv.GlobalVariable
2014 //===----------------------------------------------------------------------===//
2015 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)2016 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2017                                     Type type, StringRef name,
2018                                     unsigned descriptorSet, unsigned binding) {
2019   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
2020         nullptr);
2021   state.addAttribute(
2022       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2023       builder.getI32IntegerAttr(descriptorSet));
2024   state.addAttribute(
2025       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2026       builder.getI32IntegerAttr(binding));
2027 }
2028 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)2029 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2030                                     Type type, StringRef name,
2031                                     spirv::BuiltIn builtin) {
2032   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
2033         nullptr);
2034   state.addAttribute(
2035       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2036       builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2037 }
2038 
parseGlobalVariableOp(OpAsmParser & parser,OperationState & state)2039 static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
2040                                          OperationState &state) {
2041   // Parse variable name.
2042   StringAttr nameAttr;
2043   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2044                              state.attributes)) {
2045     return failure();
2046   }
2047 
2048   // Parse optional initializer
2049   if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2050     FlatSymbolRefAttr initSymbol;
2051     if (parser.parseLParen() ||
2052         parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2053                               state.attributes) ||
2054         parser.parseRParen())
2055       return failure();
2056   }
2057 
2058   if (parseVariableDecorations(parser, state)) {
2059     return failure();
2060   }
2061 
2062   Type type;
2063   auto loc = parser.getCurrentLocation();
2064   if (parser.parseColonType(type)) {
2065     return failure();
2066   }
2067   if (!type.isa<spirv::PointerType>()) {
2068     return parser.emitError(loc, "expected spv.ptr type");
2069   }
2070   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2071 
2072   return success();
2073 }
2074 
print(spirv::GlobalVariableOp varOp,OpAsmPrinter & printer)2075 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
2076   auto *op = varOp.getOperation();
2077   SmallVector<StringRef, 4> elidedAttrs{
2078       spirv::attributeName<spirv::StorageClass>()};
2079   printer << spirv::GlobalVariableOp::getOperationName();
2080 
2081   // Print variable name.
2082   printer << ' ';
2083   printer.printSymbolName(varOp.sym_name());
2084   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2085 
2086   // Print optional initializer
2087   if (auto initializer = varOp.initializer()) {
2088     printer << " " << kInitializerAttrName << '(';
2089     printer.printSymbolName(initializer.getValue());
2090     printer << ')';
2091     elidedAttrs.push_back(kInitializerAttrName);
2092   }
2093 
2094   elidedAttrs.push_back(kTypeAttrName);
2095   printVariableDecorations(op, printer, elidedAttrs);
2096   printer << " : " << varOp.type();
2097 }
2098 
verify(spirv::GlobalVariableOp varOp)2099 static LogicalResult verify(spirv::GlobalVariableOp varOp) {
2100   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2101   // object. It cannot be Generic. It must be the same as the Storage Class
2102   // operand of the Result Type."
2103   // Also, Function storage class is reserved by spv.Variable.
2104   auto storageClass = varOp.storageClass();
2105   if (storageClass == spirv::StorageClass::Generic ||
2106       storageClass == spirv::StorageClass::Function) {
2107     return varOp.emitOpError("storage class cannot be '")
2108            << stringifyStorageClass(storageClass) << "'";
2109   }
2110 
2111   if (auto init =
2112           varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2113     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2114         varOp->getParentOp(), init.getValue());
2115     // TODO: Currently only variable initialization with specialization
2116     // constants and other variables is supported. They could be normal
2117     // constants in the module scope as well.
2118     if (!initOp ||
2119         !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2120       return varOp.emitOpError("initializer must be result of a "
2121                                "spv.SpecConstant or spv.GlobalVariable op");
2122     }
2123   }
2124 
2125   return success();
2126 }
2127 
2128 //===----------------------------------------------------------------------===//
2129 // spv.GroupBroadcast
2130 //===----------------------------------------------------------------------===//
2131 
verify(spirv::GroupBroadcastOp broadcastOp)2132 static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
2133   spirv::Scope scope = broadcastOp.execution_scope();
2134   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2135     return broadcastOp.emitOpError(
2136         "execution scope must be 'Workgroup' or 'Subgroup'");
2137 
2138   if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
2139     if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
2140       return broadcastOp.emitOpError("localid is a vector and can be with only "
2141                                      " 2 or 3 components, actual number is ")
2142              << localIdTy.getNumElements();
2143 
2144   return success();
2145 }
2146 
2147 //===----------------------------------------------------------------------===//
2148 // spv.GroupNonUniformBallotOp
2149 //===----------------------------------------------------------------------===//
2150 
verify(spirv::GroupNonUniformBallotOp ballotOp)2151 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
2152   spirv::Scope scope = ballotOp.execution_scope();
2153   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2154     return ballotOp.emitOpError(
2155         "execution scope must be 'Workgroup' or 'Subgroup'");
2156 
2157   return success();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // spv.GroupNonUniformBroadcast
2162 //===----------------------------------------------------------------------===//
2163 
verify(spirv::GroupNonUniformBroadcastOp broadcastOp)2164 static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2165   spirv::Scope scope = broadcastOp.execution_scope();
2166   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2167     return broadcastOp.emitOpError(
2168         "execution scope must be 'Workgroup' or 'Subgroup'");
2169 
2170   // SPIR-V spec: "Before version 1.5, Id must come from a
2171   // constant instruction.
2172   auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2173   if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
2174     targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2175 
2176   if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2177     auto *idOp = broadcastOp.id().getDefiningOp();
2178     if (!idOp || !isa<spirv::ConstantOp,           // for normal constant
2179                       spirv::ReferenceOfOp>(idOp)) // for spec constant
2180       return broadcastOp.emitOpError("id must be the result of a constant op");
2181   }
2182 
2183   return success();
2184 }
2185 
2186 //===----------------------------------------------------------------------===//
2187 // spv.SubgroupBlockReadINTEL
2188 //===----------------------------------------------------------------------===//
2189 
parseSubgroupBlockReadINTELOp(OpAsmParser & parser,OperationState & state)2190 static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
2191                                                  OperationState &state) {
2192   // Parse the storage class specification
2193   spirv::StorageClass storageClass;
2194   OpAsmParser::OperandType ptrInfo;
2195   Type elementType;
2196   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2197       parser.parseColon() || parser.parseType(elementType)) {
2198     return failure();
2199   }
2200 
2201   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2202   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2203     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2204 
2205   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2206     return failure();
2207   }
2208 
2209   state.addTypes(elementType);
2210   return success();
2211 }
2212 
print(spirv::SubgroupBlockReadINTELOp blockReadOp,OpAsmPrinter & printer)2213 static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2214                   OpAsmPrinter &printer) {
2215   SmallVector<StringRef, 4> elidedAttrs;
2216   printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
2217           << blockReadOp.ptr();
2218   printer << " : " << blockReadOp.getType();
2219 }
2220 
verify(spirv::SubgroupBlockReadINTELOp blockReadOp)2221 static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2222   if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2223                                                 blockReadOp.value())))
2224     return failure();
2225 
2226   return success();
2227 }
2228 
2229 //===----------------------------------------------------------------------===//
2230 // spv.SubgroupBlockWriteINTEL
2231 //===----------------------------------------------------------------------===//
2232 
parseSubgroupBlockWriteINTELOp(OpAsmParser & parser,OperationState & state)2233 static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
2234                                                   OperationState &state) {
2235   // Parse the storage class specification
2236   spirv::StorageClass storageClass;
2237   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2238   auto loc = parser.getCurrentLocation();
2239   Type elementType;
2240   if (parseEnumStrAttr(storageClass, parser) ||
2241       parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2242       parser.parseType(elementType)) {
2243     return failure();
2244   }
2245 
2246   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2247   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2248     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2249 
2250   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2251                              state.operands)) {
2252     return failure();
2253   }
2254   return success();
2255 }
2256 
print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,OpAsmPrinter & printer)2257 static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2258                   OpAsmPrinter &printer) {
2259   SmallVector<StringRef, 4> elidedAttrs;
2260   printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
2261           << blockWriteOp.ptr() << ", " << blockWriteOp.value();
2262   printer << " : " << blockWriteOp.value().getType();
2263 }
2264 
verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp)2265 static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2266   if (failed(verifyBlockReadWritePtrAndValTypes(
2267           blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2268     return failure();
2269 
2270   return success();
2271 }
2272 
2273 //===----------------------------------------------------------------------===//
2274 // spv.GroupNonUniformElectOp
2275 //===----------------------------------------------------------------------===//
2276 
build(OpBuilder & builder,OperationState & state,spirv::Scope scope)2277 void spirv::GroupNonUniformElectOp::build(OpBuilder &builder,
2278                                           OperationState &state,
2279                                           spirv::Scope scope) {
2280   build(builder, state, builder.getI1Type(), scope);
2281 }
2282 
verify(spirv::GroupNonUniformElectOp groupOp)2283 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
2284   spirv::Scope scope = groupOp.execution_scope();
2285   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2286     return groupOp.emitOpError(
2287         "execution scope must be 'Workgroup' or 'Subgroup'");
2288 
2289   return success();
2290 }
2291 
2292 //===----------------------------------------------------------------------===//
2293 // spv.LoadOp
2294 //===----------------------------------------------------------------------===//
2295 
build(OpBuilder & builder,OperationState & state,Value basePtr,MemoryAccessAttr memoryAccess,IntegerAttr alignment)2296 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2297                           Value basePtr, MemoryAccessAttr memoryAccess,
2298                           IntegerAttr alignment) {
2299   auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2300   build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2301         alignment);
2302 }
2303 
parseLoadOp(OpAsmParser & parser,OperationState & state)2304 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
2305   // Parse the storage class specification
2306   spirv::StorageClass storageClass;
2307   OpAsmParser::OperandType ptrInfo;
2308   Type elementType;
2309   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2310       parseMemoryAccessAttributes(parser, state) ||
2311       parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2312       parser.parseType(elementType)) {
2313     return failure();
2314   }
2315 
2316   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2317   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2318     return failure();
2319   }
2320 
2321   state.addTypes(elementType);
2322   return success();
2323 }
2324 
print(spirv::LoadOp loadOp,OpAsmPrinter & printer)2325 static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
2326   auto *op = loadOp.getOperation();
2327   SmallVector<StringRef, 4> elidedAttrs;
2328   StringRef sc = stringifyStorageClass(
2329       loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2330   printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
2331           << loadOp.ptr();
2332 
2333   printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
2334 
2335   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2336   printer << " : " << loadOp.getType();
2337 }
2338 
verify(spirv::LoadOp loadOp)2339 static LogicalResult verify(spirv::LoadOp loadOp) {
2340   // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2341   // type with fixed size; i.e., it cannot be, nor include, any
2342   // OpTypeRuntimeArray types."
2343   if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
2344                                            loadOp.value()))) {
2345     return failure();
2346   }
2347   return verifyMemoryAccessAttribute(loadOp);
2348 }
2349 
2350 //===----------------------------------------------------------------------===//
2351 // spv.mlir.loop
2352 //===----------------------------------------------------------------------===//
2353 
build(OpBuilder & builder,OperationState & state)2354 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2355   state.addAttribute("loop_control",
2356                      builder.getI32IntegerAttr(
2357                          static_cast<uint32_t>(spirv::LoopControl::None)));
2358   state.addRegion();
2359 }
2360 
parseLoopOp(OpAsmParser & parser,OperationState & state)2361 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
2362   if (parseControlAttribute<spirv::LoopControl>(parser, state))
2363     return failure();
2364   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2365                             /*argTypes=*/{});
2366 }
2367 
print(spirv::LoopOp loopOp,OpAsmPrinter & printer)2368 static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
2369   auto *op = loopOp.getOperation();
2370 
2371   printer << spirv::LoopOp::getOperationName();
2372   auto control = loopOp.loop_control();
2373   if (control != spirv::LoopControl::None)
2374     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2375   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2376                       /*printBlockTerminators=*/true);
2377 }
2378 
2379 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2380 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2381 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2382   // Check that there is only one op in the `srcBlock`.
2383   if (!llvm::hasSingleElement(srcBlock))
2384     return false;
2385 
2386   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2387   return branchOp && branchOp.getSuccessor() == &dstBlock;
2388 }
2389 
verify(spirv::LoopOp loopOp)2390 static LogicalResult verify(spirv::LoopOp loopOp) {
2391   auto *op = loopOp.getOperation();
2392 
2393   // We need to verify that the blocks follow the following layout:
2394   //
2395   //                     +-------------+
2396   //                     | entry block |
2397   //                     +-------------+
2398   //                            |
2399   //                            v
2400   //                     +-------------+
2401   //                     | loop header | <-----+
2402   //                     +-------------+       |
2403   //                                           |
2404   //                           ...             |
2405   //                          \ | /            |
2406   //                            v              |
2407   //                    +---------------+      |
2408   //                    | loop continue | -----+
2409   //                    +---------------+
2410   //
2411   //                           ...
2412   //                          \ | /
2413   //                            v
2414   //                     +-------------+
2415   //                     | merge block |
2416   //                     +-------------+
2417 
2418   auto &region = op->getRegion(0);
2419   // Allow empty region as a degenerated case, which can come from
2420   // optimizations.
2421   if (region.empty())
2422     return success();
2423 
2424   // The last block is the merge block.
2425   Block &merge = region.back();
2426   if (!isMergeBlock(merge))
2427     return loopOp.emitOpError(
2428         "last block must be the merge block with only one 'spv.mlir.merge' op");
2429 
2430   if (std::next(region.begin()) == region.end())
2431     return loopOp.emitOpError(
2432         "must have an entry block branching to the loop header block");
2433   // The first block is the entry block.
2434   Block &entry = region.front();
2435 
2436   if (std::next(region.begin(), 2) == region.end())
2437     return loopOp.emitOpError(
2438         "must have a loop header block branched from the entry block");
2439   // The second block is the loop header block.
2440   Block &header = *std::next(region.begin(), 1);
2441 
2442   if (!hasOneBranchOpTo(entry, header))
2443     return loopOp.emitOpError(
2444         "entry block must only have one 'spv.Branch' op to the second block");
2445 
2446   if (std::next(region.begin(), 3) == region.end())
2447     return loopOp.emitOpError(
2448         "requires a loop continue block branching to the loop header block");
2449   // The second to last block is the loop continue block.
2450   Block &cont = *std::prev(region.end(), 2);
2451 
2452   // Make sure that we have a branch from the loop continue block to the loop
2453   // header block.
2454   if (llvm::none_of(
2455           llvm::seq<unsigned>(0, cont.getNumSuccessors()),
2456           [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
2457     return loopOp.emitOpError("second to last block must be the loop continue "
2458                               "block that branches to the loop header block");
2459 
2460   // Make sure that no other blocks (except the entry and loop continue block)
2461   // branches to the loop header block.
2462   for (auto &block : llvm::make_range(std::next(region.begin(), 2),
2463                                       std::prev(region.end(), 2))) {
2464     for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
2465       if (block.getSuccessor(i) == &header) {
2466         return loopOp.emitOpError("can only have the entry and loop continue "
2467                                   "block branching to the loop header block");
2468       }
2469     }
2470   }
2471 
2472   return success();
2473 }
2474 
getEntryBlock()2475 Block *spirv::LoopOp::getEntryBlock() {
2476   assert(!body().empty() && "op region should not be empty!");
2477   return &body().front();
2478 }
2479 
getHeaderBlock()2480 Block *spirv::LoopOp::getHeaderBlock() {
2481   assert(!body().empty() && "op region should not be empty!");
2482   // The second block is the loop header block.
2483   return &*std::next(body().begin());
2484 }
2485 
getContinueBlock()2486 Block *spirv::LoopOp::getContinueBlock() {
2487   assert(!body().empty() && "op region should not be empty!");
2488   // The second to last block is the loop continue block.
2489   return &*std::prev(body().end(), 2);
2490 }
2491 
getMergeBlock()2492 Block *spirv::LoopOp::getMergeBlock() {
2493   assert(!body().empty() && "op region should not be empty!");
2494   // The last block is the loop merge block.
2495   return &body().back();
2496 }
2497 
addEntryAndMergeBlock()2498 void spirv::LoopOp::addEntryAndMergeBlock() {
2499   assert(body().empty() && "entry and merge block already exist");
2500   body().push_back(new Block());
2501   auto *mergeBlock = new Block();
2502   body().push_back(mergeBlock);
2503   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2504 
2505   // Add a spv.mlir.merge op into the merge block.
2506   builder.create<spirv::MergeOp>(getLoc());
2507 }
2508 
2509 //===----------------------------------------------------------------------===//
2510 // spv.mlir.merge
2511 //===----------------------------------------------------------------------===//
2512 
verify(spirv::MergeOp mergeOp)2513 static LogicalResult verify(spirv::MergeOp mergeOp) {
2514   auto *parentOp = mergeOp->getParentOp();
2515   if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
2516     return mergeOp.emitOpError(
2517         "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
2518 
2519   Block &parentLastBlock = mergeOp->getParentRegion()->back();
2520   if (mergeOp.getOperation() != parentLastBlock.getTerminator())
2521     return mergeOp.emitOpError("can only be used in the last block of "
2522                                "'spv.mlir.selection' or 'spv.mlir.loop'");
2523   return success();
2524 }
2525 
2526 //===----------------------------------------------------------------------===//
2527 // spv.module
2528 //===----------------------------------------------------------------------===//
2529 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)2530 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2531                             Optional<StringRef> name) {
2532   OpBuilder::InsertionGuard guard(builder);
2533   builder.createBlock(state.addRegion());
2534   if (name) {
2535     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2536                             builder.getStringAttr(*name));
2537   }
2538 }
2539 
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<StringRef> name)2540 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2541                             spirv::AddressingModel addressingModel,
2542                             spirv::MemoryModel memoryModel,
2543                             Optional<StringRef> name) {
2544   state.addAttribute(
2545       "addressing_model",
2546       builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
2547   state.addAttribute("memory_model", builder.getI32IntegerAttr(
2548                                          static_cast<int32_t>(memoryModel)));
2549   OpBuilder::InsertionGuard guard(builder);
2550   builder.createBlock(state.addRegion());
2551   if (name) {
2552     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2553                             builder.getStringAttr(*name));
2554   }
2555 }
2556 
parseModuleOp(OpAsmParser & parser,OperationState & state)2557 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
2558   Region *body = state.addRegion();
2559 
2560   // If the name is present, parse it.
2561   StringAttr nameAttr;
2562   parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2563                                  state.attributes);
2564 
2565   // Parse attributes
2566   spirv::AddressingModel addrModel;
2567   spirv::MemoryModel memoryModel;
2568   if (parseEnumKeywordAttr(addrModel, parser, state) ||
2569       parseEnumKeywordAttr(memoryModel, parser, state))
2570     return failure();
2571 
2572   if (succeeded(parser.parseOptionalKeyword("requires"))) {
2573     spirv::VerCapExtAttr vceTriple;
2574     if (parser.parseAttribute(vceTriple,
2575                               spirv::ModuleOp::getVCETripleAttrName(),
2576                               state.attributes))
2577       return failure();
2578   }
2579 
2580   if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2581     return failure();
2582 
2583   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2584     return failure();
2585 
2586   // Make sure we have at least one block.
2587   if (body->empty())
2588     body->push_back(new Block());
2589 
2590   return success();
2591 }
2592 
print(spirv::ModuleOp moduleOp,OpAsmPrinter & printer)2593 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
2594   printer << spirv::ModuleOp::getOperationName();
2595 
2596   if (Optional<StringRef> name = moduleOp.getName()) {
2597     printer << ' ';
2598     printer.printSymbolName(*name);
2599   }
2600 
2601   SmallVector<StringRef, 2> elidedAttrs;
2602 
2603   printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
2604           << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
2605   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
2606   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
2607   elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
2608                       SymbolTable::getSymbolAttrName()});
2609 
2610   if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
2611     printer << " requires " << *triple;
2612     elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
2613   }
2614 
2615   printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
2616   printer.printRegion(moduleOp.getRegion());
2617 }
2618 
verify(spirv::ModuleOp moduleOp)2619 static LogicalResult verify(spirv::ModuleOp moduleOp) {
2620   auto &op = *moduleOp.getOperation();
2621   auto *dialect = op.getDialect();
2622   DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
2623       entryPoints;
2624   SymbolTable table(moduleOp);
2625 
2626   for (auto &op : *moduleOp.getBody()) {
2627     if (op.getDialect() != dialect)
2628       return op.emitError("'spv.module' can only contain spv.* ops");
2629 
2630     // For EntryPoint op, check that the function and execution model is not
2631     // duplicated in EntryPointOps. Also verify that the interface specified
2632     // comes from globalVariables here to make this check cheaper.
2633     if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
2634       auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
2635       if (!funcOp) {
2636         return entryPointOp.emitError("function '")
2637                << entryPointOp.fn() << "' not found in 'spv.module'";
2638       }
2639       if (auto interface = entryPointOp.interface()) {
2640         for (Attribute varRef : interface) {
2641           auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
2642           if (!varSymRef) {
2643             return entryPointOp.emitError(
2644                        "expected symbol reference for interface "
2645                        "specification instead of '")
2646                    << varRef;
2647           }
2648           auto variableOp =
2649               table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
2650           if (!variableOp) {
2651             return entryPointOp.emitError("expected spv.GlobalVariable "
2652                                           "symbol reference instead of'")
2653                    << varSymRef << "'";
2654           }
2655         }
2656       }
2657 
2658       auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
2659           funcOp, entryPointOp.execution_model());
2660       auto entryPtIt = entryPoints.find(key);
2661       if (entryPtIt != entryPoints.end()) {
2662         return entryPointOp.emitError("duplicate of a previous EntryPointOp");
2663       }
2664       entryPoints[key] = entryPointOp;
2665     } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
2666       if (funcOp.isExternal())
2667         return op.emitError("'spv.module' cannot contain external functions");
2668 
2669       // TODO: move this check to spv.func.
2670       for (auto &block : funcOp)
2671         for (auto &op : block) {
2672           if (op.getDialect() != dialect)
2673             return op.emitError(
2674                 "functions in 'spv.module' can only contain spv.* ops");
2675         }
2676     }
2677   }
2678 
2679   return success();
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // spv.mlir.referenceof
2684 //===----------------------------------------------------------------------===//
2685 
verify(spirv::ReferenceOfOp referenceOfOp)2686 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2687   auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2688       referenceOfOp->getParentOp(), referenceOfOp.spec_const());
2689   Type constType;
2690 
2691   auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2692   if (specConstOp)
2693     constType = specConstOp.default_value().getType();
2694 
2695   auto specConstCompositeOp =
2696       dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2697   if (specConstCompositeOp)
2698     constType = specConstCompositeOp.type();
2699 
2700   if (!specConstOp && !specConstCompositeOp)
2701     return referenceOfOp.emitOpError(
2702         "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
2703 
2704   if (referenceOfOp.reference().getType() != constType)
2705     return referenceOfOp.emitOpError("result type mismatch with the referenced "
2706                                      "specialization constant's type");
2707 
2708   return success();
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // spv.Return
2713 //===----------------------------------------------------------------------===//
2714 
verify(spirv::ReturnOp returnOp)2715 static LogicalResult verify(spirv::ReturnOp returnOp) {
2716   // Verification is performed in spv.func op.
2717   return success();
2718 }
2719 
2720 //===----------------------------------------------------------------------===//
2721 // spv.ReturnValue
2722 //===----------------------------------------------------------------------===//
2723 
verify(spirv::ReturnValueOp retValOp)2724 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
2725   // Verification is performed in spv.func op.
2726   return success();
2727 }
2728 
2729 //===----------------------------------------------------------------------===//
2730 // spv.Select
2731 //===----------------------------------------------------------------------===//
2732 
build(OpBuilder & builder,OperationState & state,Value cond,Value trueValue,Value falseValue)2733 void spirv::SelectOp::build(OpBuilder &builder, OperationState &state,
2734                             Value cond, Value trueValue, Value falseValue) {
2735   build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
2736 }
2737 
verify(spirv::SelectOp op)2738 static LogicalResult verify(spirv::SelectOp op) {
2739   if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
2740     auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
2741     if (!resultVectorTy) {
2742       return op.emitOpError("result expected to be of vector type when "
2743                             "condition is of vector type");
2744     }
2745     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
2746       return op.emitOpError("result should have the same number of elements as "
2747                             "the condition when condition is of vector type");
2748     }
2749   }
2750   return success();
2751 }
2752 
2753 //===----------------------------------------------------------------------===//
2754 // spv.mlir.selection
2755 //===----------------------------------------------------------------------===//
2756 
parseSelectionOp(OpAsmParser & parser,OperationState & state)2757 static ParseResult parseSelectionOp(OpAsmParser &parser,
2758                                     OperationState &state) {
2759   if (parseControlAttribute<spirv::SelectionControl>(parser, state))
2760     return failure();
2761   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2762                             /*argTypes=*/{});
2763 }
2764 
print(spirv::SelectionOp selectionOp,OpAsmPrinter & printer)2765 static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
2766   auto *op = selectionOp.getOperation();
2767 
2768   printer << spirv::SelectionOp::getOperationName();
2769   auto control = selectionOp.selection_control();
2770   if (control != spirv::SelectionControl::None)
2771     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
2772   printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2773                       /*printBlockTerminators=*/true);
2774 }
2775 
verify(spirv::SelectionOp selectionOp)2776 static LogicalResult verify(spirv::SelectionOp selectionOp) {
2777   auto *op = selectionOp.getOperation();
2778 
2779   // We need to verify that the blocks follow the following layout:
2780   //
2781   //                     +--------------+
2782   //                     | header block |
2783   //                     +--------------+
2784   //                          / | \
2785   //                           ...
2786   //
2787   //
2788   //         +---------+   +---------+   +---------+
2789   //         | case #0 |   | case #1 |   | case #2 |  ...
2790   //         +---------+   +---------+   +---------+
2791   //
2792   //
2793   //                           ...
2794   //                          \ | /
2795   //                            v
2796   //                     +-------------+
2797   //                     | merge block |
2798   //                     +-------------+
2799 
2800   auto &region = op->getRegion(0);
2801   // Allow empty region as a degenerated case, which can come from
2802   // optimizations.
2803   if (region.empty())
2804     return success();
2805 
2806   // The last block is the merge block.
2807   if (!isMergeBlock(region.back()))
2808     return selectionOp.emitOpError(
2809         "last block must be the merge block with only one 'spv.mlir.merge' op");
2810 
2811   if (std::next(region.begin()) == region.end())
2812     return selectionOp.emitOpError("must have a selection header block");
2813 
2814   return success();
2815 }
2816 
getHeaderBlock()2817 Block *spirv::SelectionOp::getHeaderBlock() {
2818   assert(!body().empty() && "op region should not be empty!");
2819   // The first block is the loop header block.
2820   return &body().front();
2821 }
2822 
getMergeBlock()2823 Block *spirv::SelectionOp::getMergeBlock() {
2824   assert(!body().empty() && "op region should not be empty!");
2825   // The last block is the loop merge block.
2826   return &body().back();
2827 }
2828 
addMergeBlock()2829 void spirv::SelectionOp::addMergeBlock() {
2830   assert(body().empty() && "entry and merge block already exist");
2831   auto *mergeBlock = new Block();
2832   body().push_back(mergeBlock);
2833   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2834 
2835   // Add a spv.mlir.merge op into the merge block.
2836   builder.create<spirv::MergeOp>(getLoc());
2837 }
2838 
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)2839 spirv::SelectionOp spirv::SelectionOp::createIfThen(
2840     Location loc, Value condition,
2841     function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2842   auto selectionOp =
2843       builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
2844 
2845   selectionOp.addMergeBlock();
2846   Block *mergeBlock = selectionOp.getMergeBlock();
2847   Block *thenBlock = nullptr;
2848 
2849   // Build the "then" block.
2850   {
2851     OpBuilder::InsertionGuard guard(builder);
2852     thenBlock = builder.createBlock(mergeBlock);
2853     thenBody(builder);
2854     builder.create<spirv::BranchOp>(loc, mergeBlock);
2855   }
2856 
2857   // Build the header block.
2858   {
2859     OpBuilder::InsertionGuard guard(builder);
2860     builder.createBlock(thenBlock);
2861     builder.create<spirv::BranchConditionalOp>(
2862         loc, condition, thenBlock,
2863         /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
2864         /*falseArguments=*/ArrayRef<Value>());
2865   }
2866 
2867   return selectionOp;
2868 }
2869 
2870 //===----------------------------------------------------------------------===//
2871 // spv.SpecConstant
2872 //===----------------------------------------------------------------------===//
2873 
parseSpecConstantOp(OpAsmParser & parser,OperationState & state)2874 static ParseResult parseSpecConstantOp(OpAsmParser &parser,
2875                                        OperationState &state) {
2876   StringAttr nameAttr;
2877   Attribute valueAttr;
2878 
2879   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2880                              state.attributes))
2881     return failure();
2882 
2883   // Parse optional spec_id.
2884   if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
2885     IntegerAttr specIdAttr;
2886     if (parser.parseLParen() ||
2887         parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
2888         parser.parseRParen())
2889       return failure();
2890   }
2891 
2892   if (parser.parseEqual() ||
2893       parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
2894     return failure();
2895 
2896   return success();
2897 }
2898 
print(spirv::SpecConstantOp constOp,OpAsmPrinter & printer)2899 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
2900   printer << spirv::SpecConstantOp::getOperationName() << ' ';
2901   printer.printSymbolName(constOp.sym_name());
2902   if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2903     printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
2904   printer << " = " << constOp.default_value();
2905 }
2906 
verify(spirv::SpecConstantOp constOp)2907 static LogicalResult verify(spirv::SpecConstantOp constOp) {
2908   if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2909     if (specID.getValue().isNegative())
2910       return constOp.emitOpError("SpecId cannot be negative");
2911 
2912   auto value = constOp.default_value();
2913   if (value.isa<IntegerAttr, FloatAttr>()) {
2914     // Make sure bitwidth is allowed.
2915     if (!value.getType().isa<spirv::SPIRVType>())
2916       return constOp.emitOpError("default value bitwidth disallowed");
2917     return success();
2918   }
2919   return constOp.emitOpError(
2920       "default value can only be a bool, integer, or float scalar");
2921 }
2922 
2923 //===----------------------------------------------------------------------===//
2924 // spv.StoreOp
2925 //===----------------------------------------------------------------------===//
2926 
parseStoreOp(OpAsmParser & parser,OperationState & state)2927 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
2928   // Parse the storage class specification
2929   spirv::StorageClass storageClass;
2930   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2931   auto loc = parser.getCurrentLocation();
2932   Type elementType;
2933   if (parseEnumStrAttr(storageClass, parser) ||
2934       parser.parseOperandList(operandInfo, 2) ||
2935       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
2936       parser.parseType(elementType)) {
2937     return failure();
2938   }
2939 
2940   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2941   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2942                              state.operands)) {
2943     return failure();
2944   }
2945   return success();
2946 }
2947 
print(spirv::StoreOp storeOp,OpAsmPrinter & printer)2948 static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
2949   auto *op = storeOp.getOperation();
2950   SmallVector<StringRef, 4> elidedAttrs;
2951   StringRef sc = stringifyStorageClass(
2952       storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2953   printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
2954           << storeOp.ptr() << ", " << storeOp.value();
2955 
2956   printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
2957 
2958   printer << " : " << storeOp.value().getType();
2959   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2960 }
2961 
verify(spirv::StoreOp storeOp)2962 static LogicalResult verify(spirv::StoreOp storeOp) {
2963   // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
2964   // OpTypePointer whose Type operand is the same as the type of Object."
2965   if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
2966                                            storeOp.value()))) {
2967     return failure();
2968   }
2969   return verifyMemoryAccessAttribute(storeOp);
2970 }
2971 
2972 //===----------------------------------------------------------------------===//
2973 // spv.Unreachable
2974 //===----------------------------------------------------------------------===//
2975 
verify(spirv::UnreachableOp unreachableOp)2976 static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
2977   auto *op = unreachableOp.getOperation();
2978   auto *block = op->getBlock();
2979   // Fast track: if this is in entry block, its invalid. Otherwise, if no
2980   // predecessors, it's valid.
2981   if (block->isEntryBlock())
2982     return unreachableOp.emitOpError("cannot be used in reachable block");
2983   if (block->hasNoPredecessors())
2984     return success();
2985 
2986   // TODO: further verification needs to analyze reachability from
2987   // the entry block.
2988 
2989   return success();
2990 }
2991 
2992 //===----------------------------------------------------------------------===//
2993 // spv.Variable
2994 //===----------------------------------------------------------------------===//
2995 
parseVariableOp(OpAsmParser & parser,OperationState & state)2996 static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
2997   // Parse optional initializer
2998   Optional<OpAsmParser::OperandType> initInfo;
2999   if (succeeded(parser.parseOptionalKeyword("init"))) {
3000     initInfo = OpAsmParser::OperandType();
3001     if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3002         parser.parseRParen())
3003       return failure();
3004   }
3005 
3006   if (parseVariableDecorations(parser, state)) {
3007     return failure();
3008   }
3009 
3010   // Parse result pointer type
3011   Type type;
3012   if (parser.parseColon())
3013     return failure();
3014   auto loc = parser.getCurrentLocation();
3015   if (parser.parseType(type))
3016     return failure();
3017 
3018   auto ptrType = type.dyn_cast<spirv::PointerType>();
3019   if (!ptrType)
3020     return parser.emitError(loc, "expected spv.ptr type");
3021   state.addTypes(ptrType);
3022 
3023   // Resolve the initializer operand
3024   if (initInfo) {
3025     if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3026                               state.operands))
3027       return failure();
3028   }
3029 
3030   auto attr = parser.getBuilder().getI32IntegerAttr(
3031       llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3032   state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3033 
3034   return success();
3035 }
3036 
print(spirv::VariableOp varOp,OpAsmPrinter & printer)3037 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
3038   SmallVector<StringRef, 4> elidedAttrs{
3039       spirv::attributeName<spirv::StorageClass>()};
3040   printer << spirv::VariableOp::getOperationName();
3041 
3042   // Print optional initializer
3043   if (varOp.getNumOperands() != 0)
3044     printer << " init(" << varOp.initializer() << ")";
3045 
3046   printVariableDecorations(varOp, printer, elidedAttrs);
3047   printer << " : " << varOp.getType();
3048 }
3049 
verify(spirv::VariableOp varOp)3050 static LogicalResult verify(spirv::VariableOp varOp) {
3051   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3052   // object. It cannot be Generic. It must be the same as the Storage Class
3053   // operand of the Result Type."
3054   if (varOp.storage_class() != spirv::StorageClass::Function) {
3055     return varOp.emitOpError(
3056         "can only be used to model function-level variables. Use "
3057         "spv.GlobalVariable for module-level variables.");
3058   }
3059 
3060   auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
3061   if (varOp.storage_class() != pointerType.getStorageClass())
3062     return varOp.emitOpError(
3063         "storage class must match result pointer's storage class");
3064 
3065   if (varOp.getNumOperands() != 0) {
3066     // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3067     // a global (module scope) OpVariable instruction".
3068     auto *initOp = varOp.getOperand(0).getDefiningOp();
3069     if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
3070                         spirv::ReferenceOfOp, // for spec constant
3071                         spirv::AddressOfOp>(initOp))
3072       return varOp.emitOpError("initializer must be the result of a "
3073                                "constant or spv.GlobalVariable op");
3074   }
3075 
3076   // TODO: generate these strings using ODS.
3077   auto *op = varOp.getOperation();
3078   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3079       stringifyDecoration(spirv::Decoration::DescriptorSet));
3080   auto bindingName = llvm::convertToSnakeFromCamelCase(
3081       stringifyDecoration(spirv::Decoration::Binding));
3082   auto builtInName = llvm::convertToSnakeFromCamelCase(
3083       stringifyDecoration(spirv::Decoration::BuiltIn));
3084 
3085   for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3086     if (op->getAttr(attr))
3087       return varOp.emitOpError("cannot have '")
3088              << attr << "' attribute (only allowed in spv.GlobalVariable)";
3089   }
3090 
3091   return success();
3092 }
3093 
3094 //===----------------------------------------------------------------------===//
3095 // spv.VectorShuffle
3096 //===----------------------------------------------------------------------===//
3097 
verify(spirv::VectorShuffleOp shuffleOp)3098 static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
3099   VectorType resultType = shuffleOp.getType().cast<VectorType>();
3100 
3101   size_t numResultElements = resultType.getNumElements();
3102   if (numResultElements != shuffleOp.components().size())
3103     return shuffleOp.emitOpError("result type element count (")
3104            << numResultElements
3105            << ") mismatch with the number of component selectors ("
3106            << shuffleOp.components().size() << ")";
3107 
3108   size_t totalSrcElements =
3109       shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
3110       shuffleOp.vector2().getType().cast<VectorType>().getNumElements();
3111 
3112   for (const auto &selector :
3113        shuffleOp.components().getAsValueRange<IntegerAttr>()) {
3114     uint32_t index = selector.getZExtValue();
3115     if (index >= totalSrcElements &&
3116         index != std::numeric_limits<uint32_t>().max())
3117       return shuffleOp.emitOpError("component selector ")
3118              << index << " out of range: expected to be in [0, "
3119              << totalSrcElements << ") or 0xffffffff";
3120   }
3121   return success();
3122 }
3123 
3124 //===----------------------------------------------------------------------===//
3125 // spv.CooperativeMatrixLoadNV
3126 //===----------------------------------------------------------------------===//
3127 
parseCooperativeMatrixLoadNVOp(OpAsmParser & parser,OperationState & state)3128 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
3129                                                   OperationState &state) {
3130   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
3131   Type strideType = parser.getBuilder().getIntegerType(32);
3132   Type columnMajorType = parser.getBuilder().getIntegerType(1);
3133   Type ptrType;
3134   Type elementType;
3135   if (parser.parseOperandList(operandInfo, 3) ||
3136       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3137       parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3138     return failure();
3139   }
3140   if (parser.resolveOperands(operandInfo,
3141                              {ptrType, strideType, columnMajorType},
3142                              parser.getNameLoc(), state.operands)) {
3143     return failure();
3144   }
3145 
3146   state.addTypes(elementType);
3147   return success();
3148 }
3149 
print(spirv::CooperativeMatrixLoadNVOp M,OpAsmPrinter & printer)3150 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
3151   printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
3152           << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
3153   // Print optional memory access attribute.
3154   if (auto memAccess = M.memory_access())
3155     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3156   printer << " : " << M.pointer().getType() << " as " << M.getType();
3157 }
3158 
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3159 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3160                                                     Type coopMatrix) {
3161   Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3162   if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3163     return op->emitError(
3164                "Pointer must point to a scalar or vector type but provided ")
3165            << pointeeType;
3166   spirv::StorageClass storage =
3167       pointer.cast<spirv::PointerType>().getStorageClass();
3168   if (storage != spirv::StorageClass::Workgroup &&
3169       storage != spirv::StorageClass::StorageBuffer &&
3170       storage != spirv::StorageClass::PhysicalStorageBuffer)
3171     return op->emitError(
3172                "Pointer storage class must be Workgroup, StorageBuffer or "
3173                "PhysicalStorageBufferEXT but provided ")
3174            << stringifyStorageClass(storage);
3175   return success();
3176 }
3177 
3178 //===----------------------------------------------------------------------===//
3179 // spv.CooperativeMatrixStoreNV
3180 //===----------------------------------------------------------------------===//
3181 
parseCooperativeMatrixStoreNVOp(OpAsmParser & parser,OperationState & state)3182 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
3183                                                    OperationState &state) {
3184   SmallVector<OpAsmParser::OperandType, 4> operandInfo;
3185   Type strideType = parser.getBuilder().getIntegerType(32);
3186   Type columnMajorType = parser.getBuilder().getIntegerType(1);
3187   Type ptrType;
3188   Type elementType;
3189   if (parser.parseOperandList(operandInfo, 4) ||
3190       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3191       parser.parseType(ptrType) || parser.parseComma() ||
3192       parser.parseType(elementType)) {
3193     return failure();
3194   }
3195   if (parser.resolveOperands(
3196           operandInfo, {ptrType, elementType, strideType, columnMajorType},
3197           parser.getNameLoc(), state.operands)) {
3198     return failure();
3199   }
3200 
3201   return success();
3202 }
3203 
print(spirv::CooperativeMatrixStoreNVOp coopMatrix,OpAsmPrinter & printer)3204 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
3205                   OpAsmPrinter &printer) {
3206   printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
3207           << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
3208           << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
3209   // Print optional memory access attribute.
3210   if (auto memAccess = coopMatrix.memory_access())
3211     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3212   printer << " : " << coopMatrix.pointer().getType() << ", "
3213           << coopMatrix.getOperand(1).getType();
3214 }
3215 
3216 //===----------------------------------------------------------------------===//
3217 // spv.CooperativeMatrixMulAddNV
3218 //===----------------------------------------------------------------------===//
3219 
3220 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3221 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3222   if (op.c().getType() != op.result().getType())
3223     return op.emitOpError("result and third operand must have the same type");
3224   auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3225   auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3226   auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3227   auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3228   if (typeA.getRows() != typeR.getRows() ||
3229       typeA.getColumns() != typeB.getRows() ||
3230       typeB.getColumns() != typeR.getColumns())
3231     return op.emitOpError("matrix size must match");
3232   if (typeR.getScope() != typeA.getScope() ||
3233       typeR.getScope() != typeB.getScope() ||
3234       typeR.getScope() != typeC.getScope())
3235     return op.emitOpError("matrix scope must match");
3236   if (typeA.getElementType() != typeB.getElementType() ||
3237       typeR.getElementType() != typeC.getElementType())
3238     return op.emitOpError("matrix element type must match");
3239   return success();
3240 }
3241 
3242 //===----------------------------------------------------------------------===//
3243 // spv.MatrixTimesScalar
3244 //===----------------------------------------------------------------------===//
3245 
verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op)3246 static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
3247   // We already checked that result and matrix are both of matrix type in the
3248   // auto-generated verify method.
3249 
3250   auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3251   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3252 
3253   // Check that the scalar type is the same as the matrix element type.
3254   if (op.scalar().getType() != inputMatrix.getElementType())
3255     return op.emitError("input matrix components' type and scaling value must "
3256                         "have the same type");
3257 
3258   // Note that the next three checks could be done using the AllTypesMatch
3259   // trait in the Op definition file but it generates a vague error message.
3260 
3261   // Check that the input and result matrices have the same columns' count
3262   if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3263     return op.emitError("input and result matrices must have the same "
3264                         "number of columns");
3265 
3266   // Check that the input and result matrices' have the same rows count
3267   if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3268     return op.emitError("input and result matrices' columns must have "
3269                         "the same size");
3270 
3271   // Check that the input and result matrices' have the same component type
3272   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3273     return op.emitError("input and result matrices' columns must have "
3274                         "the same component type");
3275 
3276   return success();
3277 }
3278 
3279 //===----------------------------------------------------------------------===//
3280 // spv.CopyMemory
3281 //===----------------------------------------------------------------------===//
3282 
print(spirv::CopyMemoryOp copyMemory,OpAsmPrinter & printer)3283 static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
3284   auto *op = copyMemory.getOperation();
3285   printer << spirv::CopyMemoryOp::getOperationName() << ' ';
3286 
3287   StringRef targetStorageClass =
3288       stringifyStorageClass(copyMemory.target()
3289                                 .getType()
3290                                 .cast<spirv::PointerType>()
3291                                 .getStorageClass());
3292   printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
3293           << ", ";
3294 
3295   StringRef sourceStorageClass =
3296       stringifyStorageClass(copyMemory.source()
3297                                 .getType()
3298                                 .cast<spirv::PointerType>()
3299                                 .getStorageClass());
3300   printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
3301 
3302   SmallVector<StringRef, 4> elidedAttrs;
3303   printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
3304   printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
3305                                    copyMemory.source_memory_access(),
3306                                    copyMemory.source_alignment());
3307 
3308   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3309 
3310   Type pointeeType =
3311       copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3312   printer << " : " << pointeeType;
3313 }
3314 
parseCopyMemoryOp(OpAsmParser & parser,OperationState & state)3315 static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
3316                                      OperationState &state) {
3317   spirv::StorageClass targetStorageClass;
3318   OpAsmParser::OperandType targetPtrInfo;
3319 
3320   spirv::StorageClass sourceStorageClass;
3321   OpAsmParser::OperandType sourcePtrInfo;
3322 
3323   Type elementType;
3324 
3325   if (parseEnumStrAttr(targetStorageClass, parser) ||
3326       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3327       parseEnumStrAttr(sourceStorageClass, parser) ||
3328       parser.parseOperand(sourcePtrInfo) ||
3329       parseMemoryAccessAttributes(parser, state)) {
3330     return failure();
3331   }
3332 
3333   if (!parser.parseOptionalComma()) {
3334     // Parse 2nd memory access attributes.
3335     if (parseSourceMemoryAccessAttributes(parser, state)) {
3336       return failure();
3337     }
3338   }
3339 
3340   if (parser.parseColon() || parser.parseType(elementType))
3341     return failure();
3342 
3343   if (parser.parseOptionalAttrDict(state.attributes))
3344     return failure();
3345 
3346   auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3347   auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3348 
3349   if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3350       parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3351     return failure();
3352   }
3353 
3354   return success();
3355 }
3356 
verifyCopyMemory(spirv::CopyMemoryOp copyMemory)3357 static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
3358   Type targetType =
3359       copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3360 
3361   Type sourceType =
3362       copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
3363 
3364   if (targetType != sourceType) {
3365     return copyMemory.emitOpError(
3366         "both operands must be pointers to the same type");
3367   }
3368 
3369   if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3370     return failure();
3371   }
3372 
3373   // TODO - According to the spec:
3374   //
3375   // If two masks are present, the first applies to Target and cannot include
3376   // MakePointerVisible, and the second applies to Source and cannot include
3377   // MakePointerAvailable.
3378   //
3379   // Add such verification here.
3380 
3381   return verifySourceMemoryAccessAttribute(copyMemory);
3382 }
3383 
3384 //===----------------------------------------------------------------------===//
3385 // spv.Transpose
3386 //===----------------------------------------------------------------------===//
3387 
verifyTranspose(spirv::TransposeOp op)3388 static LogicalResult verifyTranspose(spirv::TransposeOp op) {
3389   auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3390   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3391 
3392   // Verify that the input and output matrices have correct shapes.
3393   if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3394     return op.emitError("input matrix rows count must be equal to "
3395                         "output matrix columns count");
3396 
3397   if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3398     return op.emitError("input matrix columns count must be equal to "
3399                         "output matrix rows count");
3400 
3401   // Verify that the input and output matrices have the same component type
3402   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3403     return op.emitError("input and output matrices must have the same "
3404                         "component type");
3405 
3406   return success();
3407 }
3408 
3409 //===----------------------------------------------------------------------===//
3410 // spv.MatrixTimesMatrix
3411 //===----------------------------------------------------------------------===//
3412 
verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op)3413 static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
3414   auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
3415   auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
3416   auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3417 
3418   // left matrix columns' count and right matrix rows' count must be equal
3419   if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
3420     return op.emitError("left matrix columns' count must be equal to "
3421                         "the right matrix rows' count");
3422 
3423   // right and result matrices columns' count must be the same
3424   if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
3425     return op.emitError(
3426         "right and result matrices must have equal columns' count");
3427 
3428   // right and result matrices component type must be the same
3429   if (rightMatrix.getElementType() != resultMatrix.getElementType())
3430     return op.emitError("right and result matrices' component type must"
3431                         " be the same");
3432 
3433   // left and result matrices component type must be the same
3434   if (leftMatrix.getElementType() != resultMatrix.getElementType())
3435     return op.emitError("left and result matrices' component type"
3436                         " must be the same");
3437 
3438   // left and result matrices rows count must be the same
3439   if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
3440     return op.emitError("left and result matrices must have equal rows'"
3441                         " count");
3442 
3443   return success();
3444 }
3445 
3446 //===----------------------------------------------------------------------===//
3447 // spv.SpecConstantComposite
3448 //===----------------------------------------------------------------------===//
3449 
parseSpecConstantCompositeOp(OpAsmParser & parser,OperationState & state)3450 static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
3451                                                 OperationState &state) {
3452 
3453   StringAttr compositeName;
3454   if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
3455                              state.attributes))
3456     return failure();
3457 
3458   if (parser.parseLParen())
3459     return failure();
3460 
3461   SmallVector<Attribute, 4> constituents;
3462 
3463   do {
3464     // The name of the constituent attribute isn't important
3465     const char *attrName = "spec_const";
3466     FlatSymbolRefAttr specConstRef;
3467     NamedAttrList attrs;
3468 
3469     if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
3470       return failure();
3471 
3472     constituents.push_back(specConstRef);
3473   } while (!parser.parseOptionalComma());
3474 
3475   if (parser.parseRParen())
3476     return failure();
3477 
3478   state.addAttribute(kCompositeSpecConstituentsName,
3479                      parser.getBuilder().getArrayAttr(constituents));
3480 
3481   Type type;
3482   if (parser.parseColonType(type))
3483     return failure();
3484 
3485   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
3486 
3487   return success();
3488 }
3489 
print(spirv::SpecConstantCompositeOp op,OpAsmPrinter & printer)3490 static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
3491   printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
3492   printer.printSymbolName(op.sym_name());
3493   printer << " (";
3494   auto constituents = op.constituents().getValue();
3495 
3496   if (!constituents.empty())
3497     llvm::interleaveComma(constituents, printer);
3498 
3499   printer << ") : " << op.type();
3500 }
3501 
verify(spirv::SpecConstantCompositeOp constOp)3502 static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
3503   auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
3504   auto constituents = constOp.constituents().getValue();
3505 
3506   if (!cType)
3507     return constOp.emitError(
3508                "result type must be a composite type, but provided ")
3509            << constOp.type();
3510 
3511   if (cType.isa<spirv::CooperativeMatrixNVType>())
3512     return constOp.emitError("unsupported composite type  ") << cType;
3513   else if (constituents.size() != cType.getNumElements())
3514     return constOp.emitError("has incorrect number of operands: expected ")
3515            << cType.getNumElements() << ", but provided "
3516            << constituents.size();
3517 
3518   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
3519     auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
3520 
3521     auto constituentSpecConstOp =
3522         dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
3523             constOp->getParentOp(), constituent.getValue()));
3524 
3525     if (constituentSpecConstOp.default_value().getType() !=
3526         cType.getElementType(index))
3527       return constOp.emitError("has incorrect types of operands: expected ")
3528              << cType.getElementType(index) << ", but provided "
3529              << constituentSpecConstOp.default_value().getType();
3530   }
3531 
3532   return success();
3533 }
3534 
3535 //===----------------------------------------------------------------------===//
3536 // spv.SpecConstantOperation
3537 //===----------------------------------------------------------------------===//
3538 
parseSpecConstantOperationOp(OpAsmParser & parser,OperationState & state)3539 static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
3540                                                 OperationState &state) {
3541   Region *body = state.addRegion();
3542 
3543   if (parser.parseKeyword("wraps"))
3544     return failure();
3545 
3546   body->push_back(new Block);
3547   Block &block = body->back();
3548   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
3549 
3550   if (!wrappedOp)
3551     return failure();
3552 
3553   OpBuilder builder(parser.getBuilder().getContext());
3554   builder.setInsertionPointToEnd(&block);
3555   builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
3556   state.location = wrappedOp->getLoc();
3557 
3558   state.addTypes(wrappedOp->getResult(0).getType());
3559 
3560   if (parser.parseOptionalAttrDict(state.attributes))
3561     return failure();
3562 
3563   return success();
3564 }
3565 
print(spirv::SpecConstantOperationOp op,OpAsmPrinter & printer)3566 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
3567   printer << op.getOperationName() << " wraps ";
3568   printer.printGenericOp(&op.body().front().front());
3569 }
3570 
verify(spirv::SpecConstantOperationOp constOp)3571 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
3572   Block &block = constOp.getRegion().getBlocks().front();
3573 
3574   if (block.getOperations().size() != 2)
3575     return constOp.emitOpError("expected exactly 2 nested ops");
3576 
3577   Operation &enclosedOp = block.getOperations().front();
3578 
3579   if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
3580     return constOp.emitOpError("invalid enclosed op");
3581 
3582   for (auto operand : enclosedOp.getOperands())
3583     if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
3584              spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
3585       return constOp.emitOpError(
3586           "invalid operand, must be defined by a constant operation");
3587 
3588   return success();
3589 }
3590 
3591 //===----------------------------------------------------------------------===//
3592 // spv.GLSL.FrexpStruct
3593 //===----------------------------------------------------------------------===//
3594 static LogicalResult
verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp)3595 verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) {
3596   spirv::StructType structTy =
3597       frexpStructOp.result().getType().dyn_cast<spirv::StructType>();
3598 
3599   if (structTy.getNumElements() != 2)
3600     return frexpStructOp.emitError("result type must be a struct  type "
3601                                    "with two memebers");
3602 
3603   Type significandTy = structTy.getElementType(0);
3604   Type exponentTy = structTy.getElementType(1);
3605   VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
3606   IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
3607 
3608   Type operandTy = frexpStructOp.operand().getType();
3609   VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
3610   FloatType operandFTy = operandTy.dyn_cast<FloatType>();
3611 
3612   if (significandTy != operandTy)
3613     return frexpStructOp.emitError("member zero of the resulting struct type "
3614                                    "must be the same type as the operand");
3615 
3616   if (exponentVecTy) {
3617     IntegerType componentIntTy =
3618         exponentVecTy.getElementType().dyn_cast<IntegerType>();
3619     if (!(componentIntTy && componentIntTy.getWidth() == 32))
3620       return frexpStructOp.emitError(
3621           "member one of the resulting struct type must"
3622           "be a scalar or vector of 32 bit integer type");
3623   } else if (!(exponentIntTy && exponentIntTy.getWidth() == 32)) {
3624     return frexpStructOp.emitError(
3625         "member one of the resulting struct type "
3626         "must be a scalar or vector of 32 bit integer type");
3627   }
3628 
3629   // Check that the two member types have the same number of components
3630   if (operandVecTy && exponentVecTy &&
3631       (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
3632     return success();
3633 
3634   if (operandFTy && exponentIntTy)
3635     return success();
3636 
3637   return frexpStructOp.emitError(
3638       "member one of the resulting struct type "
3639       "must have the same number of components as the operand type");
3640 }
3641 
3642 //===----------------------------------------------------------------------===//
3643 // spv.GLSL.Ldexp
3644 //===----------------------------------------------------------------------===//
3645 
verify(spirv::GLSLLdexpOp ldexpOp)3646 static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) {
3647   Type significandType = ldexpOp.x().getType();
3648   Type exponentType = ldexpOp.exp().getType();
3649 
3650   if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
3651     return ldexpOp.emitOpError("operands must both be scalars or vectors");
3652 
3653   auto getNumElements = [](Type type) -> unsigned {
3654     if (auto vectorType = type.dyn_cast<VectorType>())
3655       return vectorType.getNumElements();
3656     return 1;
3657   };
3658 
3659   if (getNumElements(significandType) != getNumElements(exponentType))
3660     return ldexpOp.emitOpError(
3661         "operands must have the same number of elements");
3662 
3663   return success();
3664 }
3665 
3666 //===----------------------------------------------------------------------===//
3667 // spv.ImageDrefGather
3668 //===----------------------------------------------------------------------===//
3669 
verify(spirv::ImageDrefGatherOp imageDrefGatherOp)3670 static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) {
3671   // TODO: Support optional operands.
3672   VectorType resultType =
3673       imageDrefGatherOp.result().getType().cast<VectorType>();
3674   auto sampledImageType = imageDrefGatherOp.sampledimage()
3675                               .getType()
3676                               .cast<spirv::SampledImageType>();
3677   auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
3678 
3679   if (resultType.getNumElements() != 4)
3680     return imageDrefGatherOp.emitOpError(
3681         "result type must be a vector of four components");
3682 
3683   Type elementType = resultType.getElementType();
3684   Type sampledElementType = imageType.getElementType();
3685   if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
3686     return imageDrefGatherOp.emitOpError(
3687         "the component type of result must be the same as sampled type of the "
3688         "underlying image type");
3689 
3690   spirv::Dim imageDim = imageType.getDim();
3691   spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
3692 
3693   if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
3694       imageDim != spirv::Dim::Rect)
3695     return imageDrefGatherOp.emitOpError(
3696         "the Dim operand of the underlying image type must be 2D, Cube, or "
3697         "Rect");
3698 
3699   if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
3700     return imageDrefGatherOp.emitOpError(
3701         "the MS operand of the underlying image type must be 0");
3702 
3703   return success();
3704 }
3705 
3706 //===----------------------------------------------------------------------===//
3707 // spv.ImageQuerySize
3708 //===----------------------------------------------------------------------===//
3709 
verify(spirv::ImageQuerySizeOp imageQuerySizeOp)3710 static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
3711   spirv::ImageType imageType =
3712       imageQuerySizeOp.image().getType().cast<spirv::ImageType>();
3713   Type resultType = imageQuerySizeOp.result().getType();
3714 
3715   spirv::Dim dim = imageType.getDim();
3716   spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
3717   spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
3718   switch (dim) {
3719   case spirv::Dim::Dim1D:
3720   case spirv::Dim::Dim2D:
3721   case spirv::Dim::Dim3D:
3722   case spirv::Dim::Cube:
3723     if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled ||
3724           samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown ||
3725           samplerInfo == spirv::ImageSamplerUseInfo::NoSampler))
3726       return imageQuerySizeOp.emitError(
3727           "if Dim is 1D, 2D, 3D, or Cube, "
3728           "it must also have either an MS of 1 or a Sampled of 0 or 2");
3729     break;
3730   case spirv::Dim::Buffer:
3731   case spirv::Dim::Rect:
3732     break;
3733   default:
3734     return imageQuerySizeOp.emitError("the Dim operand of the image type must "
3735                                       "be 1D, 2D, 3D, Buffer, Cube, or Rect");
3736   }
3737 
3738   unsigned componentNumber = 0;
3739   switch (dim) {
3740   case spirv::Dim::Dim1D:
3741   case spirv::Dim::Buffer:
3742     componentNumber = 1;
3743     break;
3744   case spirv::Dim::Dim2D:
3745   case spirv::Dim::Cube:
3746   case spirv::Dim::Rect:
3747     componentNumber = 2;
3748     break;
3749   case spirv::Dim::Dim3D:
3750     componentNumber = 3;
3751     break;
3752   default:
3753     break;
3754   }
3755 
3756   if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
3757     componentNumber += 1;
3758 
3759   unsigned resultComponentNumber = 1;
3760   if (auto resultVectorType = resultType.dyn_cast<VectorType>())
3761     resultComponentNumber = resultVectorType.getNumElements();
3762 
3763   if (componentNumber != resultComponentNumber)
3764     return imageQuerySizeOp.emitError("expected the result to have ")
3765            << componentNumber << " component(s), but found "
3766            << resultComponentNumber << " component(s)";
3767 
3768   return success();
3769 }
3770 
3771 namespace mlir {
3772 namespace spirv {
3773 
3774 // TableGen'erated operation interfaces for querying versions, extensions, and
3775 // capabilities.
3776 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
3777 } // namespace spirv
3778 } // namespace mlir
3779 
3780 // TablenGen'erated operation definitions.
3781 #define GET_OP_CLASSES
3782 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
3783 
3784 namespace mlir {
3785 namespace spirv {
3786 // TableGen'erated operation availability interface implementations.
3787 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
3788 
3789 } // namespace spirv
3790 } // namespace mlir
3791