1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
16 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/IR/StandardTypes.h"
21 #include "mlir/Parser.h"
22 #include "mlir/Support/StringExtras.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringMap.h"
28 #include "llvm/ADT/StringSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
30
31 namespace mlir {
32 namespace spirv {
33 #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
34 } // namespace spirv
35 } // namespace mlir
36
37 using namespace mlir;
38 using namespace mlir::spirv;
39
40 //===----------------------------------------------------------------------===//
41 // InlinerInterface
42 //===----------------------------------------------------------------------===//
43
44 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
containsReturn(Region & region)45 static inline bool containsReturn(Region ®ion) {
46 return llvm::any_of(region, [](Block &block) {
47 Operation *terminator = block.getTerminator();
48 return isa<spirv::ReturnOp>(terminator) ||
49 isa<spirv::ReturnValueOp>(terminator);
50 });
51 }
52
53 namespace {
54 /// This class defines the interface for inlining within the SPIR-V dialect.
55 struct SPIRVInlinerInterface : public DialectInlinerInterface {
56 using DialectInlinerInterface::DialectInlinerInterface;
57
58 /// Returns true if the given region 'src' can be inlined into the region
59 /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anonf10db8e40211::SPIRVInlinerInterface60 bool isLegalToInline(Region *dest, Region *src,
61 BlockAndValueMapping &) const final {
62 // Return true here when inlining into spv.selection and spv.loop
63 // operations.
64 auto op = dest->getParentOp();
65 return isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op);
66 }
67
68 /// Returns true if the given operation 'op', that is registered to this
69 /// dialect, can be inlined into the region 'dest' that is attached to an
70 /// operation registered to the current dialect.
isLegalToInline__anonf10db8e40211::SPIRVInlinerInterface71 bool isLegalToInline(Operation *op, Region *dest,
72 BlockAndValueMapping &) const final {
73 // TODO(antiagainst): Enable inlining structured control flows with return.
74 if ((isa<spirv::SelectionOp>(op) || isa<spirv::LoopOp>(op)) &&
75 containsReturn(op->getRegion(0)))
76 return false;
77 // TODO(antiagainst): we need to filter OpKill here to avoid inlining it to
78 // a loop continue construct:
79 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
80 // However OpKill is fragment shader specific and we don't support it yet.
81 return true;
82 }
83
84 /// Handle the given inlined terminator by replacing it with a new operation
85 /// as necessary.
handleTerminator__anonf10db8e40211::SPIRVInlinerInterface86 void handleTerminator(Operation *op, Block *newDest) const final {
87 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
88 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
89 op->erase();
90 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
91 llvm_unreachable("unimplemented spv.ReturnValue in inliner");
92 }
93 }
94
95 /// Handle the given inlined terminator by replacing it with a new operation
96 /// as necessary.
handleTerminator__anonf10db8e40211::SPIRVInlinerInterface97 void handleTerminator(Operation *op,
98 ArrayRef<Value> valuesToRepl) const final {
99 // Only spv.ReturnValue needs to be handled here.
100 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
101 if (!retValOp)
102 return;
103
104 // Replace the values directly with the return operands.
105 assert(valuesToRepl.size() == 1 &&
106 "spv.ReturnValue expected to only handle one result");
107 valuesToRepl.front().replaceAllUsesWith(retValOp.value());
108 }
109 };
110 } // namespace
111
112 //===----------------------------------------------------------------------===//
113 // SPIR-V Dialect
114 //===----------------------------------------------------------------------===//
115
SPIRVDialect(MLIRContext * context)116 SPIRVDialect::SPIRVDialect(MLIRContext *context)
117 : Dialect(getDialectNamespace(), context) {
118 addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
119
120 // Add SPIR-V ops.
121 addOperations<
122 #define GET_OP_LIST
123 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
124 >();
125
126 addInterfaces<SPIRVInlinerInterface>();
127
128 // Allow unknown operations because SPIR-V is extensible.
129 allowUnknownOperations();
130 }
131
getAttributeName(Decoration decoration)132 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
133 return convertToSnakeCase(stringifyDecoration(decoration));
134 }
135
136 //===----------------------------------------------------------------------===//
137 // Type Parsing
138 //===----------------------------------------------------------------------===//
139
140 // Forward declarations.
141 template <typename ValTy>
142 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
143 DialectAsmParser &parser);
144 template <>
145 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
146 DialectAsmParser &parser);
147
148 template <>
149 Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
150 DialectAsmParser &parser);
151
isValidSPIRVIntType(IntegerType type)152 static bool isValidSPIRVIntType(IntegerType type) {
153 return llvm::is_contained(ArrayRef<unsigned>({1, 8, 16, 32, 64}),
154 type.getWidth());
155 }
156
isValidScalarType(Type type)157 bool SPIRVDialect::isValidScalarType(Type type) {
158 if (type.isa<FloatType>()) {
159 return !type.isBF16();
160 }
161 if (auto intType = type.dyn_cast<IntegerType>()) {
162 return isValidSPIRVIntType(intType);
163 }
164 return false;
165 }
166
isValidSPIRVVectorType(VectorType type)167 static bool isValidSPIRVVectorType(VectorType type) {
168 return type.getRank() == 1 &&
169 SPIRVDialect::isValidScalarType(type.getElementType()) &&
170 type.getNumElements() >= 2 && type.getNumElements() <= 4;
171 }
172
isValidType(Type type)173 bool SPIRVDialect::isValidType(Type type) {
174 // Allow SPIR-V dialect types
175 if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
176 type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
177 return true;
178 }
179 if (SPIRVDialect::isValidScalarType(type)) {
180 return true;
181 }
182 if (auto vectorType = type.dyn_cast<VectorType>()) {
183 return isValidSPIRVVectorType(vectorType);
184 }
185 return false;
186 }
187
parseAndVerifyType(SPIRVDialect const & dialect,DialectAsmParser & parser)188 static Type parseAndVerifyType(SPIRVDialect const &dialect,
189 DialectAsmParser &parser) {
190 Type type;
191 llvm::SMLoc typeLoc = parser.getCurrentLocation();
192 if (parser.parseType(type))
193 return Type();
194
195 // Allow SPIR-V dialect types
196 if (&type.getDialect() == &dialect)
197 return type;
198
199 // Check other allowed types
200 if (auto t = type.dyn_cast<FloatType>()) {
201 if (type.isBF16()) {
202 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
203 return Type();
204 }
205 } else if (auto t = type.dyn_cast<IntegerType>()) {
206 if (!isValidSPIRVIntType(t)) {
207 parser.emitError(typeLoc,
208 "only 1/8/16/32/64-bit integer type allowed but found ")
209 << type;
210 return Type();
211 }
212 } else if (auto t = type.dyn_cast<VectorType>()) {
213 if (t.getRank() != 1) {
214 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215 return Type();
216 }
217 if (t.getNumElements() > 4) {
218 parser.emitError(
219 typeLoc, "vector length has to be less than or equal to 4 but found ")
220 << t.getNumElements();
221 return Type();
222 }
223 } else {
224 parser.emitError(typeLoc, "cannot use ")
225 << type << " to compose SPIR-V types";
226 return Type();
227 }
228
229 return type;
230 }
231
232 // element-type ::= integer-type
233 // | floating-point-type
234 // | vector-type
235 // | spirv-type
236 //
237 // array-type ::= `!spv.array<` integer-literal `x` element-type
238 // (`[` integer-literal `]`)? `>`
parseArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)239 static Type parseArrayType(SPIRVDialect const &dialect,
240 DialectAsmParser &parser) {
241 if (parser.parseLess())
242 return Type();
243
244 SmallVector<int64_t, 1> countDims;
245 llvm::SMLoc countLoc = parser.getCurrentLocation();
246 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
247 return Type();
248 if (countDims.size() != 1) {
249 parser.emitError(countLoc,
250 "expected single integer for array element count");
251 return Type();
252 }
253
254 // According to the SPIR-V spec:
255 // "Length is the number of elements in the array. It must be at least 1."
256 int64_t count = countDims[0];
257 if (count == 0) {
258 parser.emitError(countLoc, "expected array length greater than 0");
259 return Type();
260 }
261
262 Type elementType = parseAndVerifyType(dialect, parser);
263 if (!elementType)
264 return Type();
265
266 ArrayType::LayoutInfo layoutInfo = 0;
267 if (succeeded(parser.parseOptionalLSquare())) {
268 llvm::SMLoc layoutLoc = parser.getCurrentLocation();
269 auto layout = parseAndVerify<ArrayType::LayoutInfo>(dialect, parser);
270 if (!layout)
271 return Type();
272
273 if (!(layoutInfo = layout.getValue())) {
274 parser.emitError(layoutLoc, "ArrayStride must be greater than zero");
275 return Type();
276 }
277
278 if (parser.parseRSquare())
279 return Type();
280 }
281
282 if (parser.parseGreater())
283 return Type();
284 return ArrayType::get(elementType, count, layoutInfo);
285 }
286
287 // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
288 // methods in alphabetical order
289 //
290 // storage-class ::= `UniformConstant`
291 // | `Uniform`
292 // | `Workgroup`
293 // | <and other storage classes...>
294 //
295 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
parsePointerType(SPIRVDialect const & dialect,DialectAsmParser & parser)296 static Type parsePointerType(SPIRVDialect const &dialect,
297 DialectAsmParser &parser) {
298 if (parser.parseLess())
299 return Type();
300
301 auto pointeeType = parseAndVerifyType(dialect, parser);
302 if (!pointeeType)
303 return Type();
304
305 StringRef storageClassSpec;
306 llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
307 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
308 return Type();
309
310 auto storageClass = symbolizeStorageClass(storageClassSpec);
311 if (!storageClass) {
312 parser.emitError(storageClassLoc, "unknown storage class: ")
313 << storageClassSpec;
314 return Type();
315 }
316 if (parser.parseGreater())
317 return Type();
318 return PointerType::get(pointeeType, *storageClass);
319 }
320
321 // runtime-array-type ::= `!spv.rtarray<` element-type `>`
parseRuntimeArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)322 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
323 DialectAsmParser &parser) {
324 if (parser.parseLess())
325 return Type();
326
327 Type elementType = parseAndVerifyType(dialect, parser);
328 if (!elementType)
329 return Type();
330
331 if (parser.parseGreater())
332 return Type();
333 return RuntimeArrayType::get(elementType);
334 }
335
336 // Specialize this function to parse each of the parameters that define an
337 // ImageType. By default it assumes this is an enum type.
338 template <typename ValTy>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)339 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
340 DialectAsmParser &parser) {
341 StringRef enumSpec;
342 llvm::SMLoc enumLoc = parser.getCurrentLocation();
343 if (parser.parseKeyword(&enumSpec)) {
344 return llvm::None;
345 }
346
347 auto val = spirv::symbolizeEnum<ValTy>()(enumSpec);
348 if (!val)
349 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
350 return val;
351 }
352
353 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)354 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
355 DialectAsmParser &parser) {
356 // TODO(ravishankarm): Further verify that the element type can be sampled
357 auto ty = parseAndVerifyType(dialect, parser);
358 if (!ty)
359 return llvm::None;
360 return ty;
361 }
362
363 template <typename IntTy>
parseAndVerifyInteger(SPIRVDialect const & dialect,DialectAsmParser & parser)364 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
365 DialectAsmParser &parser) {
366 IntTy offsetVal = std::numeric_limits<IntTy>::max();
367 if (parser.parseInteger(offsetVal))
368 return llvm::None;
369 return offsetVal;
370 }
371
372 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)373 Optional<uint64_t> parseAndVerify<uint64_t>(SPIRVDialect const &dialect,
374 DialectAsmParser &parser) {
375 return parseAndVerifyInteger<uint64_t>(dialect, parser);
376 }
377
378 namespace {
379 // Functor object to parse a comma separated list of specs. The function
380 // parseAndVerify does the actual parsing and verification of individual
381 // elements. This is a functor since parsing the last element of the list
382 // (termination condition) needs partial specialization.
383 template <typename ParseType, typename... Args> struct parseCommaSeparatedList {
384 Optional<std::tuple<ParseType, Args...>>
operator ()__anonf10db8e40311::parseCommaSeparatedList385 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
386 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
387 if (!parseVal)
388 return llvm::None;
389
390 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
391 if (numArgs != 0 && failed(parser.parseComma()))
392 return llvm::None;
393 auto remainingValues = parseCommaSeparatedList<Args...>{}(dialect, parser);
394 if (!remainingValues)
395 return llvm::None;
396 return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
397 remainingValues.getValue());
398 }
399 };
400
401 // Partial specialization of the function to parse a comma separated list of
402 // specs to parse the last element of the list.
403 template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
operator ()__anonf10db8e40311::parseCommaSeparatedList404 Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
405 DialectAsmParser &parser) const {
406 if (auto value = parseAndVerify<ParseType>(dialect, parser))
407 return std::tuple<ParseType>(value.getValue());
408 return llvm::None;
409 }
410 };
411 } // namespace
412
413 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
414 //
415 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
416 //
417 // arrayed-info ::= `NonArrayed` | `Arrayed`
418 //
419 // sampling-info ::= `SingleSampled` | `MultiSampled`
420 //
421 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
422 //
423 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
424 //
425 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
426 // arrayed-info `,` sampling-info `,`
427 // sampler-use-info `,` format `>`
parseImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)428 static Type parseImageType(SPIRVDialect const &dialect,
429 DialectAsmParser &parser) {
430 if (parser.parseLess())
431 return Type();
432
433 auto value =
434 parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
435 ImageSamplingInfo, ImageSamplerUseInfo,
436 ImageFormat>{}(dialect, parser);
437 if (!value)
438 return Type();
439
440 if (parser.parseGreater())
441 return Type();
442 return ImageType::get(value.getValue());
443 }
444
445 // Parse decorations associated with a member.
parseStructMemberDecorations(SPIRVDialect const & dialect,DialectAsmParser & parser,ArrayRef<Type> memberTypes,SmallVectorImpl<StructType::LayoutInfo> & layoutInfo,SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorationInfo)446 static ParseResult parseStructMemberDecorations(
447 SPIRVDialect const &dialect, DialectAsmParser &parser,
448 ArrayRef<Type> memberTypes,
449 SmallVectorImpl<StructType::LayoutInfo> &layoutInfo,
450 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
451
452 // Check if the first element is offset.
453 llvm::SMLoc layoutLoc = parser.getCurrentLocation();
454 StructType::LayoutInfo layout = 0;
455 OptionalParseResult layoutParseResult = parser.parseOptionalInteger(layout);
456 if (layoutParseResult.hasValue()) {
457 if (failed(*layoutParseResult))
458 return failure();
459
460 if (layoutInfo.size() != memberTypes.size() - 1) {
461 return parser.emitError(
462 layoutLoc, "layout specification must be given for all members");
463 }
464 layoutInfo.push_back(layout);
465 }
466
467 // Check for no spirv::Decorations.
468 if (succeeded(parser.parseOptionalRSquare()))
469 return success();
470
471 // If there was a layout, make sure to parse the comma.
472 if (layoutParseResult.hasValue() && parser.parseComma())
473 return failure();
474
475 // Check for spirv::Decorations.
476 do {
477 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
478 if (!memberDecoration)
479 return failure();
480
481 memberDecorationInfo.emplace_back(
482 static_cast<uint32_t>(memberTypes.size() - 1),
483 memberDecoration.getValue());
484 } while (succeeded(parser.parseOptionalComma()));
485
486 return parser.parseRSquare();
487 }
488
489 // struct-member-decoration ::= integer-literal? spirv-decoration*
490 // struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)?
491 // (`, ` spirv-type (`[` struct-member-decoration `]`)? `>`
parseStructType(SPIRVDialect const & dialect,DialectAsmParser & parser)492 static Type parseStructType(SPIRVDialect const &dialect,
493 DialectAsmParser &parser) {
494 if (parser.parseLess())
495 return Type();
496
497 if (succeeded(parser.parseOptionalGreater()))
498 return StructType::getEmpty(dialect.getContext());
499
500 SmallVector<Type, 4> memberTypes;
501 SmallVector<StructType::LayoutInfo, 4> layoutInfo;
502 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
503
504 do {
505 Type memberType;
506 if (parser.parseType(memberType))
507 return Type();
508 memberTypes.push_back(memberType);
509
510 if (succeeded(parser.parseOptionalLSquare())) {
511 if (parseStructMemberDecorations(dialect, parser, memberTypes, layoutInfo,
512 memberDecorationInfo)) {
513 return Type();
514 }
515 }
516 } while (succeeded(parser.parseOptionalComma()));
517
518 if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) {
519 parser.emitError(parser.getNameLoc(),
520 "layout specification must be given for all members");
521 return Type();
522 }
523 if (parser.parseGreater())
524 return Type();
525 return StructType::get(memberTypes, layoutInfo, memberDecorationInfo);
526 }
527
528 // spirv-type ::= array-type
529 // | element-type
530 // | image-type
531 // | pointer-type
532 // | runtime-array-type
533 // | struct-type
parseType(DialectAsmParser & parser) const534 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
535 StringRef keyword;
536 if (parser.parseKeyword(&keyword))
537 return Type();
538
539 if (keyword == "array")
540 return parseArrayType(*this, parser);
541 if (keyword == "image")
542 return parseImageType(*this, parser);
543 if (keyword == "ptr")
544 return parsePointerType(*this, parser);
545 if (keyword == "rtarray")
546 return parseRuntimeArrayType(*this, parser);
547 if (keyword == "struct")
548 return parseStructType(*this, parser);
549
550 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
551 return Type();
552 }
553
554 //===----------------------------------------------------------------------===//
555 // Type Printing
556 //===----------------------------------------------------------------------===//
557
print(ArrayType type,DialectAsmPrinter & os)558 static void print(ArrayType type, DialectAsmPrinter &os) {
559 os << "array<" << type.getNumElements() << " x " << type.getElementType();
560 if (type.hasLayout()) {
561 os << " [" << type.getArrayStride() << "]";
562 }
563 os << ">";
564 }
565
print(RuntimeArrayType type,DialectAsmPrinter & os)566 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
567 os << "rtarray<" << type.getElementType() << ">";
568 }
569
print(PointerType type,DialectAsmPrinter & os)570 static void print(PointerType type, DialectAsmPrinter &os) {
571 os << "ptr<" << type.getPointeeType() << ", "
572 << stringifyStorageClass(type.getStorageClass()) << ">";
573 }
574
print(ImageType type,DialectAsmPrinter & os)575 static void print(ImageType type, DialectAsmPrinter &os) {
576 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
577 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
578 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
579 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
580 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
581 << stringifyImageFormat(type.getImageFormat()) << ">";
582 }
583
print(StructType type,DialectAsmPrinter & os)584 static void print(StructType type, DialectAsmPrinter &os) {
585 os << "struct<";
586 auto printMember = [&](unsigned i) {
587 os << type.getElementType(i);
588 SmallVector<spirv::Decoration, 0> decorations;
589 type.getMemberDecorations(i, decorations);
590 if (type.hasLayout() || !decorations.empty()) {
591 os << " [";
592 if (type.hasLayout()) {
593 os << type.getOffset(i);
594 if (!decorations.empty())
595 os << ", ";
596 }
597 auto each_fn = [&os](spirv::Decoration decoration) {
598 os << stringifyDecoration(decoration);
599 };
600 interleaveComma(decorations, os, each_fn);
601 os << "]";
602 }
603 };
604 interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
605 printMember);
606 os << ">";
607 }
608
printType(Type type,DialectAsmPrinter & os) const609 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
610 switch (type.getKind()) {
611 case TypeKind::Array:
612 print(type.cast<ArrayType>(), os);
613 return;
614 case TypeKind::Pointer:
615 print(type.cast<PointerType>(), os);
616 return;
617 case TypeKind::RuntimeArray:
618 print(type.cast<RuntimeArrayType>(), os);
619 return;
620 case TypeKind::Image:
621 print(type.cast<ImageType>(), os);
622 return;
623 case TypeKind::Struct:
624 print(type.cast<StructType>(), os);
625 return;
626 default:
627 llvm_unreachable("unhandled SPIR-V type");
628 }
629 }
630
631 //===----------------------------------------------------------------------===//
632 // Constant
633 //===----------------------------------------------------------------------===//
634
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)635 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
636 Attribute value, Type type,
637 Location loc) {
638 if (!ConstantOp::isBuildableWith(type))
639 return nullptr;
640
641 return builder.create<spirv::ConstantOp>(loc, type, value);
642 }
643
644 //===----------------------------------------------------------------------===//
645 // Shader Interface ABI
646 //===----------------------------------------------------------------------===//
647
verifyOperationAttribute(Operation * op,NamedAttribute attribute)648 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
649 NamedAttribute attribute) {
650 StringRef symbol = attribute.first.strref();
651 Attribute attr = attribute.second;
652
653 // TODO(antiagainst): figure out a way to generate the description from the
654 // StructAttr definition.
655 if (symbol == spirv::getEntryPointABIAttrName()) {
656 if (!attr.isa<spirv::EntryPointABIAttr>())
657 return op->emitError("'")
658 << symbol
659 << "' attribute must be a dictionary attribute containing one "
660 "32-bit integer elements attribute: 'local_size'";
661 } else if (symbol == spirv::getTargetEnvAttrName()) {
662 if (!attr.isa<spirv::TargetEnvAttr>())
663 return op->emitError("'")
664 << symbol
665 << "' must be a dictionary attribute containing one 32-bit "
666 "integer attribute 'version', one string array attribute "
667 "'extensions', and one 32-bit integer array attribute "
668 "'capabilities'";
669 } else {
670 return op->emitError("found unsupported '")
671 << symbol << "' attribute on operation";
672 }
673
674 return success();
675 }
676
677 // Verifies the given SPIR-V `attribute` attached to a region's argument or
678 // result and reports error to the given location if invalid.
679 static LogicalResult
verifyRegionAttribute(Location loc,NamedAttribute attribute,bool forArg)680 verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) {
681 StringRef symbol = attribute.first.strref();
682 Attribute attr = attribute.second;
683
684 if (symbol != spirv::getInterfaceVarABIAttrName())
685 return emitError(loc, "found unsupported '")
686 << symbol << "' attribute on region "
687 << (forArg ? "argument" : "result");
688
689 if (!attr.isa<spirv::InterfaceVarABIAttr>())
690 return emitError(loc, "'")
691 << symbol
692 << "' attribute must be a dictionary attribute containing three "
693 "32-bit integer attributes: 'descriptor_set', 'binding', and "
694 "'storage_class'";
695
696 return success();
697 }
698
verifyRegionArgAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)699 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
700 unsigned /*regionIndex*/,
701 unsigned /*argIndex*/,
702 NamedAttribute attribute) {
703 return verifyRegionAttribute(op->getLoc(), attribute,
704 /*forArg=*/true);
705 }
706
verifyRegionResultAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)707 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
708 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
709 NamedAttribute attribute) {
710 return verifyRegionAttribute(op->getLoc(), attribute,
711 /*forArg=*/false);
712 }
713