1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/FunctionImplementation.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Interfaces/CallInterfaces.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/bit.h"
33
34 using namespace mlir;
35
36 // TODO: generate these strings using ODS.
37 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
38 static constexpr const char kSourceMemoryAccessAttrName[] =
39 "source_memory_access";
40 static constexpr const char kAlignmentAttrName[] = "alignment";
41 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
42 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
43 static constexpr const char kCallee[] = "callee";
44 static constexpr const char kClusterSize[] = "cluster_size";
45 static constexpr const char kControl[] = "control";
46 static constexpr const char kDefaultValueAttrName[] = "default_value";
47 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
48 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
49 static constexpr const char kFnNameAttrName[] = "fn";
50 static constexpr const char kGroupOperationAttrName[] = "group_operation";
51 static constexpr const char kIndicesAttrName[] = "indices";
52 static constexpr const char kInitializerAttrName[] = "initializer";
53 static constexpr const char kInterfaceAttrName[] = "interface";
54 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
55 static constexpr const char kSemanticsAttrName[] = "semantics";
56 static constexpr const char kSpecIdAttrName[] = "spec_id";
57 static constexpr const char kTypeAttrName[] = "type";
58 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
59 static constexpr const char kValueAttrName[] = "value";
60 static constexpr const char kValuesAttrName[] = "values";
61 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
62
63 //===----------------------------------------------------------------------===//
64 // Common utility functions
65 //===----------------------------------------------------------------------===//
66
67 /// Returns true if the given op is a function-like op or nested in a
68 /// function-like op without a module-like op in the middle.
isNestedInFunctionLikeOp(Operation * op)69 static bool isNestedInFunctionLikeOp(Operation *op) {
70 if (!op)
71 return false;
72 if (op->hasTrait<OpTrait::SymbolTable>())
73 return false;
74 if (op->hasTrait<OpTrait::FunctionLike>())
75 return true;
76 return isNestedInFunctionLikeOp(op->getParentOp());
77 }
78
79 /// Returns true if the given op is an module-like op that maintains a symbol
80 /// table.
isDirectInModuleLikeOp(Operation * op)81 static bool isDirectInModuleLikeOp(Operation *op) {
82 return op && op->hasTrait<OpTrait::SymbolTable>();
83 }
84
extractValueFromConstOp(Operation * op,int32_t & value)85 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
86 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
87 if (!constOp) {
88 return failure();
89 }
90 auto valueAttr = constOp.value();
91 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
92 if (!integerValueAttr) {
93 return failure();
94 }
95 value = integerValueAttr.getInt();
96 return success();
97 }
98
99 template <typename Ty>
100 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)101 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
102 function_ref<StringRef(Ty)> stringifyFn) {
103 if (enumValues.empty()) {
104 return nullptr;
105 }
106 SmallVector<StringRef, 1> enumValStrs;
107 enumValStrs.reserve(enumValues.size());
108 for (auto val : enumValues) {
109 enumValStrs.emplace_back(stringifyFn(val));
110 }
111 return builder.getStrArrayAttr(enumValStrs);
112 }
113
114 /// Parses the next string attribute in `parser` as an enumerant of the given
115 /// `EnumClass`.
116 template <typename EnumClass>
117 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())118 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
119 StringRef attrName = spirv::attributeName<EnumClass>()) {
120 Attribute attrVal;
121 NamedAttrList attr;
122 auto loc = parser.getCurrentLocation();
123 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
124 attrName, attr)) {
125 return failure();
126 }
127 if (!attrVal.isa<StringAttr>()) {
128 return parser.emitError(loc, "expected ")
129 << attrName << " attribute specified as string";
130 }
131 auto attrOptional =
132 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
133 if (!attrOptional) {
134 return parser.emitError(loc, "invalid ")
135 << attrName << " attribute specification: " << attrVal;
136 }
137 value = attrOptional.getValue();
138 return success();
139 }
140
141 /// Parses the next string attribute in `parser` as an enumerant of the given
142 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
143 /// attribute with the enum class's name as attribute name.
144 template <typename EnumClass>
145 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())146 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
147 StringRef attrName = spirv::attributeName<EnumClass>()) {
148 if (parseEnumStrAttr(value, parser)) {
149 return failure();
150 }
151 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
152 llvm::bit_cast<int32_t>(value)));
153 return success();
154 }
155
156 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
157 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
158 /// the enum class's name as attribute name.
159 template <typename EnumClass>
160 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())161 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
162 OperationState &state,
163 StringRef attrName = spirv::attributeName<EnumClass>()) {
164 if (parseEnumKeywordAttr(value, parser)) {
165 return failure();
166 }
167 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
168 llvm::bit_cast<int32_t>(value)));
169 return success();
170 }
171
172 /// Parses Function, Selection and Loop control attributes. If no control is
173 /// specified, "None" is used as a default.
174 template <typename EnumClass>
175 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())176 parseControlAttribute(OpAsmParser &parser, OperationState &state,
177 StringRef attrName = spirv::attributeName<EnumClass>()) {
178 if (succeeded(parser.parseOptionalKeyword(kControl))) {
179 EnumClass control;
180 if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
181 parser.parseRParen())
182 return failure();
183 return success();
184 }
185 // Set control to "None" otherwise.
186 Builder builder = parser.getBuilder();
187 state.addAttribute(attrName, builder.getI32IntegerAttr(0));
188 return success();
189 }
190
191 /// Parses optional memory access attributes attached to a memory access
192 /// operand/pointer. Specifically, parses the following syntax:
193 /// (`[` memory-access `]`)?
194 /// where:
195 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
196 /// integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)197 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
198 OperationState &state) {
199 // Parse an optional list of attributes staring with '['
200 if (parser.parseOptionalLSquare()) {
201 // Nothing to do
202 return success();
203 }
204
205 spirv::MemoryAccess memoryAccessAttr;
206 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
207 kMemoryAccessAttrName)) {
208 return failure();
209 }
210
211 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
212 // Parse integer attribute for alignment.
213 Attribute alignmentAttr;
214 Type i32Type = parser.getBuilder().getIntegerType(32);
215 if (parser.parseComma() ||
216 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
217 state.attributes)) {
218 return failure();
219 }
220 }
221 return parser.parseRSquare();
222 }
223
224 // TODO Make sure to merge this and the previous function into one template
225 // parameterized by memory access attribute name and alignment. Doing so now
226 // results in VS2017 in producing an internal error (at the call site) that's
227 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)228 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
229 OperationState &state) {
230 // Parse an optional list of attributes staring with '['
231 if (parser.parseOptionalLSquare()) {
232 // Nothing to do
233 return success();
234 }
235
236 spirv::MemoryAccess memoryAccessAttr;
237 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
238 kSourceMemoryAccessAttrName)) {
239 return failure();
240 }
241
242 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
243 // Parse integer attribute for alignment.
244 Attribute alignmentAttr;
245 Type i32Type = parser.getBuilder().getIntegerType(32);
246 if (parser.parseComma() ||
247 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
248 state.attributes)) {
249 return failure();
250 }
251 }
252 return parser.parseRSquare();
253 }
254
255 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)256 static void printMemoryAccessAttribute(
257 MemoryOpTy memoryOp, OpAsmPrinter &printer,
258 SmallVectorImpl<StringRef> &elidedAttrs,
259 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
260 Optional<uint32_t> alignmentAttrValue = None) {
261 // Print optional memory access attribute.
262 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
263 : memoryOp.memory_access())) {
264 elidedAttrs.push_back(kMemoryAccessAttrName);
265
266 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
267
268 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
269 // Print integer alignment attribute.
270 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
271 : memoryOp.alignment())) {
272 elidedAttrs.push_back(kAlignmentAttrName);
273 printer << ", " << alignment;
274 }
275 }
276 printer << "]";
277 }
278 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
279 }
280
281 // TODO Make sure to merge this and the previous function into one template
282 // parameterized by memory access attribute name and alignment. Doing so now
283 // results in VS2017 in producing an internal error (at the call site) that's
284 // not detailed enough to understand what is happening.
285 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)286 static void printSourceMemoryAccessAttribute(
287 MemoryOpTy memoryOp, OpAsmPrinter &printer,
288 SmallVectorImpl<StringRef> &elidedAttrs,
289 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
290 Optional<uint32_t> alignmentAttrValue = None) {
291
292 printer << ", ";
293
294 // Print optional memory access attribute.
295 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
296 : memoryOp.memory_access())) {
297 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
298
299 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
300
301 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
302 // Print integer alignment attribute.
303 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
304 : memoryOp.alignment())) {
305 elidedAttrs.push_back(kSourceAlignmentAttrName);
306 printer << ", " << alignment;
307 }
308 }
309 printer << "]";
310 }
311 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
312 }
313
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)314 static LogicalResult verifyCastOp(Operation *op,
315 bool requireSameBitWidth = true,
316 bool skipBitWidthCheck = false) {
317 // Some CastOps have no limit on bit widths for result and operand type.
318 if (skipBitWidthCheck)
319 return success();
320
321 Type operandType = op->getOperand(0).getType();
322 Type resultType = op->getResult(0).getType();
323
324 // ODS checks that result type and operand type have the same shape.
325 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
326 operandType = vectorType.getElementType();
327 resultType = resultType.cast<VectorType>().getElementType();
328 }
329
330 if (auto coopMatrixType =
331 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
332 operandType = coopMatrixType.getElementType();
333 resultType =
334 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
335 }
336
337 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
338 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
339 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
340
341 if (requireSameBitWidth) {
342 if (!isSameBitWidth) {
343 return op->emitOpError(
344 "expected the same bit widths for operand type and result "
345 "type, but provided ")
346 << operandType << " and " << resultType;
347 }
348 return success();
349 }
350
351 if (isSameBitWidth) {
352 return op->emitOpError(
353 "expected the different bit widths for operand type and result "
354 "type, but provided ")
355 << operandType << " and " << resultType;
356 }
357 return success();
358 }
359
360 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)361 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
362 // ODS checks for attributes values. Just need to verify that if the
363 // memory-access attribute is Aligned, then the alignment attribute must be
364 // present.
365 auto *op = memoryOp.getOperation();
366 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
367 if (!memAccessAttr) {
368 // Alignment attribute shouldn't be present if memory access attribute is
369 // not present.
370 if (op->getAttr(kAlignmentAttrName)) {
371 return memoryOp.emitOpError(
372 "invalid alignment specification without aligned memory access "
373 "specification");
374 }
375 return success();
376 }
377
378 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
379 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
380
381 if (!memAccess) {
382 return memoryOp.emitOpError("invalid memory access specifier: ")
383 << memAccessVal;
384 }
385
386 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
387 if (!op->getAttr(kAlignmentAttrName)) {
388 return memoryOp.emitOpError("missing alignment value");
389 }
390 } else {
391 if (op->getAttr(kAlignmentAttrName)) {
392 return memoryOp.emitOpError(
393 "invalid alignment specification with non-aligned memory access "
394 "specification");
395 }
396 }
397 return success();
398 }
399
400 // TODO Make sure to merge this and the previous function into one template
401 // parameterized by memory access attribute name and alignment. Doing so now
402 // results in VS2017 in producing an internal error (at the call site) that's
403 // not detailed enough to understand what is happening.
404 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)405 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
406 // ODS checks for attributes values. Just need to verify that if the
407 // memory-access attribute is Aligned, then the alignment attribute must be
408 // present.
409 auto *op = memoryOp.getOperation();
410 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
411 if (!memAccessAttr) {
412 // Alignment attribute shouldn't be present if memory access attribute is
413 // not present.
414 if (op->getAttr(kSourceAlignmentAttrName)) {
415 return memoryOp.emitOpError(
416 "invalid alignment specification without aligned memory access "
417 "specification");
418 }
419 return success();
420 }
421
422 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
423 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
424
425 if (!memAccess) {
426 return memoryOp.emitOpError("invalid memory access specifier: ")
427 << memAccessVal;
428 }
429
430 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
431 if (!op->getAttr(kSourceAlignmentAttrName)) {
432 return memoryOp.emitOpError("missing alignment value");
433 }
434 } else {
435 if (op->getAttr(kSourceAlignmentAttrName)) {
436 return memoryOp.emitOpError(
437 "invalid alignment specification with non-aligned memory access "
438 "specification");
439 }
440 }
441 return success();
442 }
443
444 template <typename BarrierOp>
verifyMemorySemantics(BarrierOp op)445 static LogicalResult verifyMemorySemantics(BarrierOp op) {
446 // According to the SPIR-V specification:
447 // "Despite being a mask and allowing multiple bits to be combined, it is
448 // invalid for more than one of these four bits to be set: Acquire, Release,
449 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
450 // Release semantics is done by setting the AcquireRelease bit, not by setting
451 // two bits."
452 auto memorySemantics = op.memory_semantics();
453 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
454 spirv::MemorySemantics::Release |
455 spirv::MemorySemantics::AcquireRelease |
456 spirv::MemorySemantics::SequentiallyConsistent;
457
458 auto bitCount = llvm::countPopulation(
459 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
460 if (bitCount > 1) {
461 return op.emitError("expected at most one of these four memory constraints "
462 "to be set: `Acquire`, `Release`,"
463 "`AcquireRelease` or `SequentiallyConsistent`");
464 }
465 return success();
466 }
467
468 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)469 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
470 Value val) {
471 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
472 // type of the pointer and the type of the value are the same
473 //
474 // TODO: Check that the value type satisfies restrictions of
475 // SPIR-V OpLoad/OpStore operations
476 if (val.getType() !=
477 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
478 return op.emitOpError("mismatch in result type and pointer type");
479 }
480 return success();
481 }
482
483 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)484 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
485 Value ptr, Value val) {
486 auto valType = val.getType();
487 if (auto valVecTy = valType.dyn_cast<VectorType>())
488 valType = valVecTy.getElementType();
489
490 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
491 return op.emitOpError("mismatch in result type and pointer type");
492 }
493 return success();
494 }
495
parseVariableDecorations(OpAsmParser & parser,OperationState & state)496 static ParseResult parseVariableDecorations(OpAsmParser &parser,
497 OperationState &state) {
498 auto builtInName = llvm::convertToSnakeFromCamelCase(
499 stringifyDecoration(spirv::Decoration::BuiltIn));
500 if (succeeded(parser.parseOptionalKeyword("bind"))) {
501 Attribute set, binding;
502 // Parse optional descriptor binding
503 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
504 stringifyDecoration(spirv::Decoration::DescriptorSet));
505 auto bindingName = llvm::convertToSnakeFromCamelCase(
506 stringifyDecoration(spirv::Decoration::Binding));
507 Type i32Type = parser.getBuilder().getIntegerType(32);
508 if (parser.parseLParen() ||
509 parser.parseAttribute(set, i32Type, descriptorSetName,
510 state.attributes) ||
511 parser.parseComma() ||
512 parser.parseAttribute(binding, i32Type, bindingName,
513 state.attributes) ||
514 parser.parseRParen()) {
515 return failure();
516 }
517 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
518 StringAttr builtIn;
519 if (parser.parseLParen() ||
520 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
521 parser.parseRParen()) {
522 return failure();
523 }
524 }
525
526 // Parse other attributes
527 if (parser.parseOptionalAttrDict(state.attributes))
528 return failure();
529
530 return success();
531 }
532
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)533 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
534 SmallVectorImpl<StringRef> &elidedAttrs) {
535 // Print optional descriptor binding
536 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
537 stringifyDecoration(spirv::Decoration::DescriptorSet));
538 auto bindingName = llvm::convertToSnakeFromCamelCase(
539 stringifyDecoration(spirv::Decoration::Binding));
540 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
541 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
542 if (descriptorSet && binding) {
543 elidedAttrs.push_back(descriptorSetName);
544 elidedAttrs.push_back(bindingName);
545 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
546 << ")";
547 }
548
549 // Print BuiltIn attribute if present
550 auto builtInName = llvm::convertToSnakeFromCamelCase(
551 stringifyDecoration(spirv::Decoration::BuiltIn));
552 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
553 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
554 elidedAttrs.push_back(builtInName);
555 }
556
557 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
558 }
559
560 // Get bit width of types.
getBitWidth(Type type)561 static unsigned getBitWidth(Type type) {
562 if (type.isa<spirv::PointerType>()) {
563 // Just return 64 bits for pointer types for now.
564 // TODO: Make sure not caller relies on the actual pointer width value.
565 return 64;
566 }
567
568 if (type.isIntOrFloat())
569 return type.getIntOrFloatBitWidth();
570
571 if (auto vectorType = type.dyn_cast<VectorType>()) {
572 assert(vectorType.getElementType().isIntOrFloat());
573 return vectorType.getNumElements() *
574 vectorType.getElementType().getIntOrFloatBitWidth();
575 }
576 llvm_unreachable("unhandled bit width computation for type");
577 }
578
579 /// Walks the given type hierarchy with the given indices, potentially down
580 /// to component granularity, to select an element type. Returns null type and
581 /// emits errors with the given loc on failure.
582 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)583 getElementType(Type type, ArrayRef<int32_t> indices,
584 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
585 if (indices.empty()) {
586 emitErrorFn("expected at least one index for spv.CompositeExtract");
587 return nullptr;
588 }
589
590 for (auto index : indices) {
591 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
592 if (cType.hasCompileTimeKnownNumElements() &&
593 (index < 0 ||
594 static_cast<uint64_t>(index) >= cType.getNumElements())) {
595 emitErrorFn("index ") << index << " out of bounds for " << type;
596 return nullptr;
597 }
598 type = cType.getElementType(index);
599 } else {
600 emitErrorFn("cannot extract from non-composite type ")
601 << type << " with index " << index;
602 return nullptr;
603 }
604 }
605 return type;
606 }
607
608 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)609 getElementType(Type type, Attribute indices,
610 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
611 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
612 if (!indicesArrayAttr) {
613 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
614 return nullptr;
615 }
616 if (!indicesArrayAttr.size()) {
617 emitErrorFn("expected at least one index for spv.CompositeExtract");
618 return nullptr;
619 }
620
621 SmallVector<int32_t, 2> indexVals;
622 for (auto indexAttr : indicesArrayAttr) {
623 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
624 if (!indexIntAttr) {
625 emitErrorFn("expected an 32-bit integer for index, but found '")
626 << indexAttr << "'";
627 return nullptr;
628 }
629 indexVals.push_back(indexIntAttr.getInt());
630 }
631 return getElementType(type, indexVals, emitErrorFn);
632 }
633
getElementType(Type type,Attribute indices,Location loc)634 static Type getElementType(Type type, Attribute indices, Location loc) {
635 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
636 return ::mlir::emitError(loc, err);
637 };
638 return getElementType(type, indices, errorFn);
639 }
640
getElementType(Type type,Attribute indices,OpAsmParser & parser,llvm::SMLoc loc)641 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
642 llvm::SMLoc loc) {
643 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
644 return parser.emitError(loc, err);
645 };
646 return getElementType(type, indices, errorFn);
647 }
648
649 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)650 static inline bool isMergeBlock(Block &block) {
651 return !block.empty() && std::next(block.begin()) == block.end() &&
652 isa<spirv::MergeOp>(block.front());
653 }
654
655 //===----------------------------------------------------------------------===//
656 // Common parsers and printers
657 //===----------------------------------------------------------------------===//
658
659 // Parses an atomic update op. If the update op does not take a value (like
660 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)661 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
662 OperationState &state, bool hasValue) {
663 spirv::Scope scope;
664 spirv::MemorySemantics memoryScope;
665 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
666 OpAsmParser::OperandType ptrInfo, valueInfo;
667 Type type;
668 llvm::SMLoc loc;
669 if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
670 parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
671 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
672 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
673 return failure();
674
675 auto ptrType = type.dyn_cast<spirv::PointerType>();
676 if (!ptrType)
677 return parser.emitError(loc, "expected pointer type");
678
679 SmallVector<Type, 2> operandTypes;
680 operandTypes.push_back(ptrType);
681 if (hasValue)
682 operandTypes.push_back(ptrType.getPointeeType());
683 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
684 state.operands))
685 return failure();
686 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
687 }
688
689 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)690 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
691 printer << op->getName() << " \"";
692 auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
693 printer << spirv::stringifyScope(
694 static_cast<spirv::Scope>(scopeAttr.getInt()))
695 << "\" \"";
696 auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
697 printer << spirv::stringifyMemorySemantics(
698 static_cast<spirv::MemorySemantics>(
699 memorySemanticsAttr.getInt()))
700 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
701 }
702
703 // Verifies an atomic update op.
verifyAtomicUpdateOp(Operation * op)704 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
705 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
706 auto elementType = ptrType.getPointeeType();
707 if (!elementType.isa<IntegerType>())
708 return op->emitOpError(
709 "pointer operand must point to an integer value, found ")
710 << elementType;
711
712 if (op->getNumOperands() > 1) {
713 auto valueType = op->getOperand(1).getType();
714 if (valueType != elementType)
715 return op->emitOpError("expected value to have the same type as the "
716 "pointer operand's pointee type ")
717 << elementType << ", but found " << valueType;
718 }
719 return success();
720 }
721
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)722 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
723 OperationState &state) {
724 spirv::Scope executionScope;
725 spirv::GroupOperation groupOperation;
726 OpAsmParser::OperandType valueInfo;
727 if (parseEnumStrAttr(executionScope, parser, state,
728 kExecutionScopeAttrName) ||
729 parseEnumStrAttr(groupOperation, parser, state,
730 kGroupOperationAttrName) ||
731 parser.parseOperand(valueInfo))
732 return failure();
733
734 Optional<OpAsmParser::OperandType> clusterSizeInfo;
735 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
736 clusterSizeInfo = OpAsmParser::OperandType();
737 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
738 parser.parseRParen())
739 return failure();
740 }
741
742 Type resultType;
743 if (parser.parseColonType(resultType))
744 return failure();
745
746 if (parser.resolveOperand(valueInfo, resultType, state.operands))
747 return failure();
748
749 if (clusterSizeInfo.hasValue()) {
750 Type i32Type = parser.getBuilder().getIntegerType(32);
751 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
752 return failure();
753 }
754
755 return parser.addTypeToList(resultType, state.types);
756 }
757
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)758 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
759 OpAsmPrinter &printer) {
760 printer << groupOp->getName() << " \""
761 << stringifyScope(static_cast<spirv::Scope>(
762 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
763 .getInt()))
764 << "\" \""
765 << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
766 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
767 .getInt()))
768 << "\" " << groupOp->getOperand(0);
769
770 if (groupOp->getNumOperands() > 1)
771 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
772 printer << " : " << groupOp->getResult(0).getType();
773 }
774
verifyGroupNonUniformArithmeticOp(Operation * groupOp)775 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
776 spirv::Scope scope = static_cast<spirv::Scope>(
777 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
778 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
779 return groupOp->emitOpError(
780 "execution scope must be 'Workgroup' or 'Subgroup'");
781
782 spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
783 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
784 if (operation == spirv::GroupOperation::ClusteredReduce &&
785 groupOp->getNumOperands() == 1)
786 return groupOp->emitOpError("cluster size operand must be provided for "
787 "'ClusteredReduce' group operation");
788 if (groupOp->getNumOperands() > 1) {
789 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
790 int32_t clusterSize = 0;
791
792 // TODO: support specialization constant here.
793 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
794 return groupOp->emitOpError(
795 "cluster size operand must come from a constant op");
796
797 if (!llvm::isPowerOf2_32(clusterSize))
798 return groupOp->emitOpError(
799 "cluster size operand must be a power of two");
800 }
801 return success();
802 }
803
parseUnaryOp(OpAsmParser & parser,OperationState & state)804 static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
805 OpAsmParser::OperandType operandInfo;
806 Type type;
807 if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
808 parser.resolveOperands(operandInfo, type, state.operands)) {
809 return failure();
810 }
811 state.addTypes(type);
812 return success();
813 }
814
printUnaryOp(Operation * unaryOp,OpAsmPrinter & printer)815 static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
816 printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
817 << unaryOp->getOperand(0).getType();
818 }
819
820 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Builder & builder,Type operandType)821 static Type getUnaryOpResultType(Builder &builder, Type operandType) {
822 Type resultType = builder.getIntegerType(1);
823 if (auto vecType = operandType.dyn_cast<VectorType>()) {
824 return VectorType::get(vecType.getNumElements(), resultType);
825 }
826 return resultType;
827 }
828
parseLogicalUnaryOp(OpAsmParser & parser,OperationState & state)829 static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
830 OperationState &state) {
831 OpAsmParser::OperandType operandInfo;
832 Type type;
833 if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
834 parser.resolveOperand(operandInfo, type, state.operands)) {
835 return failure();
836 }
837 state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
838 return success();
839 }
840
parseLogicalBinaryOp(OpAsmParser & parser,OperationState & result)841 static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
842 OperationState &result) {
843 SmallVector<OpAsmParser::OperandType, 2> ops;
844 Type type;
845 if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
846 parser.resolveOperands(ops, type, result.operands)) {
847 return failure();
848 }
849 result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
850 return success();
851 }
852
printLogicalOp(Operation * logicalOp,OpAsmPrinter & printer)853 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
854 printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
855 << logicalOp->getOperand(0).getType();
856 }
857
parseShiftOp(OpAsmParser & parser,OperationState & state)858 static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
859 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
860 Type baseType;
861 Type shiftType;
862 auto loc = parser.getCurrentLocation();
863
864 if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
865 parser.parseType(baseType) || parser.parseComma() ||
866 parser.parseType(shiftType) ||
867 parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
868 state.operands)) {
869 return failure();
870 }
871 state.addTypes(baseType);
872 return success();
873 }
874
printShiftOp(Operation * op,OpAsmPrinter & printer)875 static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
876 Value base = op->getOperand(0);
877 Value shift = op->getOperand(1);
878 printer << op->getName() << ' ' << base << ", " << shift << " : "
879 << base.getType() << ", " << shift.getType();
880 }
881
verifyShiftOp(Operation * op)882 static LogicalResult verifyShiftOp(Operation *op) {
883 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
884 return op->emitError("expected the same type for the first operand and "
885 "result, but provided ")
886 << op->getOperand(0).getType() << " and "
887 << op->getResult(0).getType();
888 }
889 return success();
890 }
891
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)892 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
893 Value lhs, Value rhs) {
894 assert(lhs.getType() == rhs.getType());
895
896 Type boolType = builder.getI1Type();
897 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
898 boolType = VectorType::get(vecType.getShape(), boolType);
899 state.addTypes(boolType);
900
901 state.addOperands({lhs, rhs});
902 }
903
buildLogicalUnaryOp(OpBuilder & builder,OperationState & state,Value value)904 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
905 Value value) {
906 Type boolType = builder.getI1Type();
907 if (auto vecType = value.getType().dyn_cast<VectorType>())
908 boolType = VectorType::get(vecType.getShape(), boolType);
909 state.addTypes(boolType);
910
911 state.addOperands(value);
912 }
913
914 //===----------------------------------------------------------------------===//
915 // spv.AccessChainOp
916 //===----------------------------------------------------------------------===//
917
getElementPtrType(Type type,ValueRange indices,Location baseLoc)918 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
919 auto ptrType = type.dyn_cast<spirv::PointerType>();
920 if (!ptrType) {
921 emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
922 "to composite type, but provided ")
923 << type;
924 return nullptr;
925 }
926
927 auto resultType = ptrType.getPointeeType();
928 auto resultStorageClass = ptrType.getStorageClass();
929 int32_t index = 0;
930
931 for (auto indexSSA : indices) {
932 auto cType = resultType.dyn_cast<spirv::CompositeType>();
933 if (!cType) {
934 emitError(baseLoc,
935 "'spv.AccessChain' op cannot extract from non-composite type ")
936 << resultType << " with index " << index;
937 return nullptr;
938 }
939 index = 0;
940 if (resultType.isa<spirv::StructType>()) {
941 Operation *op = indexSSA.getDefiningOp();
942 if (!op) {
943 emitError(baseLoc, "'spv.AccessChain' op index must be an "
944 "integer spv.Constant to access "
945 "element of spv.struct");
946 return nullptr;
947 }
948
949 // TODO: this should be relaxed to allow
950 // integer literals of other bitwidths.
951 if (failed(extractValueFromConstOp(op, index))) {
952 emitError(baseLoc,
953 "'spv.AccessChain' index must be an integer spv.Constant to "
954 "access element of spv.struct, but provided ")
955 << op->getName();
956 return nullptr;
957 }
958 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
959 emitError(baseLoc, "'spv.AccessChain' op index ")
960 << index << " out of bounds for " << resultType;
961 return nullptr;
962 }
963 }
964 resultType = cType.getElementType(index);
965 }
966 return spirv::PointerType::get(resultType, resultStorageClass);
967 }
968
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)969 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
970 Value basePtr, ValueRange indices) {
971 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
972 assert(type && "Unable to deduce return type based on basePtr and indices");
973 build(builder, state, type, basePtr, indices);
974 }
975
parseAccessChainOp(OpAsmParser & parser,OperationState & state)976 static ParseResult parseAccessChainOp(OpAsmParser &parser,
977 OperationState &state) {
978 OpAsmParser::OperandType ptrInfo;
979 SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
980 Type type;
981 auto loc = parser.getCurrentLocation();
982 SmallVector<Type, 4> indicesTypes;
983
984 if (parser.parseOperand(ptrInfo) ||
985 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
986 parser.parseColonType(type) ||
987 parser.resolveOperand(ptrInfo, type, state.operands)) {
988 return failure();
989 }
990
991 // Check that the provided indices list is not empty before parsing their
992 // type list.
993 if (indicesInfo.empty()) {
994 return emitError(state.location, "'spv.AccessChain' op expected at "
995 "least one index ");
996 }
997
998 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
999 return failure();
1000
1001 // Check that the indices types list is not empty and that it has a one-to-one
1002 // mapping to the provided indices.
1003 if (indicesTypes.size() != indicesInfo.size()) {
1004 return emitError(state.location, "'spv.AccessChain' op indices "
1005 "types' count must be equal to indices "
1006 "info count");
1007 }
1008
1009 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1010 return failure();
1011
1012 auto resultType = getElementPtrType(
1013 type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1014 if (!resultType) {
1015 return failure();
1016 }
1017
1018 state.addTypes(resultType);
1019 return success();
1020 }
1021
print(spirv::AccessChainOp op,OpAsmPrinter & printer)1022 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
1023 printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
1024 << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
1025 << op.indices().getTypes();
1026 }
1027
verify(spirv::AccessChainOp accessChainOp)1028 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
1029 SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
1030 accessChainOp.indices().end());
1031 auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1032 indices, accessChainOp.getLoc());
1033 if (!resultType) {
1034 return failure();
1035 }
1036
1037 auto providedResultType =
1038 accessChainOp.getType().dyn_cast<spirv::PointerType>();
1039 if (!providedResultType) {
1040 return accessChainOp.emitOpError(
1041 "result type must be a pointer, but provided")
1042 << providedResultType;
1043 }
1044
1045 if (resultType != providedResultType) {
1046 return accessChainOp.emitOpError("invalid result type: expected ")
1047 << resultType << ", but provided " << providedResultType;
1048 }
1049
1050 return success();
1051 }
1052
1053 //===----------------------------------------------------------------------===//
1054 // spv.mlir.addressof
1055 //===----------------------------------------------------------------------===//
1056
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1057 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1058 spirv::GlobalVariableOp var) {
1059 build(builder, state, var.type(), builder.getSymbolRefAttr(var));
1060 }
1061
verify(spirv::AddressOfOp addressOfOp)1062 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
1063 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1064 SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
1065 addressOfOp.variable()));
1066 if (!varOp) {
1067 return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
1068 }
1069 if (addressOfOp.pointer().getType() != varOp.type()) {
1070 return addressOfOp.emitOpError(
1071 "result type mismatch with the referenced global variable's type");
1072 }
1073 return success();
1074 }
1075
1076 //===----------------------------------------------------------------------===//
1077 // spv.AtomicCompareExchangeWeak
1078 //===----------------------------------------------------------------------===//
1079
parseAtomicCompareExchangeWeakOp(OpAsmParser & parser,OperationState & state)1080 static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
1081 OperationState &state) {
1082 spirv::Scope memoryScope;
1083 spirv::MemorySemantics equalSemantics, unequalSemantics;
1084 SmallVector<OpAsmParser::OperandType, 3> operandInfo;
1085 Type type;
1086 if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1087 parseEnumStrAttr(equalSemantics, parser, state,
1088 kEqualSemanticsAttrName) ||
1089 parseEnumStrAttr(unequalSemantics, parser, state,
1090 kUnequalSemanticsAttrName) ||
1091 parser.parseOperandList(operandInfo, 3))
1092 return failure();
1093
1094 auto loc = parser.getCurrentLocation();
1095 if (parser.parseColonType(type))
1096 return failure();
1097
1098 auto ptrType = type.dyn_cast<spirv::PointerType>();
1099 if (!ptrType)
1100 return parser.emitError(loc, "expected pointer type");
1101
1102 if (parser.resolveOperands(
1103 operandInfo,
1104 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1105 parser.getNameLoc(), state.operands))
1106 return failure();
1107
1108 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1109 }
1110
print(spirv::AtomicCompareExchangeWeakOp atomOp,OpAsmPrinter & printer)1111 static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
1112 OpAsmPrinter &printer) {
1113 printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
1114 << stringifyScope(atomOp.memory_scope()) << "\" \""
1115 << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1116 << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1117 << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1118 }
1119
verify(spirv::AtomicCompareExchangeWeakOp atomOp)1120 static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
1121 // According to the spec:
1122 // "The type of Value must be the same as Result Type. The type of the value
1123 // pointed to by Pointer must be the same as Result Type. This type must also
1124 // match the type of Comparator."
1125 if (atomOp.getType() != atomOp.value().getType())
1126 return atomOp.emitOpError("value operand must have the same type as the op "
1127 "result, but found ")
1128 << atomOp.value().getType() << " vs " << atomOp.getType();
1129
1130 if (atomOp.getType() != atomOp.comparator().getType())
1131 return atomOp.emitOpError(
1132 "comparator operand must have the same type as the op "
1133 "result, but found ")
1134 << atomOp.comparator().getType() << " vs " << atomOp.getType();
1135
1136 Type pointeeType =
1137 atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
1138 if (atomOp.getType() != pointeeType)
1139 return atomOp.emitOpError(
1140 "pointer operand's pointee type must have the same "
1141 "as the op result type, but found ")
1142 << pointeeType << " vs " << atomOp.getType();
1143
1144 // TODO: Unequal cannot be set to Release or Acquire and Release.
1145 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1146
1147 return success();
1148 }
1149
1150 //===----------------------------------------------------------------------===//
1151 // spv.BitcastOp
1152 //===----------------------------------------------------------------------===//
1153
verify(spirv::BitcastOp bitcastOp)1154 static LogicalResult verify(spirv::BitcastOp bitcastOp) {
1155 // TODO: The SPIR-V spec validation rules are different for different
1156 // versions.
1157 auto operandType = bitcastOp.operand().getType();
1158 auto resultType = bitcastOp.result().getType();
1159 if (operandType == resultType) {
1160 return bitcastOp.emitError(
1161 "result type must be different from operand type");
1162 }
1163 if (operandType.isa<spirv::PointerType>() &&
1164 !resultType.isa<spirv::PointerType>()) {
1165 return bitcastOp.emitError(
1166 "unhandled bit cast conversion from pointer type to non-pointer type");
1167 }
1168 if (!operandType.isa<spirv::PointerType>() &&
1169 resultType.isa<spirv::PointerType>()) {
1170 return bitcastOp.emitError(
1171 "unhandled bit cast conversion from non-pointer type to pointer type");
1172 }
1173 auto operandBitWidth = getBitWidth(operandType);
1174 auto resultBitWidth = getBitWidth(resultType);
1175 if (operandBitWidth != resultBitWidth) {
1176 return bitcastOp.emitOpError("mismatch in result type bitwidth ")
1177 << resultBitWidth << " and operand type bitwidth "
1178 << operandBitWidth;
1179 }
1180 return success();
1181 }
1182
1183 //===----------------------------------------------------------------------===//
1184 // spv.BranchOp
1185 //===----------------------------------------------------------------------===//
1186
1187 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1188 spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
1189 assert(index == 0 && "invalid successor index");
1190 return targetOperandsMutable();
1191 }
1192
1193 //===----------------------------------------------------------------------===//
1194 // spv.BranchConditionalOp
1195 //===----------------------------------------------------------------------===//
1196
1197 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1198 spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
1199 assert(index < 2 && "invalid successor index");
1200 return index == kTrueIndex ? trueTargetOperandsMutable()
1201 : falseTargetOperandsMutable();
1202 }
1203
parseBranchConditionalOp(OpAsmParser & parser,OperationState & state)1204 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
1205 OperationState &state) {
1206 auto &builder = parser.getBuilder();
1207 OpAsmParser::OperandType condInfo;
1208 Block *dest;
1209
1210 // Parse the condition.
1211 Type boolTy = builder.getI1Type();
1212 if (parser.parseOperand(condInfo) ||
1213 parser.resolveOperand(condInfo, boolTy, state.operands))
1214 return failure();
1215
1216 // Parse the optional branch weights.
1217 if (succeeded(parser.parseOptionalLSquare())) {
1218 IntegerAttr trueWeight, falseWeight;
1219 NamedAttrList weights;
1220
1221 auto i32Type = builder.getIntegerType(32);
1222 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1223 parser.parseComma() ||
1224 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1225 parser.parseRSquare())
1226 return failure();
1227
1228 state.addAttribute(kBranchWeightAttrName,
1229 builder.getArrayAttr({trueWeight, falseWeight}));
1230 }
1231
1232 // Parse the true branch.
1233 SmallVector<Value, 4> trueOperands;
1234 if (parser.parseComma() ||
1235 parser.parseSuccessorAndUseList(dest, trueOperands))
1236 return failure();
1237 state.addSuccessors(dest);
1238 state.addOperands(trueOperands);
1239
1240 // Parse the false branch.
1241 SmallVector<Value, 4> falseOperands;
1242 if (parser.parseComma() ||
1243 parser.parseSuccessorAndUseList(dest, falseOperands))
1244 return failure();
1245 state.addSuccessors(dest);
1246 state.addOperands(falseOperands);
1247 state.addAttribute(
1248 spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1249 builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1250 static_cast<int32_t>(falseOperands.size())}));
1251
1252 return success();
1253 }
1254
print(spirv::BranchConditionalOp branchOp,OpAsmPrinter & printer)1255 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
1256 printer << spirv::BranchConditionalOp::getOperationName() << ' '
1257 << branchOp.condition();
1258
1259 if (auto weights = branchOp.branch_weights()) {
1260 printer << " [";
1261 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1262 printer << a.cast<IntegerAttr>().getInt();
1263 });
1264 printer << "]";
1265 }
1266
1267 printer << ", ";
1268 printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
1269 branchOp.getTrueBlockArguments());
1270 printer << ", ";
1271 printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
1272 branchOp.getFalseBlockArguments());
1273 }
1274
verify(spirv::BranchConditionalOp branchOp)1275 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1276 if (auto weights = branchOp.branch_weights()) {
1277 if (weights->getValue().size() != 2) {
1278 return branchOp.emitOpError("must have exactly two branch weights");
1279 }
1280 if (llvm::all_of(*weights, [](Attribute attr) {
1281 return attr.cast<IntegerAttr>().getValue().isNullValue();
1282 }))
1283 return branchOp.emitOpError("branch weights cannot both be zero");
1284 }
1285
1286 return success();
1287 }
1288
1289 //===----------------------------------------------------------------------===//
1290 // spv.CompositeConstruct
1291 //===----------------------------------------------------------------------===//
1292
parseCompositeConstructOp(OpAsmParser & parser,OperationState & state)1293 static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
1294 OperationState &state) {
1295 SmallVector<OpAsmParser::OperandType, 4> operands;
1296 Type type;
1297 auto loc = parser.getCurrentLocation();
1298
1299 if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1300 return failure();
1301 }
1302 auto cType = type.dyn_cast<spirv::CompositeType>();
1303 if (!cType) {
1304 return parser.emitError(
1305 loc, "result type must be a composite type, but provided ")
1306 << type;
1307 }
1308
1309 if (cType.hasCompileTimeKnownNumElements() &&
1310 operands.size() != cType.getNumElements()) {
1311 return parser.emitError(loc, "has incorrect number of operands: expected ")
1312 << cType.getNumElements() << ", but provided " << operands.size();
1313 }
1314 // TODO: Add support for constructing a vector type from the vector operands.
1315 // According to the spec: "for constructing a vector, the operands may
1316 // also be vectors with the same component type as the Result Type component
1317 // type".
1318 SmallVector<Type, 4> elementTypes;
1319 elementTypes.reserve(operands.size());
1320 for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1321 elementTypes.push_back(cType.getElementType(index));
1322 }
1323 state.addTypes(type);
1324 return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1325 }
1326
print(spirv::CompositeConstructOp compositeConstructOp,OpAsmPrinter & printer)1327 static void print(spirv::CompositeConstructOp compositeConstructOp,
1328 OpAsmPrinter &printer) {
1329 printer << spirv::CompositeConstructOp::getOperationName() << " "
1330 << compositeConstructOp.constituents() << " : "
1331 << compositeConstructOp.getResult().getType();
1332 }
1333
verify(spirv::CompositeConstructOp compositeConstructOp)1334 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1335 auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1336 SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
1337
1338 if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1339 if (constituents.size() != 1)
1340 return compositeConstructOp.emitError(
1341 "has incorrect number of operands: expected ")
1342 << "1, but provided " << constituents.size();
1343 } else if (constituents.size() != cType.getNumElements()) {
1344 return compositeConstructOp.emitError(
1345 "has incorrect number of operands: expected ")
1346 << cType.getNumElements() << ", but provided "
1347 << constituents.size();
1348 }
1349
1350 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1351 if (constituents[index].getType() != cType.getElementType(index)) {
1352 return compositeConstructOp.emitError(
1353 "operand type mismatch: expected operand type ")
1354 << cType.getElementType(index) << ", but provided "
1355 << constituents[index].getType();
1356 }
1357 }
1358
1359 return success();
1360 }
1361
1362 //===----------------------------------------------------------------------===//
1363 // spv.CompositeExtractOp
1364 //===----------------------------------------------------------------------===//
1365
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1366 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1367 Value composite,
1368 ArrayRef<int32_t> indices) {
1369 auto indexAttr = builder.getI32ArrayAttr(indices);
1370 auto elementType =
1371 getElementType(composite.getType(), indexAttr, state.location);
1372 if (!elementType) {
1373 return;
1374 }
1375 build(builder, state, elementType, composite, indexAttr);
1376 }
1377
parseCompositeExtractOp(OpAsmParser & parser,OperationState & state)1378 static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
1379 OperationState &state) {
1380 OpAsmParser::OperandType compositeInfo;
1381 Attribute indicesAttr;
1382 Type compositeType;
1383 llvm::SMLoc attrLocation;
1384
1385 if (parser.parseOperand(compositeInfo) ||
1386 parser.getCurrentLocation(&attrLocation) ||
1387 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1388 parser.parseColonType(compositeType) ||
1389 parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1390 return failure();
1391 }
1392
1393 Type resultType =
1394 getElementType(compositeType, indicesAttr, parser, attrLocation);
1395 if (!resultType) {
1396 return failure();
1397 }
1398 state.addTypes(resultType);
1399 return success();
1400 }
1401
print(spirv::CompositeExtractOp compositeExtractOp,OpAsmPrinter & printer)1402 static void print(spirv::CompositeExtractOp compositeExtractOp,
1403 OpAsmPrinter &printer) {
1404 printer << spirv::CompositeExtractOp::getOperationName() << ' '
1405 << compositeExtractOp.composite() << compositeExtractOp.indices()
1406 << " : " << compositeExtractOp.composite().getType();
1407 }
1408
verify(spirv::CompositeExtractOp compExOp)1409 static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
1410 auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
1411 auto resultType = getElementType(compExOp.composite().getType(),
1412 indicesArrayAttr, compExOp.getLoc());
1413 if (!resultType)
1414 return failure();
1415
1416 if (resultType != compExOp.getType()) {
1417 return compExOp.emitOpError("invalid result type: expected ")
1418 << resultType << " but provided " << compExOp.getType();
1419 }
1420
1421 return success();
1422 }
1423
1424 //===----------------------------------------------------------------------===//
1425 // spv.CompositeInsert
1426 //===----------------------------------------------------------------------===//
1427
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1428 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1429 Value object, Value composite,
1430 ArrayRef<int32_t> indices) {
1431 auto indexAttr = builder.getI32ArrayAttr(indices);
1432 build(builder, state, composite.getType(), object, composite, indexAttr);
1433 }
1434
parseCompositeInsertOp(OpAsmParser & parser,OperationState & state)1435 static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
1436 OperationState &state) {
1437 SmallVector<OpAsmParser::OperandType, 2> operands;
1438 Type objectType, compositeType;
1439 Attribute indicesAttr;
1440 auto loc = parser.getCurrentLocation();
1441
1442 return failure(
1443 parser.parseOperandList(operands, 2) ||
1444 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1445 parser.parseColonType(objectType) ||
1446 parser.parseKeywordType("into", compositeType) ||
1447 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1448 state.operands) ||
1449 parser.addTypesToList(compositeType, state.types));
1450 }
1451
verify(spirv::CompositeInsertOp compositeInsertOp)1452 static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
1453 auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
1454 auto objectType =
1455 getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
1456 compositeInsertOp.getLoc());
1457 if (!objectType)
1458 return failure();
1459
1460 if (objectType != compositeInsertOp.object().getType()) {
1461 return compositeInsertOp.emitOpError("object operand type should be ")
1462 << objectType << ", but found "
1463 << compositeInsertOp.object().getType();
1464 }
1465
1466 if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
1467 return compositeInsertOp.emitOpError("result type should be the same as "
1468 "the composite type, but found ")
1469 << compositeInsertOp.composite().getType() << " vs "
1470 << compositeInsertOp.getType();
1471 }
1472
1473 return success();
1474 }
1475
print(spirv::CompositeInsertOp compositeInsertOp,OpAsmPrinter & printer)1476 static void print(spirv::CompositeInsertOp compositeInsertOp,
1477 OpAsmPrinter &printer) {
1478 printer << spirv::CompositeInsertOp::getOperationName() << " "
1479 << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
1480 << compositeInsertOp.indices() << " : "
1481 << compositeInsertOp.object().getType() << " into "
1482 << compositeInsertOp.composite().getType();
1483 }
1484
1485 //===----------------------------------------------------------------------===//
1486 // spv.Constant
1487 //===----------------------------------------------------------------------===//
1488
parseConstantOp(OpAsmParser & parser,OperationState & state)1489 static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
1490 Attribute value;
1491 if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1492 return failure();
1493
1494 Type type = value.getType();
1495 if (type.isa<NoneType, TensorType>()) {
1496 if (parser.parseColonType(type))
1497 return failure();
1498 }
1499
1500 return parser.addTypeToList(type, state.types);
1501 }
1502
print(spirv::ConstantOp constOp,OpAsmPrinter & printer)1503 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
1504 printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
1505 if (constOp.getType().isa<spirv::ArrayType>())
1506 printer << " : " << constOp.getType();
1507 }
1508
verify(spirv::ConstantOp constOp)1509 static LogicalResult verify(spirv::ConstantOp constOp) {
1510 auto opType = constOp.getType();
1511 auto value = constOp.value();
1512 auto valueType = value.getType();
1513
1514 // ODS already generates checks to make sure the result type is valid. We just
1515 // need to additionally check that the value's attribute type is consistent
1516 // with the result type.
1517 if (value.isa<IntegerAttr, FloatAttr>()) {
1518 if (valueType != opType)
1519 return constOp.emitOpError("result type (")
1520 << opType << ") does not match value type (" << valueType << ")";
1521 return success();
1522 }
1523 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1524 if (valueType == opType)
1525 return success();
1526 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1527 auto shapedType = valueType.dyn_cast<ShapedType>();
1528 if (!arrayType) {
1529 return constOp.emitOpError(
1530 "must have spv.array result type for array value");
1531 }
1532
1533 int numElements = arrayType.getNumElements();
1534 auto opElemType = arrayType.getElementType();
1535 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1536 numElements *= t.getNumElements();
1537 opElemType = t.getElementType();
1538 }
1539 if (!opElemType.isIntOrFloat())
1540 return constOp.emitOpError("only support nested array result type");
1541
1542 auto valueElemType = shapedType.getElementType();
1543 if (valueElemType != opElemType) {
1544 return constOp.emitOpError("result element type (")
1545 << opElemType << ") does not match value element type ("
1546 << valueElemType << ")";
1547 }
1548
1549 if (numElements != shapedType.getNumElements()) {
1550 return constOp.emitOpError("result number of elements (")
1551 << numElements << ") does not match value number of elements ("
1552 << shapedType.getNumElements() << ")";
1553 }
1554 return success();
1555 }
1556 if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
1557 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1558 if (!arrayType)
1559 return constOp.emitOpError(
1560 "must have spv.array result type for array value");
1561 Type elemType = arrayType.getElementType();
1562 for (Attribute element : attayAttr.getValue()) {
1563 if (element.getType() != elemType)
1564 return constOp.emitOpError("has array element whose type (")
1565 << element.getType()
1566 << ") does not match the result element type (" << elemType
1567 << ')';
1568 }
1569 return success();
1570 }
1571 return constOp.emitOpError("cannot have value of type ") << valueType;
1572 }
1573
isBuildableWith(Type type)1574 bool spirv::ConstantOp::isBuildableWith(Type type) {
1575 // Must be valid SPIR-V type first.
1576 if (!type.isa<spirv::SPIRVType>())
1577 return false;
1578
1579 if (isa<SPIRVDialect>(type.getDialect())) {
1580 // TODO: support constant struct
1581 return type.isa<spirv::ArrayType>();
1582 }
1583
1584 return true;
1585 }
1586
getZero(Type type,Location loc,OpBuilder & builder)1587 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1588 OpBuilder &builder) {
1589 if (auto intType = type.dyn_cast<IntegerType>()) {
1590 unsigned width = intType.getWidth();
1591 if (width == 1)
1592 return builder.create<spirv::ConstantOp>(loc, type,
1593 builder.getBoolAttr(false));
1594 return builder.create<spirv::ConstantOp>(
1595 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1596 }
1597 if (auto floatType = type.dyn_cast<FloatType>()) {
1598 return builder.create<spirv::ConstantOp>(
1599 loc, type, builder.getFloatAttr(floatType, 0.0));
1600 }
1601 if (auto vectorType = type.dyn_cast<VectorType>()) {
1602 Type elemType = vectorType.getElementType();
1603 if (elemType.isa<IntegerType>()) {
1604 return builder.create<spirv::ConstantOp>(
1605 loc, type,
1606 DenseElementsAttr::get(vectorType,
1607 IntegerAttr::get(elemType, 0.0).getValue()));
1608 }
1609 if (elemType.isa<FloatType>()) {
1610 return builder.create<spirv::ConstantOp>(
1611 loc, type,
1612 DenseFPElementsAttr::get(vectorType,
1613 FloatAttr::get(elemType, 0.0).getValue()));
1614 }
1615 }
1616
1617 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1618 }
1619
getOne(Type type,Location loc,OpBuilder & builder)1620 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1621 OpBuilder &builder) {
1622 if (auto intType = type.dyn_cast<IntegerType>()) {
1623 unsigned width = intType.getWidth();
1624 if (width == 1)
1625 return builder.create<spirv::ConstantOp>(loc, type,
1626 builder.getBoolAttr(true));
1627 return builder.create<spirv::ConstantOp>(
1628 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1629 }
1630 if (auto floatType = type.dyn_cast<FloatType>()) {
1631 return builder.create<spirv::ConstantOp>(
1632 loc, type, builder.getFloatAttr(floatType, 1.0));
1633 }
1634 if (auto vectorType = type.dyn_cast<VectorType>()) {
1635 Type elemType = vectorType.getElementType();
1636 if (elemType.isa<IntegerType>()) {
1637 return builder.create<spirv::ConstantOp>(
1638 loc, type,
1639 DenseElementsAttr::get(vectorType,
1640 IntegerAttr::get(elemType, 1.0).getValue()));
1641 }
1642 if (elemType.isa<FloatType>()) {
1643 return builder.create<spirv::ConstantOp>(
1644 loc, type,
1645 DenseFPElementsAttr::get(vectorType,
1646 FloatAttr::get(elemType, 1.0).getValue()));
1647 }
1648 }
1649
1650 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1651 }
1652
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1653 void mlir::spirv::ConstantOp::getAsmResultNames(
1654 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1655 Type type = getType();
1656
1657 SmallString<32> specialNameBuffer;
1658 llvm::raw_svector_ostream specialName(specialNameBuffer);
1659 specialName << "cst";
1660
1661 IntegerType intTy = type.dyn_cast<IntegerType>();
1662
1663 if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1664 if (intTy && intTy.getWidth() == 1) {
1665 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1666 }
1667
1668 if (intTy.isSignless()) {
1669 specialName << intCst.getInt();
1670 } else {
1671 specialName << intCst.getSInt();
1672 }
1673 }
1674
1675 if (intTy || type.isa<FloatType>()) {
1676 specialName << '_' << type;
1677 }
1678
1679 if (auto vecType = type.dyn_cast<VectorType>()) {
1680 specialName << "_vec_";
1681 specialName << vecType.getDimSize(0);
1682
1683 Type elementType = vecType.getElementType();
1684
1685 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1686 specialName << "x" << elementType;
1687 }
1688 }
1689
1690 setNameFn(getResult(), specialName.str());
1691 }
1692
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1693 void mlir::spirv::AddressOfOp::getAsmResultNames(
1694 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1695 SmallString<32> specialNameBuffer;
1696 llvm::raw_svector_ostream specialName(specialNameBuffer);
1697 specialName << variable() << "_addr";
1698 setNameFn(getResult(), specialName.str());
1699 }
1700
1701 //===----------------------------------------------------------------------===//
1702 // spv.EntryPoint
1703 //===----------------------------------------------------------------------===//
1704
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)1705 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
1706 spirv::ExecutionModel executionModel,
1707 spirv::FuncOp function,
1708 ArrayRef<Attribute> interfaceVars) {
1709 build(builder, state,
1710 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
1711 builder.getSymbolRefAttr(function),
1712 builder.getArrayAttr(interfaceVars));
1713 }
1714
parseEntryPointOp(OpAsmParser & parser,OperationState & state)1715 static ParseResult parseEntryPointOp(OpAsmParser &parser,
1716 OperationState &state) {
1717 spirv::ExecutionModel execModel;
1718 SmallVector<OpAsmParser::OperandType, 0> identifiers;
1719 SmallVector<Type, 0> idTypes;
1720 SmallVector<Attribute, 4> interfaceVars;
1721
1722 FlatSymbolRefAttr fn;
1723 if (parseEnumStrAttr(execModel, parser, state) ||
1724 parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
1725 return failure();
1726 }
1727
1728 if (!parser.parseOptionalComma()) {
1729 // Parse the interface variables
1730 do {
1731 // The name of the interface variable attribute isnt important
1732 auto attrName = "var_symbol";
1733 FlatSymbolRefAttr var;
1734 NamedAttrList attrs;
1735 if (parser.parseAttribute(var, Type(), attrName, attrs)) {
1736 return failure();
1737 }
1738 interfaceVars.push_back(var);
1739 } while (!parser.parseOptionalComma());
1740 }
1741 state.addAttribute(kInterfaceAttrName,
1742 parser.getBuilder().getArrayAttr(interfaceVars));
1743 return success();
1744 }
1745
print(spirv::EntryPointOp entryPointOp,OpAsmPrinter & printer)1746 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
1747 printer << spirv::EntryPointOp::getOperationName() << " \""
1748 << stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
1749 printer.printSymbolName(entryPointOp.fn());
1750 auto interfaceVars = entryPointOp.interface().getValue();
1751 if (!interfaceVars.empty()) {
1752 printer << ", ";
1753 llvm::interleaveComma(interfaceVars, printer);
1754 }
1755 }
1756
verify(spirv::EntryPointOp entryPointOp)1757 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
1758 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
1759 // verification.
1760 return success();
1761 }
1762
1763 //===----------------------------------------------------------------------===//
1764 // spv.ExecutionMode
1765 //===----------------------------------------------------------------------===//
1766
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)1767 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
1768 spirv::FuncOp function,
1769 spirv::ExecutionMode executionMode,
1770 ArrayRef<int32_t> params) {
1771 build(builder, state, builder.getSymbolRefAttr(function),
1772 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
1773 builder.getI32ArrayAttr(params));
1774 }
1775
parseExecutionModeOp(OpAsmParser & parser,OperationState & state)1776 static ParseResult parseExecutionModeOp(OpAsmParser &parser,
1777 OperationState &state) {
1778 spirv::ExecutionMode execMode;
1779 Attribute fn;
1780 if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
1781 parseEnumStrAttr(execMode, parser, state)) {
1782 return failure();
1783 }
1784
1785 SmallVector<int32_t, 4> values;
1786 Type i32Type = parser.getBuilder().getIntegerType(32);
1787 while (!parser.parseOptionalComma()) {
1788 NamedAttrList attr;
1789 Attribute value;
1790 if (parser.parseAttribute(value, i32Type, "value", attr)) {
1791 return failure();
1792 }
1793 values.push_back(value.cast<IntegerAttr>().getInt());
1794 }
1795 state.addAttribute(kValuesAttrName,
1796 parser.getBuilder().getI32ArrayAttr(values));
1797 return success();
1798 }
1799
print(spirv::ExecutionModeOp execModeOp,OpAsmPrinter & printer)1800 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
1801 printer << spirv::ExecutionModeOp::getOperationName() << " ";
1802 printer.printSymbolName(execModeOp.fn());
1803 printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
1804 << "\"";
1805 auto values = execModeOp.values();
1806 if (!values.size())
1807 return;
1808 printer << ", ";
1809 llvm::interleaveComma(values, printer, [&](Attribute a) {
1810 printer << a.cast<IntegerAttr>().getInt();
1811 });
1812 }
1813
1814 //===----------------------------------------------------------------------===//
1815 // spv.func
1816 //===----------------------------------------------------------------------===//
1817
parseFuncOp(OpAsmParser & parser,OperationState & state)1818 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
1819 SmallVector<OpAsmParser::OperandType, 4> entryArgs;
1820 SmallVector<NamedAttrList, 4> argAttrs;
1821 SmallVector<NamedAttrList, 4> resultAttrs;
1822 SmallVector<Type, 4> argTypes;
1823 SmallVector<Type, 4> resultTypes;
1824 auto &builder = parser.getBuilder();
1825
1826 // Parse the name as a symbol.
1827 StringAttr nameAttr;
1828 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1829 state.attributes))
1830 return failure();
1831
1832 // Parse the function signature.
1833 bool isVariadic = false;
1834 if (function_like_impl::parseFunctionSignature(
1835 parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
1836 isVariadic, resultTypes, resultAttrs))
1837 return failure();
1838
1839 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1840 state.addAttribute(function_like_impl::getTypeAttrName(),
1841 TypeAttr::get(fnType));
1842
1843 // Parse the optional function control keyword.
1844 spirv::FunctionControl fnControl;
1845 if (parseEnumStrAttr(fnControl, parser, state))
1846 return failure();
1847
1848 // If additional attributes are present, parse them.
1849 if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1850 return failure();
1851
1852 // Add the attributes to the function arguments.
1853 assert(argAttrs.size() == argTypes.size());
1854 assert(resultAttrs.size() == resultTypes.size());
1855 function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
1856 resultAttrs);
1857
1858 // Parse the optional function body.
1859 auto *body = state.addRegion();
1860 OptionalParseResult result = parser.parseOptionalRegion(
1861 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1862 return failure(result.hasValue() && failed(*result));
1863 }
1864
print(spirv::FuncOp fnOp,OpAsmPrinter & printer)1865 static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
1866 // Print function name, signature, and control.
1867 printer << spirv::FuncOp::getOperationName() << " ";
1868 printer.printSymbolName(fnOp.sym_name());
1869 auto fnType = fnOp.getType();
1870 function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
1871 /*isVariadic=*/false,
1872 fnType.getResults());
1873 printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
1874 << "\"";
1875 function_like_impl::printFunctionAttributes(
1876 printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
1877 {spirv::attributeName<spirv::FunctionControl>()});
1878
1879 // Print the body if this is not an external function.
1880 Region &body = fnOp.body();
1881 if (!body.empty())
1882 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1883 /*printBlockTerminators=*/true);
1884 }
1885
verifyType()1886 LogicalResult spirv::FuncOp::verifyType() {
1887 auto type = getTypeAttr().getValue();
1888 if (!type.isa<FunctionType>())
1889 return emitOpError("requires '" + getTypeAttrName() +
1890 "' attribute of function type");
1891 if (getType().getNumResults() > 1)
1892 return emitOpError("cannot have more than one result");
1893 return success();
1894 }
1895
verifyBody()1896 LogicalResult spirv::FuncOp::verifyBody() {
1897 FunctionType fnType = getType();
1898
1899 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1900 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1901 if (fnType.getNumResults() != 0)
1902 return retOp.emitOpError("cannot be used in functions returning value");
1903 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1904 if (fnType.getNumResults() != 1)
1905 return retOp.emitOpError(
1906 "returns 1 value but enclosing function requires ")
1907 << fnType.getNumResults() << " results";
1908
1909 auto retOperandType = retOp.value().getType();
1910 auto fnResultType = fnType.getResult(0);
1911 if (retOperandType != fnResultType)
1912 return retOp.emitOpError(" return value's type (")
1913 << retOperandType << ") mismatch with function's result type ("
1914 << fnResultType << ")";
1915 }
1916 return WalkResult::advance();
1917 });
1918
1919 // TODO: verify other bits like linkage type.
1920
1921 return failure(walkResult.wasInterrupted());
1922 }
1923
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)1924 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1925 StringRef name, FunctionType type,
1926 spirv::FunctionControl control,
1927 ArrayRef<NamedAttribute> attrs) {
1928 state.addAttribute(SymbolTable::getSymbolAttrName(),
1929 builder.getStringAttr(name));
1930 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
1931 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1932 builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
1933 state.attributes.append(attrs.begin(), attrs.end());
1934 state.addRegion();
1935 }
1936
1937 // CallableOpInterface
getCallableRegion()1938 Region *spirv::FuncOp::getCallableRegion() {
1939 return isExternal() ? nullptr : &body();
1940 }
1941
1942 // CallableOpInterface
getCallableResults()1943 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
1944 return getType().getResults();
1945 }
1946
1947 //===----------------------------------------------------------------------===//
1948 // spv.FunctionCall
1949 //===----------------------------------------------------------------------===//
1950
verify(spirv::FunctionCallOp functionCallOp)1951 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
1952 auto fnName = functionCallOp.callee();
1953
1954 auto funcOp =
1955 dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
1956 functionCallOp->getParentOp(), fnName));
1957 if (!funcOp) {
1958 return functionCallOp.emitOpError("callee function '")
1959 << fnName << "' not found in nearest symbol table";
1960 }
1961
1962 auto functionType = funcOp.getType();
1963
1964 if (functionCallOp.getNumResults() > 1) {
1965 return functionCallOp.emitOpError(
1966 "expected callee function to have 0 or 1 result, but provided ")
1967 << functionCallOp.getNumResults();
1968 }
1969
1970 if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
1971 return functionCallOp.emitOpError(
1972 "has incorrect number of operands for callee: expected ")
1973 << functionType.getNumInputs() << ", but provided "
1974 << functionCallOp.getNumOperands();
1975 }
1976
1977 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
1978 if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
1979 return functionCallOp.emitOpError(
1980 "operand type mismatch: expected operand type ")
1981 << functionType.getInput(i) << ", but provided "
1982 << functionCallOp.getOperand(i).getType() << " for operand number "
1983 << i;
1984 }
1985 }
1986
1987 if (functionType.getNumResults() != functionCallOp.getNumResults()) {
1988 return functionCallOp.emitOpError(
1989 "has incorrect number of results has for callee: expected ")
1990 << functionType.getNumResults() << ", but provided "
1991 << functionCallOp.getNumResults();
1992 }
1993
1994 if (functionCallOp.getNumResults() &&
1995 (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
1996 return functionCallOp.emitOpError("result type mismatch: expected ")
1997 << functionType.getResult(0) << ", but provided "
1998 << functionCallOp.getResult(0).getType();
1999 }
2000
2001 return success();
2002 }
2003
getCallableForCallee()2004 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2005 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2006 }
2007
getArgOperands()2008 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2009 return arguments();
2010 }
2011
2012 //===----------------------------------------------------------------------===//
2013 // spv.GlobalVariable
2014 //===----------------------------------------------------------------------===//
2015
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)2016 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2017 Type type, StringRef name,
2018 unsigned descriptorSet, unsigned binding) {
2019 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
2020 nullptr);
2021 state.addAttribute(
2022 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2023 builder.getI32IntegerAttr(descriptorSet));
2024 state.addAttribute(
2025 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2026 builder.getI32IntegerAttr(binding));
2027 }
2028
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)2029 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2030 Type type, StringRef name,
2031 spirv::BuiltIn builtin) {
2032 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
2033 nullptr);
2034 state.addAttribute(
2035 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2036 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2037 }
2038
parseGlobalVariableOp(OpAsmParser & parser,OperationState & state)2039 static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
2040 OperationState &state) {
2041 // Parse variable name.
2042 StringAttr nameAttr;
2043 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2044 state.attributes)) {
2045 return failure();
2046 }
2047
2048 // Parse optional initializer
2049 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2050 FlatSymbolRefAttr initSymbol;
2051 if (parser.parseLParen() ||
2052 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2053 state.attributes) ||
2054 parser.parseRParen())
2055 return failure();
2056 }
2057
2058 if (parseVariableDecorations(parser, state)) {
2059 return failure();
2060 }
2061
2062 Type type;
2063 auto loc = parser.getCurrentLocation();
2064 if (parser.parseColonType(type)) {
2065 return failure();
2066 }
2067 if (!type.isa<spirv::PointerType>()) {
2068 return parser.emitError(loc, "expected spv.ptr type");
2069 }
2070 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2071
2072 return success();
2073 }
2074
print(spirv::GlobalVariableOp varOp,OpAsmPrinter & printer)2075 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
2076 auto *op = varOp.getOperation();
2077 SmallVector<StringRef, 4> elidedAttrs{
2078 spirv::attributeName<spirv::StorageClass>()};
2079 printer << spirv::GlobalVariableOp::getOperationName();
2080
2081 // Print variable name.
2082 printer << ' ';
2083 printer.printSymbolName(varOp.sym_name());
2084 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2085
2086 // Print optional initializer
2087 if (auto initializer = varOp.initializer()) {
2088 printer << " " << kInitializerAttrName << '(';
2089 printer.printSymbolName(initializer.getValue());
2090 printer << ')';
2091 elidedAttrs.push_back(kInitializerAttrName);
2092 }
2093
2094 elidedAttrs.push_back(kTypeAttrName);
2095 printVariableDecorations(op, printer, elidedAttrs);
2096 printer << " : " << varOp.type();
2097 }
2098
verify(spirv::GlobalVariableOp varOp)2099 static LogicalResult verify(spirv::GlobalVariableOp varOp) {
2100 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2101 // object. It cannot be Generic. It must be the same as the Storage Class
2102 // operand of the Result Type."
2103 // Also, Function storage class is reserved by spv.Variable.
2104 auto storageClass = varOp.storageClass();
2105 if (storageClass == spirv::StorageClass::Generic ||
2106 storageClass == spirv::StorageClass::Function) {
2107 return varOp.emitOpError("storage class cannot be '")
2108 << stringifyStorageClass(storageClass) << "'";
2109 }
2110
2111 if (auto init =
2112 varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2113 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2114 varOp->getParentOp(), init.getValue());
2115 // TODO: Currently only variable initialization with specialization
2116 // constants and other variables is supported. They could be normal
2117 // constants in the module scope as well.
2118 if (!initOp ||
2119 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2120 return varOp.emitOpError("initializer must be result of a "
2121 "spv.SpecConstant or spv.GlobalVariable op");
2122 }
2123 }
2124
2125 return success();
2126 }
2127
2128 //===----------------------------------------------------------------------===//
2129 // spv.GroupBroadcast
2130 //===----------------------------------------------------------------------===//
2131
verify(spirv::GroupBroadcastOp broadcastOp)2132 static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
2133 spirv::Scope scope = broadcastOp.execution_scope();
2134 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2135 return broadcastOp.emitOpError(
2136 "execution scope must be 'Workgroup' or 'Subgroup'");
2137
2138 if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
2139 if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
2140 return broadcastOp.emitOpError("localid is a vector and can be with only "
2141 " 2 or 3 components, actual number is ")
2142 << localIdTy.getNumElements();
2143
2144 return success();
2145 }
2146
2147 //===----------------------------------------------------------------------===//
2148 // spv.GroupNonUniformBallotOp
2149 //===----------------------------------------------------------------------===//
2150
verify(spirv::GroupNonUniformBallotOp ballotOp)2151 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
2152 spirv::Scope scope = ballotOp.execution_scope();
2153 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2154 return ballotOp.emitOpError(
2155 "execution scope must be 'Workgroup' or 'Subgroup'");
2156
2157 return success();
2158 }
2159
2160 //===----------------------------------------------------------------------===//
2161 // spv.GroupNonUniformBroadcast
2162 //===----------------------------------------------------------------------===//
2163
verify(spirv::GroupNonUniformBroadcastOp broadcastOp)2164 static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2165 spirv::Scope scope = broadcastOp.execution_scope();
2166 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2167 return broadcastOp.emitOpError(
2168 "execution scope must be 'Workgroup' or 'Subgroup'");
2169
2170 // SPIR-V spec: "Before version 1.5, Id must come from a
2171 // constant instruction.
2172 auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2173 if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
2174 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2175
2176 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2177 auto *idOp = broadcastOp.id().getDefiningOp();
2178 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2179 spirv::ReferenceOfOp>(idOp)) // for spec constant
2180 return broadcastOp.emitOpError("id must be the result of a constant op");
2181 }
2182
2183 return success();
2184 }
2185
2186 //===----------------------------------------------------------------------===//
2187 // spv.SubgroupBlockReadINTEL
2188 //===----------------------------------------------------------------------===//
2189
parseSubgroupBlockReadINTELOp(OpAsmParser & parser,OperationState & state)2190 static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
2191 OperationState &state) {
2192 // Parse the storage class specification
2193 spirv::StorageClass storageClass;
2194 OpAsmParser::OperandType ptrInfo;
2195 Type elementType;
2196 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2197 parser.parseColon() || parser.parseType(elementType)) {
2198 return failure();
2199 }
2200
2201 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2202 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2203 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2204
2205 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2206 return failure();
2207 }
2208
2209 state.addTypes(elementType);
2210 return success();
2211 }
2212
print(spirv::SubgroupBlockReadINTELOp blockReadOp,OpAsmPrinter & printer)2213 static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2214 OpAsmPrinter &printer) {
2215 SmallVector<StringRef, 4> elidedAttrs;
2216 printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
2217 << blockReadOp.ptr();
2218 printer << " : " << blockReadOp.getType();
2219 }
2220
verify(spirv::SubgroupBlockReadINTELOp blockReadOp)2221 static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2222 if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2223 blockReadOp.value())))
2224 return failure();
2225
2226 return success();
2227 }
2228
2229 //===----------------------------------------------------------------------===//
2230 // spv.SubgroupBlockWriteINTEL
2231 //===----------------------------------------------------------------------===//
2232
parseSubgroupBlockWriteINTELOp(OpAsmParser & parser,OperationState & state)2233 static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
2234 OperationState &state) {
2235 // Parse the storage class specification
2236 spirv::StorageClass storageClass;
2237 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2238 auto loc = parser.getCurrentLocation();
2239 Type elementType;
2240 if (parseEnumStrAttr(storageClass, parser) ||
2241 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2242 parser.parseType(elementType)) {
2243 return failure();
2244 }
2245
2246 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2247 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2248 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2249
2250 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2251 state.operands)) {
2252 return failure();
2253 }
2254 return success();
2255 }
2256
print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,OpAsmPrinter & printer)2257 static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2258 OpAsmPrinter &printer) {
2259 SmallVector<StringRef, 4> elidedAttrs;
2260 printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
2261 << blockWriteOp.ptr() << ", " << blockWriteOp.value();
2262 printer << " : " << blockWriteOp.value().getType();
2263 }
2264
verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp)2265 static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2266 if (failed(verifyBlockReadWritePtrAndValTypes(
2267 blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2268 return failure();
2269
2270 return success();
2271 }
2272
2273 //===----------------------------------------------------------------------===//
2274 // spv.GroupNonUniformElectOp
2275 //===----------------------------------------------------------------------===//
2276
build(OpBuilder & builder,OperationState & state,spirv::Scope scope)2277 void spirv::GroupNonUniformElectOp::build(OpBuilder &builder,
2278 OperationState &state,
2279 spirv::Scope scope) {
2280 build(builder, state, builder.getI1Type(), scope);
2281 }
2282
verify(spirv::GroupNonUniformElectOp groupOp)2283 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
2284 spirv::Scope scope = groupOp.execution_scope();
2285 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2286 return groupOp.emitOpError(
2287 "execution scope must be 'Workgroup' or 'Subgroup'");
2288
2289 return success();
2290 }
2291
2292 //===----------------------------------------------------------------------===//
2293 // spv.LoadOp
2294 //===----------------------------------------------------------------------===//
2295
build(OpBuilder & builder,OperationState & state,Value basePtr,MemoryAccessAttr memoryAccess,IntegerAttr alignment)2296 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2297 Value basePtr, MemoryAccessAttr memoryAccess,
2298 IntegerAttr alignment) {
2299 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2300 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2301 alignment);
2302 }
2303
parseLoadOp(OpAsmParser & parser,OperationState & state)2304 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
2305 // Parse the storage class specification
2306 spirv::StorageClass storageClass;
2307 OpAsmParser::OperandType ptrInfo;
2308 Type elementType;
2309 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2310 parseMemoryAccessAttributes(parser, state) ||
2311 parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2312 parser.parseType(elementType)) {
2313 return failure();
2314 }
2315
2316 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2317 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2318 return failure();
2319 }
2320
2321 state.addTypes(elementType);
2322 return success();
2323 }
2324
print(spirv::LoadOp loadOp,OpAsmPrinter & printer)2325 static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
2326 auto *op = loadOp.getOperation();
2327 SmallVector<StringRef, 4> elidedAttrs;
2328 StringRef sc = stringifyStorageClass(
2329 loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2330 printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
2331 << loadOp.ptr();
2332
2333 printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
2334
2335 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2336 printer << " : " << loadOp.getType();
2337 }
2338
verify(spirv::LoadOp loadOp)2339 static LogicalResult verify(spirv::LoadOp loadOp) {
2340 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2341 // type with fixed size; i.e., it cannot be, nor include, any
2342 // OpTypeRuntimeArray types."
2343 if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
2344 loadOp.value()))) {
2345 return failure();
2346 }
2347 return verifyMemoryAccessAttribute(loadOp);
2348 }
2349
2350 //===----------------------------------------------------------------------===//
2351 // spv.mlir.loop
2352 //===----------------------------------------------------------------------===//
2353
build(OpBuilder & builder,OperationState & state)2354 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2355 state.addAttribute("loop_control",
2356 builder.getI32IntegerAttr(
2357 static_cast<uint32_t>(spirv::LoopControl::None)));
2358 state.addRegion();
2359 }
2360
parseLoopOp(OpAsmParser & parser,OperationState & state)2361 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
2362 if (parseControlAttribute<spirv::LoopControl>(parser, state))
2363 return failure();
2364 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2365 /*argTypes=*/{});
2366 }
2367
print(spirv::LoopOp loopOp,OpAsmPrinter & printer)2368 static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
2369 auto *op = loopOp.getOperation();
2370
2371 printer << spirv::LoopOp::getOperationName();
2372 auto control = loopOp.loop_control();
2373 if (control != spirv::LoopControl::None)
2374 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2375 printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2376 /*printBlockTerminators=*/true);
2377 }
2378
2379 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2380 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2381 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2382 // Check that there is only one op in the `srcBlock`.
2383 if (!llvm::hasSingleElement(srcBlock))
2384 return false;
2385
2386 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2387 return branchOp && branchOp.getSuccessor() == &dstBlock;
2388 }
2389
verify(spirv::LoopOp loopOp)2390 static LogicalResult verify(spirv::LoopOp loopOp) {
2391 auto *op = loopOp.getOperation();
2392
2393 // We need to verify that the blocks follow the following layout:
2394 //
2395 // +-------------+
2396 // | entry block |
2397 // +-------------+
2398 // |
2399 // v
2400 // +-------------+
2401 // | loop header | <-----+
2402 // +-------------+ |
2403 // |
2404 // ... |
2405 // \ | / |
2406 // v |
2407 // +---------------+ |
2408 // | loop continue | -----+
2409 // +---------------+
2410 //
2411 // ...
2412 // \ | /
2413 // v
2414 // +-------------+
2415 // | merge block |
2416 // +-------------+
2417
2418 auto ®ion = op->getRegion(0);
2419 // Allow empty region as a degenerated case, which can come from
2420 // optimizations.
2421 if (region.empty())
2422 return success();
2423
2424 // The last block is the merge block.
2425 Block &merge = region.back();
2426 if (!isMergeBlock(merge))
2427 return loopOp.emitOpError(
2428 "last block must be the merge block with only one 'spv.mlir.merge' op");
2429
2430 if (std::next(region.begin()) == region.end())
2431 return loopOp.emitOpError(
2432 "must have an entry block branching to the loop header block");
2433 // The first block is the entry block.
2434 Block &entry = region.front();
2435
2436 if (std::next(region.begin(), 2) == region.end())
2437 return loopOp.emitOpError(
2438 "must have a loop header block branched from the entry block");
2439 // The second block is the loop header block.
2440 Block &header = *std::next(region.begin(), 1);
2441
2442 if (!hasOneBranchOpTo(entry, header))
2443 return loopOp.emitOpError(
2444 "entry block must only have one 'spv.Branch' op to the second block");
2445
2446 if (std::next(region.begin(), 3) == region.end())
2447 return loopOp.emitOpError(
2448 "requires a loop continue block branching to the loop header block");
2449 // The second to last block is the loop continue block.
2450 Block &cont = *std::prev(region.end(), 2);
2451
2452 // Make sure that we have a branch from the loop continue block to the loop
2453 // header block.
2454 if (llvm::none_of(
2455 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
2456 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
2457 return loopOp.emitOpError("second to last block must be the loop continue "
2458 "block that branches to the loop header block");
2459
2460 // Make sure that no other blocks (except the entry and loop continue block)
2461 // branches to the loop header block.
2462 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
2463 std::prev(region.end(), 2))) {
2464 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
2465 if (block.getSuccessor(i) == &header) {
2466 return loopOp.emitOpError("can only have the entry and loop continue "
2467 "block branching to the loop header block");
2468 }
2469 }
2470 }
2471
2472 return success();
2473 }
2474
getEntryBlock()2475 Block *spirv::LoopOp::getEntryBlock() {
2476 assert(!body().empty() && "op region should not be empty!");
2477 return &body().front();
2478 }
2479
getHeaderBlock()2480 Block *spirv::LoopOp::getHeaderBlock() {
2481 assert(!body().empty() && "op region should not be empty!");
2482 // The second block is the loop header block.
2483 return &*std::next(body().begin());
2484 }
2485
getContinueBlock()2486 Block *spirv::LoopOp::getContinueBlock() {
2487 assert(!body().empty() && "op region should not be empty!");
2488 // The second to last block is the loop continue block.
2489 return &*std::prev(body().end(), 2);
2490 }
2491
getMergeBlock()2492 Block *spirv::LoopOp::getMergeBlock() {
2493 assert(!body().empty() && "op region should not be empty!");
2494 // The last block is the loop merge block.
2495 return &body().back();
2496 }
2497
addEntryAndMergeBlock()2498 void spirv::LoopOp::addEntryAndMergeBlock() {
2499 assert(body().empty() && "entry and merge block already exist");
2500 body().push_back(new Block());
2501 auto *mergeBlock = new Block();
2502 body().push_back(mergeBlock);
2503 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2504
2505 // Add a spv.mlir.merge op into the merge block.
2506 builder.create<spirv::MergeOp>(getLoc());
2507 }
2508
2509 //===----------------------------------------------------------------------===//
2510 // spv.mlir.merge
2511 //===----------------------------------------------------------------------===//
2512
verify(spirv::MergeOp mergeOp)2513 static LogicalResult verify(spirv::MergeOp mergeOp) {
2514 auto *parentOp = mergeOp->getParentOp();
2515 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
2516 return mergeOp.emitOpError(
2517 "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
2518
2519 Block &parentLastBlock = mergeOp->getParentRegion()->back();
2520 if (mergeOp.getOperation() != parentLastBlock.getTerminator())
2521 return mergeOp.emitOpError("can only be used in the last block of "
2522 "'spv.mlir.selection' or 'spv.mlir.loop'");
2523 return success();
2524 }
2525
2526 //===----------------------------------------------------------------------===//
2527 // spv.module
2528 //===----------------------------------------------------------------------===//
2529
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)2530 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2531 Optional<StringRef> name) {
2532 OpBuilder::InsertionGuard guard(builder);
2533 builder.createBlock(state.addRegion());
2534 if (name) {
2535 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2536 builder.getStringAttr(*name));
2537 }
2538 }
2539
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<StringRef> name)2540 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2541 spirv::AddressingModel addressingModel,
2542 spirv::MemoryModel memoryModel,
2543 Optional<StringRef> name) {
2544 state.addAttribute(
2545 "addressing_model",
2546 builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
2547 state.addAttribute("memory_model", builder.getI32IntegerAttr(
2548 static_cast<int32_t>(memoryModel)));
2549 OpBuilder::InsertionGuard guard(builder);
2550 builder.createBlock(state.addRegion());
2551 if (name) {
2552 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2553 builder.getStringAttr(*name));
2554 }
2555 }
2556
parseModuleOp(OpAsmParser & parser,OperationState & state)2557 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
2558 Region *body = state.addRegion();
2559
2560 // If the name is present, parse it.
2561 StringAttr nameAttr;
2562 parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2563 state.attributes);
2564
2565 // Parse attributes
2566 spirv::AddressingModel addrModel;
2567 spirv::MemoryModel memoryModel;
2568 if (parseEnumKeywordAttr(addrModel, parser, state) ||
2569 parseEnumKeywordAttr(memoryModel, parser, state))
2570 return failure();
2571
2572 if (succeeded(parser.parseOptionalKeyword("requires"))) {
2573 spirv::VerCapExtAttr vceTriple;
2574 if (parser.parseAttribute(vceTriple,
2575 spirv::ModuleOp::getVCETripleAttrName(),
2576 state.attributes))
2577 return failure();
2578 }
2579
2580 if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2581 return failure();
2582
2583 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2584 return failure();
2585
2586 // Make sure we have at least one block.
2587 if (body->empty())
2588 body->push_back(new Block());
2589
2590 return success();
2591 }
2592
print(spirv::ModuleOp moduleOp,OpAsmPrinter & printer)2593 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
2594 printer << spirv::ModuleOp::getOperationName();
2595
2596 if (Optional<StringRef> name = moduleOp.getName()) {
2597 printer << ' ';
2598 printer.printSymbolName(*name);
2599 }
2600
2601 SmallVector<StringRef, 2> elidedAttrs;
2602
2603 printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
2604 << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
2605 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
2606 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
2607 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
2608 SymbolTable::getSymbolAttrName()});
2609
2610 if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
2611 printer << " requires " << *triple;
2612 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
2613 }
2614
2615 printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
2616 printer.printRegion(moduleOp.getRegion());
2617 }
2618
verify(spirv::ModuleOp moduleOp)2619 static LogicalResult verify(spirv::ModuleOp moduleOp) {
2620 auto &op = *moduleOp.getOperation();
2621 auto *dialect = op.getDialect();
2622 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
2623 entryPoints;
2624 SymbolTable table(moduleOp);
2625
2626 for (auto &op : *moduleOp.getBody()) {
2627 if (op.getDialect() != dialect)
2628 return op.emitError("'spv.module' can only contain spv.* ops");
2629
2630 // For EntryPoint op, check that the function and execution model is not
2631 // duplicated in EntryPointOps. Also verify that the interface specified
2632 // comes from globalVariables here to make this check cheaper.
2633 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
2634 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
2635 if (!funcOp) {
2636 return entryPointOp.emitError("function '")
2637 << entryPointOp.fn() << "' not found in 'spv.module'";
2638 }
2639 if (auto interface = entryPointOp.interface()) {
2640 for (Attribute varRef : interface) {
2641 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
2642 if (!varSymRef) {
2643 return entryPointOp.emitError(
2644 "expected symbol reference for interface "
2645 "specification instead of '")
2646 << varRef;
2647 }
2648 auto variableOp =
2649 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
2650 if (!variableOp) {
2651 return entryPointOp.emitError("expected spv.GlobalVariable "
2652 "symbol reference instead of'")
2653 << varSymRef << "'";
2654 }
2655 }
2656 }
2657
2658 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
2659 funcOp, entryPointOp.execution_model());
2660 auto entryPtIt = entryPoints.find(key);
2661 if (entryPtIt != entryPoints.end()) {
2662 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
2663 }
2664 entryPoints[key] = entryPointOp;
2665 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
2666 if (funcOp.isExternal())
2667 return op.emitError("'spv.module' cannot contain external functions");
2668
2669 // TODO: move this check to spv.func.
2670 for (auto &block : funcOp)
2671 for (auto &op : block) {
2672 if (op.getDialect() != dialect)
2673 return op.emitError(
2674 "functions in 'spv.module' can only contain spv.* ops");
2675 }
2676 }
2677 }
2678
2679 return success();
2680 }
2681
2682 //===----------------------------------------------------------------------===//
2683 // spv.mlir.referenceof
2684 //===----------------------------------------------------------------------===//
2685
verify(spirv::ReferenceOfOp referenceOfOp)2686 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2687 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2688 referenceOfOp->getParentOp(), referenceOfOp.spec_const());
2689 Type constType;
2690
2691 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2692 if (specConstOp)
2693 constType = specConstOp.default_value().getType();
2694
2695 auto specConstCompositeOp =
2696 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2697 if (specConstCompositeOp)
2698 constType = specConstCompositeOp.type();
2699
2700 if (!specConstOp && !specConstCompositeOp)
2701 return referenceOfOp.emitOpError(
2702 "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
2703
2704 if (referenceOfOp.reference().getType() != constType)
2705 return referenceOfOp.emitOpError("result type mismatch with the referenced "
2706 "specialization constant's type");
2707
2708 return success();
2709 }
2710
2711 //===----------------------------------------------------------------------===//
2712 // spv.Return
2713 //===----------------------------------------------------------------------===//
2714
verify(spirv::ReturnOp returnOp)2715 static LogicalResult verify(spirv::ReturnOp returnOp) {
2716 // Verification is performed in spv.func op.
2717 return success();
2718 }
2719
2720 //===----------------------------------------------------------------------===//
2721 // spv.ReturnValue
2722 //===----------------------------------------------------------------------===//
2723
verify(spirv::ReturnValueOp retValOp)2724 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
2725 // Verification is performed in spv.func op.
2726 return success();
2727 }
2728
2729 //===----------------------------------------------------------------------===//
2730 // spv.Select
2731 //===----------------------------------------------------------------------===//
2732
build(OpBuilder & builder,OperationState & state,Value cond,Value trueValue,Value falseValue)2733 void spirv::SelectOp::build(OpBuilder &builder, OperationState &state,
2734 Value cond, Value trueValue, Value falseValue) {
2735 build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
2736 }
2737
verify(spirv::SelectOp op)2738 static LogicalResult verify(spirv::SelectOp op) {
2739 if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
2740 auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
2741 if (!resultVectorTy) {
2742 return op.emitOpError("result expected to be of vector type when "
2743 "condition is of vector type");
2744 }
2745 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
2746 return op.emitOpError("result should have the same number of elements as "
2747 "the condition when condition is of vector type");
2748 }
2749 }
2750 return success();
2751 }
2752
2753 //===----------------------------------------------------------------------===//
2754 // spv.mlir.selection
2755 //===----------------------------------------------------------------------===//
2756
parseSelectionOp(OpAsmParser & parser,OperationState & state)2757 static ParseResult parseSelectionOp(OpAsmParser &parser,
2758 OperationState &state) {
2759 if (parseControlAttribute<spirv::SelectionControl>(parser, state))
2760 return failure();
2761 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2762 /*argTypes=*/{});
2763 }
2764
print(spirv::SelectionOp selectionOp,OpAsmPrinter & printer)2765 static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
2766 auto *op = selectionOp.getOperation();
2767
2768 printer << spirv::SelectionOp::getOperationName();
2769 auto control = selectionOp.selection_control();
2770 if (control != spirv::SelectionControl::None)
2771 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
2772 printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2773 /*printBlockTerminators=*/true);
2774 }
2775
verify(spirv::SelectionOp selectionOp)2776 static LogicalResult verify(spirv::SelectionOp selectionOp) {
2777 auto *op = selectionOp.getOperation();
2778
2779 // We need to verify that the blocks follow the following layout:
2780 //
2781 // +--------------+
2782 // | header block |
2783 // +--------------+
2784 // / | \
2785 // ...
2786 //
2787 //
2788 // +---------+ +---------+ +---------+
2789 // | case #0 | | case #1 | | case #2 | ...
2790 // +---------+ +---------+ +---------+
2791 //
2792 //
2793 // ...
2794 // \ | /
2795 // v
2796 // +-------------+
2797 // | merge block |
2798 // +-------------+
2799
2800 auto ®ion = op->getRegion(0);
2801 // Allow empty region as a degenerated case, which can come from
2802 // optimizations.
2803 if (region.empty())
2804 return success();
2805
2806 // The last block is the merge block.
2807 if (!isMergeBlock(region.back()))
2808 return selectionOp.emitOpError(
2809 "last block must be the merge block with only one 'spv.mlir.merge' op");
2810
2811 if (std::next(region.begin()) == region.end())
2812 return selectionOp.emitOpError("must have a selection header block");
2813
2814 return success();
2815 }
2816
getHeaderBlock()2817 Block *spirv::SelectionOp::getHeaderBlock() {
2818 assert(!body().empty() && "op region should not be empty!");
2819 // The first block is the loop header block.
2820 return &body().front();
2821 }
2822
getMergeBlock()2823 Block *spirv::SelectionOp::getMergeBlock() {
2824 assert(!body().empty() && "op region should not be empty!");
2825 // The last block is the loop merge block.
2826 return &body().back();
2827 }
2828
addMergeBlock()2829 void spirv::SelectionOp::addMergeBlock() {
2830 assert(body().empty() && "entry and merge block already exist");
2831 auto *mergeBlock = new Block();
2832 body().push_back(mergeBlock);
2833 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2834
2835 // Add a spv.mlir.merge op into the merge block.
2836 builder.create<spirv::MergeOp>(getLoc());
2837 }
2838
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)2839 spirv::SelectionOp spirv::SelectionOp::createIfThen(
2840 Location loc, Value condition,
2841 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2842 auto selectionOp =
2843 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
2844
2845 selectionOp.addMergeBlock();
2846 Block *mergeBlock = selectionOp.getMergeBlock();
2847 Block *thenBlock = nullptr;
2848
2849 // Build the "then" block.
2850 {
2851 OpBuilder::InsertionGuard guard(builder);
2852 thenBlock = builder.createBlock(mergeBlock);
2853 thenBody(builder);
2854 builder.create<spirv::BranchOp>(loc, mergeBlock);
2855 }
2856
2857 // Build the header block.
2858 {
2859 OpBuilder::InsertionGuard guard(builder);
2860 builder.createBlock(thenBlock);
2861 builder.create<spirv::BranchConditionalOp>(
2862 loc, condition, thenBlock,
2863 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
2864 /*falseArguments=*/ArrayRef<Value>());
2865 }
2866
2867 return selectionOp;
2868 }
2869
2870 //===----------------------------------------------------------------------===//
2871 // spv.SpecConstant
2872 //===----------------------------------------------------------------------===//
2873
parseSpecConstantOp(OpAsmParser & parser,OperationState & state)2874 static ParseResult parseSpecConstantOp(OpAsmParser &parser,
2875 OperationState &state) {
2876 StringAttr nameAttr;
2877 Attribute valueAttr;
2878
2879 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2880 state.attributes))
2881 return failure();
2882
2883 // Parse optional spec_id.
2884 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
2885 IntegerAttr specIdAttr;
2886 if (parser.parseLParen() ||
2887 parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
2888 parser.parseRParen())
2889 return failure();
2890 }
2891
2892 if (parser.parseEqual() ||
2893 parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
2894 return failure();
2895
2896 return success();
2897 }
2898
print(spirv::SpecConstantOp constOp,OpAsmPrinter & printer)2899 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
2900 printer << spirv::SpecConstantOp::getOperationName() << ' ';
2901 printer.printSymbolName(constOp.sym_name());
2902 if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2903 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
2904 printer << " = " << constOp.default_value();
2905 }
2906
verify(spirv::SpecConstantOp constOp)2907 static LogicalResult verify(spirv::SpecConstantOp constOp) {
2908 if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2909 if (specID.getValue().isNegative())
2910 return constOp.emitOpError("SpecId cannot be negative");
2911
2912 auto value = constOp.default_value();
2913 if (value.isa<IntegerAttr, FloatAttr>()) {
2914 // Make sure bitwidth is allowed.
2915 if (!value.getType().isa<spirv::SPIRVType>())
2916 return constOp.emitOpError("default value bitwidth disallowed");
2917 return success();
2918 }
2919 return constOp.emitOpError(
2920 "default value can only be a bool, integer, or float scalar");
2921 }
2922
2923 //===----------------------------------------------------------------------===//
2924 // spv.StoreOp
2925 //===----------------------------------------------------------------------===//
2926
parseStoreOp(OpAsmParser & parser,OperationState & state)2927 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
2928 // Parse the storage class specification
2929 spirv::StorageClass storageClass;
2930 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2931 auto loc = parser.getCurrentLocation();
2932 Type elementType;
2933 if (parseEnumStrAttr(storageClass, parser) ||
2934 parser.parseOperandList(operandInfo, 2) ||
2935 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
2936 parser.parseType(elementType)) {
2937 return failure();
2938 }
2939
2940 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2941 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2942 state.operands)) {
2943 return failure();
2944 }
2945 return success();
2946 }
2947
print(spirv::StoreOp storeOp,OpAsmPrinter & printer)2948 static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
2949 auto *op = storeOp.getOperation();
2950 SmallVector<StringRef, 4> elidedAttrs;
2951 StringRef sc = stringifyStorageClass(
2952 storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2953 printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
2954 << storeOp.ptr() << ", " << storeOp.value();
2955
2956 printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
2957
2958 printer << " : " << storeOp.value().getType();
2959 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2960 }
2961
verify(spirv::StoreOp storeOp)2962 static LogicalResult verify(spirv::StoreOp storeOp) {
2963 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
2964 // OpTypePointer whose Type operand is the same as the type of Object."
2965 if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
2966 storeOp.value()))) {
2967 return failure();
2968 }
2969 return verifyMemoryAccessAttribute(storeOp);
2970 }
2971
2972 //===----------------------------------------------------------------------===//
2973 // spv.Unreachable
2974 //===----------------------------------------------------------------------===//
2975
verify(spirv::UnreachableOp unreachableOp)2976 static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
2977 auto *op = unreachableOp.getOperation();
2978 auto *block = op->getBlock();
2979 // Fast track: if this is in entry block, its invalid. Otherwise, if no
2980 // predecessors, it's valid.
2981 if (block->isEntryBlock())
2982 return unreachableOp.emitOpError("cannot be used in reachable block");
2983 if (block->hasNoPredecessors())
2984 return success();
2985
2986 // TODO: further verification needs to analyze reachability from
2987 // the entry block.
2988
2989 return success();
2990 }
2991
2992 //===----------------------------------------------------------------------===//
2993 // spv.Variable
2994 //===----------------------------------------------------------------------===//
2995
parseVariableOp(OpAsmParser & parser,OperationState & state)2996 static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
2997 // Parse optional initializer
2998 Optional<OpAsmParser::OperandType> initInfo;
2999 if (succeeded(parser.parseOptionalKeyword("init"))) {
3000 initInfo = OpAsmParser::OperandType();
3001 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3002 parser.parseRParen())
3003 return failure();
3004 }
3005
3006 if (parseVariableDecorations(parser, state)) {
3007 return failure();
3008 }
3009
3010 // Parse result pointer type
3011 Type type;
3012 if (parser.parseColon())
3013 return failure();
3014 auto loc = parser.getCurrentLocation();
3015 if (parser.parseType(type))
3016 return failure();
3017
3018 auto ptrType = type.dyn_cast<spirv::PointerType>();
3019 if (!ptrType)
3020 return parser.emitError(loc, "expected spv.ptr type");
3021 state.addTypes(ptrType);
3022
3023 // Resolve the initializer operand
3024 if (initInfo) {
3025 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3026 state.operands))
3027 return failure();
3028 }
3029
3030 auto attr = parser.getBuilder().getI32IntegerAttr(
3031 llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3032 state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3033
3034 return success();
3035 }
3036
print(spirv::VariableOp varOp,OpAsmPrinter & printer)3037 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
3038 SmallVector<StringRef, 4> elidedAttrs{
3039 spirv::attributeName<spirv::StorageClass>()};
3040 printer << spirv::VariableOp::getOperationName();
3041
3042 // Print optional initializer
3043 if (varOp.getNumOperands() != 0)
3044 printer << " init(" << varOp.initializer() << ")";
3045
3046 printVariableDecorations(varOp, printer, elidedAttrs);
3047 printer << " : " << varOp.getType();
3048 }
3049
verify(spirv::VariableOp varOp)3050 static LogicalResult verify(spirv::VariableOp varOp) {
3051 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3052 // object. It cannot be Generic. It must be the same as the Storage Class
3053 // operand of the Result Type."
3054 if (varOp.storage_class() != spirv::StorageClass::Function) {
3055 return varOp.emitOpError(
3056 "can only be used to model function-level variables. Use "
3057 "spv.GlobalVariable for module-level variables.");
3058 }
3059
3060 auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
3061 if (varOp.storage_class() != pointerType.getStorageClass())
3062 return varOp.emitOpError(
3063 "storage class must match result pointer's storage class");
3064
3065 if (varOp.getNumOperands() != 0) {
3066 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3067 // a global (module scope) OpVariable instruction".
3068 auto *initOp = varOp.getOperand(0).getDefiningOp();
3069 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3070 spirv::ReferenceOfOp, // for spec constant
3071 spirv::AddressOfOp>(initOp))
3072 return varOp.emitOpError("initializer must be the result of a "
3073 "constant or spv.GlobalVariable op");
3074 }
3075
3076 // TODO: generate these strings using ODS.
3077 auto *op = varOp.getOperation();
3078 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3079 stringifyDecoration(spirv::Decoration::DescriptorSet));
3080 auto bindingName = llvm::convertToSnakeFromCamelCase(
3081 stringifyDecoration(spirv::Decoration::Binding));
3082 auto builtInName = llvm::convertToSnakeFromCamelCase(
3083 stringifyDecoration(spirv::Decoration::BuiltIn));
3084
3085 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3086 if (op->getAttr(attr))
3087 return varOp.emitOpError("cannot have '")
3088 << attr << "' attribute (only allowed in spv.GlobalVariable)";
3089 }
3090
3091 return success();
3092 }
3093
3094 //===----------------------------------------------------------------------===//
3095 // spv.VectorShuffle
3096 //===----------------------------------------------------------------------===//
3097
verify(spirv::VectorShuffleOp shuffleOp)3098 static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) {
3099 VectorType resultType = shuffleOp.getType().cast<VectorType>();
3100
3101 size_t numResultElements = resultType.getNumElements();
3102 if (numResultElements != shuffleOp.components().size())
3103 return shuffleOp.emitOpError("result type element count (")
3104 << numResultElements
3105 << ") mismatch with the number of component selectors ("
3106 << shuffleOp.components().size() << ")";
3107
3108 size_t totalSrcElements =
3109 shuffleOp.vector1().getType().cast<VectorType>().getNumElements() +
3110 shuffleOp.vector2().getType().cast<VectorType>().getNumElements();
3111
3112 for (const auto &selector :
3113 shuffleOp.components().getAsValueRange<IntegerAttr>()) {
3114 uint32_t index = selector.getZExtValue();
3115 if (index >= totalSrcElements &&
3116 index != std::numeric_limits<uint32_t>().max())
3117 return shuffleOp.emitOpError("component selector ")
3118 << index << " out of range: expected to be in [0, "
3119 << totalSrcElements << ") or 0xffffffff";
3120 }
3121 return success();
3122 }
3123
3124 //===----------------------------------------------------------------------===//
3125 // spv.CooperativeMatrixLoadNV
3126 //===----------------------------------------------------------------------===//
3127
parseCooperativeMatrixLoadNVOp(OpAsmParser & parser,OperationState & state)3128 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
3129 OperationState &state) {
3130 SmallVector<OpAsmParser::OperandType, 3> operandInfo;
3131 Type strideType = parser.getBuilder().getIntegerType(32);
3132 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3133 Type ptrType;
3134 Type elementType;
3135 if (parser.parseOperandList(operandInfo, 3) ||
3136 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3137 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3138 return failure();
3139 }
3140 if (parser.resolveOperands(operandInfo,
3141 {ptrType, strideType, columnMajorType},
3142 parser.getNameLoc(), state.operands)) {
3143 return failure();
3144 }
3145
3146 state.addTypes(elementType);
3147 return success();
3148 }
3149
print(spirv::CooperativeMatrixLoadNVOp M,OpAsmPrinter & printer)3150 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
3151 printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
3152 << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
3153 // Print optional memory access attribute.
3154 if (auto memAccess = M.memory_access())
3155 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3156 printer << " : " << M.pointer().getType() << " as " << M.getType();
3157 }
3158
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3159 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3160 Type coopMatrix) {
3161 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3162 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3163 return op->emitError(
3164 "Pointer must point to a scalar or vector type but provided ")
3165 << pointeeType;
3166 spirv::StorageClass storage =
3167 pointer.cast<spirv::PointerType>().getStorageClass();
3168 if (storage != spirv::StorageClass::Workgroup &&
3169 storage != spirv::StorageClass::StorageBuffer &&
3170 storage != spirv::StorageClass::PhysicalStorageBuffer)
3171 return op->emitError(
3172 "Pointer storage class must be Workgroup, StorageBuffer or "
3173 "PhysicalStorageBufferEXT but provided ")
3174 << stringifyStorageClass(storage);
3175 return success();
3176 }
3177
3178 //===----------------------------------------------------------------------===//
3179 // spv.CooperativeMatrixStoreNV
3180 //===----------------------------------------------------------------------===//
3181
parseCooperativeMatrixStoreNVOp(OpAsmParser & parser,OperationState & state)3182 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
3183 OperationState &state) {
3184 SmallVector<OpAsmParser::OperandType, 4> operandInfo;
3185 Type strideType = parser.getBuilder().getIntegerType(32);
3186 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3187 Type ptrType;
3188 Type elementType;
3189 if (parser.parseOperandList(operandInfo, 4) ||
3190 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3191 parser.parseType(ptrType) || parser.parseComma() ||
3192 parser.parseType(elementType)) {
3193 return failure();
3194 }
3195 if (parser.resolveOperands(
3196 operandInfo, {ptrType, elementType, strideType, columnMajorType},
3197 parser.getNameLoc(), state.operands)) {
3198 return failure();
3199 }
3200
3201 return success();
3202 }
3203
print(spirv::CooperativeMatrixStoreNVOp coopMatrix,OpAsmPrinter & printer)3204 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
3205 OpAsmPrinter &printer) {
3206 printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
3207 << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
3208 << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
3209 // Print optional memory access attribute.
3210 if (auto memAccess = coopMatrix.memory_access())
3211 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3212 printer << " : " << coopMatrix.pointer().getType() << ", "
3213 << coopMatrix.getOperand(1).getType();
3214 }
3215
3216 //===----------------------------------------------------------------------===//
3217 // spv.CooperativeMatrixMulAddNV
3218 //===----------------------------------------------------------------------===//
3219
3220 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3221 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3222 if (op.c().getType() != op.result().getType())
3223 return op.emitOpError("result and third operand must have the same type");
3224 auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3225 auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3226 auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3227 auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3228 if (typeA.getRows() != typeR.getRows() ||
3229 typeA.getColumns() != typeB.getRows() ||
3230 typeB.getColumns() != typeR.getColumns())
3231 return op.emitOpError("matrix size must match");
3232 if (typeR.getScope() != typeA.getScope() ||
3233 typeR.getScope() != typeB.getScope() ||
3234 typeR.getScope() != typeC.getScope())
3235 return op.emitOpError("matrix scope must match");
3236 if (typeA.getElementType() != typeB.getElementType() ||
3237 typeR.getElementType() != typeC.getElementType())
3238 return op.emitOpError("matrix element type must match");
3239 return success();
3240 }
3241
3242 //===----------------------------------------------------------------------===//
3243 // spv.MatrixTimesScalar
3244 //===----------------------------------------------------------------------===//
3245
verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op)3246 static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
3247 // We already checked that result and matrix are both of matrix type in the
3248 // auto-generated verify method.
3249
3250 auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3251 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3252
3253 // Check that the scalar type is the same as the matrix element type.
3254 if (op.scalar().getType() != inputMatrix.getElementType())
3255 return op.emitError("input matrix components' type and scaling value must "
3256 "have the same type");
3257
3258 // Note that the next three checks could be done using the AllTypesMatch
3259 // trait in the Op definition file but it generates a vague error message.
3260
3261 // Check that the input and result matrices have the same columns' count
3262 if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3263 return op.emitError("input and result matrices must have the same "
3264 "number of columns");
3265
3266 // Check that the input and result matrices' have the same rows count
3267 if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3268 return op.emitError("input and result matrices' columns must have "
3269 "the same size");
3270
3271 // Check that the input and result matrices' have the same component type
3272 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3273 return op.emitError("input and result matrices' columns must have "
3274 "the same component type");
3275
3276 return success();
3277 }
3278
3279 //===----------------------------------------------------------------------===//
3280 // spv.CopyMemory
3281 //===----------------------------------------------------------------------===//
3282
print(spirv::CopyMemoryOp copyMemory,OpAsmPrinter & printer)3283 static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
3284 auto *op = copyMemory.getOperation();
3285 printer << spirv::CopyMemoryOp::getOperationName() << ' ';
3286
3287 StringRef targetStorageClass =
3288 stringifyStorageClass(copyMemory.target()
3289 .getType()
3290 .cast<spirv::PointerType>()
3291 .getStorageClass());
3292 printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
3293 << ", ";
3294
3295 StringRef sourceStorageClass =
3296 stringifyStorageClass(copyMemory.source()
3297 .getType()
3298 .cast<spirv::PointerType>()
3299 .getStorageClass());
3300 printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
3301
3302 SmallVector<StringRef, 4> elidedAttrs;
3303 printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
3304 printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
3305 copyMemory.source_memory_access(),
3306 copyMemory.source_alignment());
3307
3308 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3309
3310 Type pointeeType =
3311 copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3312 printer << " : " << pointeeType;
3313 }
3314
parseCopyMemoryOp(OpAsmParser & parser,OperationState & state)3315 static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
3316 OperationState &state) {
3317 spirv::StorageClass targetStorageClass;
3318 OpAsmParser::OperandType targetPtrInfo;
3319
3320 spirv::StorageClass sourceStorageClass;
3321 OpAsmParser::OperandType sourcePtrInfo;
3322
3323 Type elementType;
3324
3325 if (parseEnumStrAttr(targetStorageClass, parser) ||
3326 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3327 parseEnumStrAttr(sourceStorageClass, parser) ||
3328 parser.parseOperand(sourcePtrInfo) ||
3329 parseMemoryAccessAttributes(parser, state)) {
3330 return failure();
3331 }
3332
3333 if (!parser.parseOptionalComma()) {
3334 // Parse 2nd memory access attributes.
3335 if (parseSourceMemoryAccessAttributes(parser, state)) {
3336 return failure();
3337 }
3338 }
3339
3340 if (parser.parseColon() || parser.parseType(elementType))
3341 return failure();
3342
3343 if (parser.parseOptionalAttrDict(state.attributes))
3344 return failure();
3345
3346 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3347 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3348
3349 if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3350 parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3351 return failure();
3352 }
3353
3354 return success();
3355 }
3356
verifyCopyMemory(spirv::CopyMemoryOp copyMemory)3357 static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
3358 Type targetType =
3359 copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3360
3361 Type sourceType =
3362 copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
3363
3364 if (targetType != sourceType) {
3365 return copyMemory.emitOpError(
3366 "both operands must be pointers to the same type");
3367 }
3368
3369 if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3370 return failure();
3371 }
3372
3373 // TODO - According to the spec:
3374 //
3375 // If two masks are present, the first applies to Target and cannot include
3376 // MakePointerVisible, and the second applies to Source and cannot include
3377 // MakePointerAvailable.
3378 //
3379 // Add such verification here.
3380
3381 return verifySourceMemoryAccessAttribute(copyMemory);
3382 }
3383
3384 //===----------------------------------------------------------------------===//
3385 // spv.Transpose
3386 //===----------------------------------------------------------------------===//
3387
verifyTranspose(spirv::TransposeOp op)3388 static LogicalResult verifyTranspose(spirv::TransposeOp op) {
3389 auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3390 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3391
3392 // Verify that the input and output matrices have correct shapes.
3393 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3394 return op.emitError("input matrix rows count must be equal to "
3395 "output matrix columns count");
3396
3397 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3398 return op.emitError("input matrix columns count must be equal to "
3399 "output matrix rows count");
3400
3401 // Verify that the input and output matrices have the same component type
3402 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3403 return op.emitError("input and output matrices must have the same "
3404 "component type");
3405
3406 return success();
3407 }
3408
3409 //===----------------------------------------------------------------------===//
3410 // spv.MatrixTimesMatrix
3411 //===----------------------------------------------------------------------===//
3412
verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op)3413 static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
3414 auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
3415 auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
3416 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3417
3418 // left matrix columns' count and right matrix rows' count must be equal
3419 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
3420 return op.emitError("left matrix columns' count must be equal to "
3421 "the right matrix rows' count");
3422
3423 // right and result matrices columns' count must be the same
3424 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
3425 return op.emitError(
3426 "right and result matrices must have equal columns' count");
3427
3428 // right and result matrices component type must be the same
3429 if (rightMatrix.getElementType() != resultMatrix.getElementType())
3430 return op.emitError("right and result matrices' component type must"
3431 " be the same");
3432
3433 // left and result matrices component type must be the same
3434 if (leftMatrix.getElementType() != resultMatrix.getElementType())
3435 return op.emitError("left and result matrices' component type"
3436 " must be the same");
3437
3438 // left and result matrices rows count must be the same
3439 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
3440 return op.emitError("left and result matrices must have equal rows'"
3441 " count");
3442
3443 return success();
3444 }
3445
3446 //===----------------------------------------------------------------------===//
3447 // spv.SpecConstantComposite
3448 //===----------------------------------------------------------------------===//
3449
parseSpecConstantCompositeOp(OpAsmParser & parser,OperationState & state)3450 static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
3451 OperationState &state) {
3452
3453 StringAttr compositeName;
3454 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
3455 state.attributes))
3456 return failure();
3457
3458 if (parser.parseLParen())
3459 return failure();
3460
3461 SmallVector<Attribute, 4> constituents;
3462
3463 do {
3464 // The name of the constituent attribute isn't important
3465 const char *attrName = "spec_const";
3466 FlatSymbolRefAttr specConstRef;
3467 NamedAttrList attrs;
3468
3469 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
3470 return failure();
3471
3472 constituents.push_back(specConstRef);
3473 } while (!parser.parseOptionalComma());
3474
3475 if (parser.parseRParen())
3476 return failure();
3477
3478 state.addAttribute(kCompositeSpecConstituentsName,
3479 parser.getBuilder().getArrayAttr(constituents));
3480
3481 Type type;
3482 if (parser.parseColonType(type))
3483 return failure();
3484
3485 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
3486
3487 return success();
3488 }
3489
print(spirv::SpecConstantCompositeOp op,OpAsmPrinter & printer)3490 static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
3491 printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
3492 printer.printSymbolName(op.sym_name());
3493 printer << " (";
3494 auto constituents = op.constituents().getValue();
3495
3496 if (!constituents.empty())
3497 llvm::interleaveComma(constituents, printer);
3498
3499 printer << ") : " << op.type();
3500 }
3501
verify(spirv::SpecConstantCompositeOp constOp)3502 static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
3503 auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
3504 auto constituents = constOp.constituents().getValue();
3505
3506 if (!cType)
3507 return constOp.emitError(
3508 "result type must be a composite type, but provided ")
3509 << constOp.type();
3510
3511 if (cType.isa<spirv::CooperativeMatrixNVType>())
3512 return constOp.emitError("unsupported composite type ") << cType;
3513 else if (constituents.size() != cType.getNumElements())
3514 return constOp.emitError("has incorrect number of operands: expected ")
3515 << cType.getNumElements() << ", but provided "
3516 << constituents.size();
3517
3518 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
3519 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
3520
3521 auto constituentSpecConstOp =
3522 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
3523 constOp->getParentOp(), constituent.getValue()));
3524
3525 if (constituentSpecConstOp.default_value().getType() !=
3526 cType.getElementType(index))
3527 return constOp.emitError("has incorrect types of operands: expected ")
3528 << cType.getElementType(index) << ", but provided "
3529 << constituentSpecConstOp.default_value().getType();
3530 }
3531
3532 return success();
3533 }
3534
3535 //===----------------------------------------------------------------------===//
3536 // spv.SpecConstantOperation
3537 //===----------------------------------------------------------------------===//
3538
parseSpecConstantOperationOp(OpAsmParser & parser,OperationState & state)3539 static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
3540 OperationState &state) {
3541 Region *body = state.addRegion();
3542
3543 if (parser.parseKeyword("wraps"))
3544 return failure();
3545
3546 body->push_back(new Block);
3547 Block &block = body->back();
3548 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
3549
3550 if (!wrappedOp)
3551 return failure();
3552
3553 OpBuilder builder(parser.getBuilder().getContext());
3554 builder.setInsertionPointToEnd(&block);
3555 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
3556 state.location = wrappedOp->getLoc();
3557
3558 state.addTypes(wrappedOp->getResult(0).getType());
3559
3560 if (parser.parseOptionalAttrDict(state.attributes))
3561 return failure();
3562
3563 return success();
3564 }
3565
print(spirv::SpecConstantOperationOp op,OpAsmPrinter & printer)3566 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
3567 printer << op.getOperationName() << " wraps ";
3568 printer.printGenericOp(&op.body().front().front());
3569 }
3570
verify(spirv::SpecConstantOperationOp constOp)3571 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
3572 Block &block = constOp.getRegion().getBlocks().front();
3573
3574 if (block.getOperations().size() != 2)
3575 return constOp.emitOpError("expected exactly 2 nested ops");
3576
3577 Operation &enclosedOp = block.getOperations().front();
3578
3579 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
3580 return constOp.emitOpError("invalid enclosed op");
3581
3582 for (auto operand : enclosedOp.getOperands())
3583 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
3584 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
3585 return constOp.emitOpError(
3586 "invalid operand, must be defined by a constant operation");
3587
3588 return success();
3589 }
3590
3591 //===----------------------------------------------------------------------===//
3592 // spv.GLSL.FrexpStruct
3593 //===----------------------------------------------------------------------===//
3594 static LogicalResult
verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp)3595 verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) {
3596 spirv::StructType structTy =
3597 frexpStructOp.result().getType().dyn_cast<spirv::StructType>();
3598
3599 if (structTy.getNumElements() != 2)
3600 return frexpStructOp.emitError("result type must be a struct type "
3601 "with two memebers");
3602
3603 Type significandTy = structTy.getElementType(0);
3604 Type exponentTy = structTy.getElementType(1);
3605 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
3606 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
3607
3608 Type operandTy = frexpStructOp.operand().getType();
3609 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
3610 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
3611
3612 if (significandTy != operandTy)
3613 return frexpStructOp.emitError("member zero of the resulting struct type "
3614 "must be the same type as the operand");
3615
3616 if (exponentVecTy) {
3617 IntegerType componentIntTy =
3618 exponentVecTy.getElementType().dyn_cast<IntegerType>();
3619 if (!(componentIntTy && componentIntTy.getWidth() == 32))
3620 return frexpStructOp.emitError(
3621 "member one of the resulting struct type must"
3622 "be a scalar or vector of 32 bit integer type");
3623 } else if (!(exponentIntTy && exponentIntTy.getWidth() == 32)) {
3624 return frexpStructOp.emitError(
3625 "member one of the resulting struct type "
3626 "must be a scalar or vector of 32 bit integer type");
3627 }
3628
3629 // Check that the two member types have the same number of components
3630 if (operandVecTy && exponentVecTy &&
3631 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
3632 return success();
3633
3634 if (operandFTy && exponentIntTy)
3635 return success();
3636
3637 return frexpStructOp.emitError(
3638 "member one of the resulting struct type "
3639 "must have the same number of components as the operand type");
3640 }
3641
3642 //===----------------------------------------------------------------------===//
3643 // spv.GLSL.Ldexp
3644 //===----------------------------------------------------------------------===//
3645
verify(spirv::GLSLLdexpOp ldexpOp)3646 static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) {
3647 Type significandType = ldexpOp.x().getType();
3648 Type exponentType = ldexpOp.exp().getType();
3649
3650 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
3651 return ldexpOp.emitOpError("operands must both be scalars or vectors");
3652
3653 auto getNumElements = [](Type type) -> unsigned {
3654 if (auto vectorType = type.dyn_cast<VectorType>())
3655 return vectorType.getNumElements();
3656 return 1;
3657 };
3658
3659 if (getNumElements(significandType) != getNumElements(exponentType))
3660 return ldexpOp.emitOpError(
3661 "operands must have the same number of elements");
3662
3663 return success();
3664 }
3665
3666 //===----------------------------------------------------------------------===//
3667 // spv.ImageDrefGather
3668 //===----------------------------------------------------------------------===//
3669
verify(spirv::ImageDrefGatherOp imageDrefGatherOp)3670 static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) {
3671 // TODO: Support optional operands.
3672 VectorType resultType =
3673 imageDrefGatherOp.result().getType().cast<VectorType>();
3674 auto sampledImageType = imageDrefGatherOp.sampledimage()
3675 .getType()
3676 .cast<spirv::SampledImageType>();
3677 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
3678
3679 if (resultType.getNumElements() != 4)
3680 return imageDrefGatherOp.emitOpError(
3681 "result type must be a vector of four components");
3682
3683 Type elementType = resultType.getElementType();
3684 Type sampledElementType = imageType.getElementType();
3685 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
3686 return imageDrefGatherOp.emitOpError(
3687 "the component type of result must be the same as sampled type of the "
3688 "underlying image type");
3689
3690 spirv::Dim imageDim = imageType.getDim();
3691 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
3692
3693 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
3694 imageDim != spirv::Dim::Rect)
3695 return imageDrefGatherOp.emitOpError(
3696 "the Dim operand of the underlying image type must be 2D, Cube, or "
3697 "Rect");
3698
3699 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
3700 return imageDrefGatherOp.emitOpError(
3701 "the MS operand of the underlying image type must be 0");
3702
3703 return success();
3704 }
3705
3706 //===----------------------------------------------------------------------===//
3707 // spv.ImageQuerySize
3708 //===----------------------------------------------------------------------===//
3709
verify(spirv::ImageQuerySizeOp imageQuerySizeOp)3710 static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
3711 spirv::ImageType imageType =
3712 imageQuerySizeOp.image().getType().cast<spirv::ImageType>();
3713 Type resultType = imageQuerySizeOp.result().getType();
3714
3715 spirv::Dim dim = imageType.getDim();
3716 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
3717 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
3718 switch (dim) {
3719 case spirv::Dim::Dim1D:
3720 case spirv::Dim::Dim2D:
3721 case spirv::Dim::Dim3D:
3722 case spirv::Dim::Cube:
3723 if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled ||
3724 samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown ||
3725 samplerInfo == spirv::ImageSamplerUseInfo::NoSampler))
3726 return imageQuerySizeOp.emitError(
3727 "if Dim is 1D, 2D, 3D, or Cube, "
3728 "it must also have either an MS of 1 or a Sampled of 0 or 2");
3729 break;
3730 case spirv::Dim::Buffer:
3731 case spirv::Dim::Rect:
3732 break;
3733 default:
3734 return imageQuerySizeOp.emitError("the Dim operand of the image type must "
3735 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
3736 }
3737
3738 unsigned componentNumber = 0;
3739 switch (dim) {
3740 case spirv::Dim::Dim1D:
3741 case spirv::Dim::Buffer:
3742 componentNumber = 1;
3743 break;
3744 case spirv::Dim::Dim2D:
3745 case spirv::Dim::Cube:
3746 case spirv::Dim::Rect:
3747 componentNumber = 2;
3748 break;
3749 case spirv::Dim::Dim3D:
3750 componentNumber = 3;
3751 break;
3752 default:
3753 break;
3754 }
3755
3756 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
3757 componentNumber += 1;
3758
3759 unsigned resultComponentNumber = 1;
3760 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
3761 resultComponentNumber = resultVectorType.getNumElements();
3762
3763 if (componentNumber != resultComponentNumber)
3764 return imageQuerySizeOp.emitError("expected the result to have ")
3765 << componentNumber << " component(s), but found "
3766 << resultComponentNumber << " component(s)";
3767
3768 return success();
3769 }
3770
3771 namespace mlir {
3772 namespace spirv {
3773
3774 // TableGen'erated operation interfaces for querying versions, extensions, and
3775 // capabilities.
3776 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
3777 } // namespace spirv
3778 } // namespace mlir
3779
3780 // TablenGen'erated operation definitions.
3781 #define GET_OP_CLASSES
3782 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
3783
3784 namespace mlir {
3785 namespace spirv {
3786 // TableGen'erated operation availability interface implementations.
3787 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
3788
3789 } // namespace spirv
3790 } // namespace mlir
3791