1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20
21 using namespace mlir;
22 using namespace mlir::spirv;
23
24 //===----------------------------------------------------------------------===//
25 // ArrayType
26 //===----------------------------------------------------------------------===//
27
28 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
29 using KeyTy = std::tuple<Type, unsigned, unsigned>;
30
constructspirv::detail::ArrayTypeStorage31 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
32 const KeyTy &key) {
33 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
34 }
35
operator ==spirv::detail::ArrayTypeStorage36 bool operator==(const KeyTy &key) const {
37 return key == KeyTy(elementType, elementCount, stride);
38 }
39
ArrayTypeStoragespirv::detail::ArrayTypeStorage40 ArrayTypeStorage(const KeyTy &key)
41 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
42 stride(std::get<2>(key)) {}
43
44 Type elementType;
45 unsigned elementCount;
46 unsigned stride;
47 };
48
get(Type elementType,unsigned elementCount)49 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
50 assert(elementCount && "ArrayType needs at least one element");
51 return Base::get(elementType.getContext(), elementType, elementCount,
52 /*stride=*/0);
53 }
54
get(Type elementType,unsigned elementCount,unsigned stride)55 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
56 unsigned stride) {
57 assert(elementCount && "ArrayType needs at least one element");
58 return Base::get(elementType.getContext(), elementType, elementCount, stride);
59 }
60
getNumElements() const61 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
62
getElementType() const63 Type ArrayType::getElementType() const { return getImpl()->elementType; }
64
getArrayStride() const65 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
66
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)67 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
68 Optional<StorageClass> storage) {
69 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
70 }
71
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)72 void ArrayType::getCapabilities(
73 SPIRVType::CapabilityArrayRefVector &capabilities,
74 Optional<StorageClass> storage) {
75 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
76 }
77
getSizeInBytes()78 Optional<int64_t> ArrayType::getSizeInBytes() {
79 auto elementType = getElementType().cast<SPIRVType>();
80 Optional<int64_t> size = elementType.getSizeInBytes();
81 if (!size)
82 return llvm::None;
83 return (*size + getArrayStride()) * getNumElements();
84 }
85
86 //===----------------------------------------------------------------------===//
87 // CompositeType
88 //===----------------------------------------------------------------------===//
89
classof(Type type)90 bool CompositeType::classof(Type type) {
91 if (auto vectorType = type.dyn_cast<VectorType>())
92 return isValid(vectorType);
93 return type
94 .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
95 spirv::RuntimeArrayType, spirv::StructType>();
96 }
97
isValid(VectorType type)98 bool CompositeType::isValid(VectorType type) {
99 switch (type.getNumElements()) {
100 case 2:
101 case 3:
102 case 4:
103 case 8:
104 case 16:
105 break;
106 default:
107 return false;
108 }
109 return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
110 }
111
getElementType(unsigned index) const112 Type CompositeType::getElementType(unsigned index) const {
113 return TypeSwitch<Type, Type>(*this)
114 .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
115 [](auto type) { return type.getElementType(); })
116 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
117 .Case<StructType>(
118 [index](StructType type) { return type.getElementType(index); })
119 .Default(
120 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
121 }
122
getNumElements() const123 unsigned CompositeType::getNumElements() const {
124 if (auto arrayType = dyn_cast<ArrayType>())
125 return arrayType.getNumElements();
126 if (auto matrixType = dyn_cast<MatrixType>())
127 return matrixType.getNumColumns();
128 if (auto structType = dyn_cast<StructType>())
129 return structType.getNumElements();
130 if (auto vectorType = dyn_cast<VectorType>())
131 return vectorType.getNumElements();
132 if (isa<CooperativeMatrixNVType>()) {
133 llvm_unreachable(
134 "invalid to query number of elements of spirv::CooperativeMatrix type");
135 }
136 if (isa<RuntimeArrayType>()) {
137 llvm_unreachable(
138 "invalid to query number of elements of spirv::RuntimeArray type");
139 }
140 llvm_unreachable("invalid composite type");
141 }
142
hasCompileTimeKnownNumElements() const143 bool CompositeType::hasCompileTimeKnownNumElements() const {
144 return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
145 }
146
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)147 void CompositeType::getExtensions(
148 SPIRVType::ExtensionArrayRefVector &extensions,
149 Optional<StorageClass> storage) {
150 TypeSwitch<Type>(*this)
151 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
152 StructType>(
153 [&](auto type) { type.getExtensions(extensions, storage); })
154 .Case<VectorType>([&](VectorType type) {
155 return type.getElementType().cast<ScalarType>().getExtensions(
156 extensions, storage);
157 })
158 .Default([](Type) { llvm_unreachable("invalid composite type"); });
159 }
160
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)161 void CompositeType::getCapabilities(
162 SPIRVType::CapabilityArrayRefVector &capabilities,
163 Optional<StorageClass> storage) {
164 TypeSwitch<Type>(*this)
165 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
166 StructType>(
167 [&](auto type) { type.getCapabilities(capabilities, storage); })
168 .Case<VectorType>([&](VectorType type) {
169 auto vecSize = getNumElements();
170 if (vecSize == 8 || vecSize == 16) {
171 static const Capability caps[] = {Capability::Vector16};
172 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
173 capabilities.push_back(ref);
174 }
175 return type.getElementType().cast<ScalarType>().getCapabilities(
176 capabilities, storage);
177 })
178 .Default([](Type) { llvm_unreachable("invalid composite type"); });
179 }
180
getSizeInBytes()181 Optional<int64_t> CompositeType::getSizeInBytes() {
182 if (auto arrayType = dyn_cast<ArrayType>())
183 return arrayType.getSizeInBytes();
184 if (auto structType = dyn_cast<StructType>())
185 return structType.getSizeInBytes();
186 if (auto vectorType = dyn_cast<VectorType>()) {
187 Optional<int64_t> elementSize =
188 vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
189 if (!elementSize)
190 return llvm::None;
191 return *elementSize * vectorType.getNumElements();
192 }
193 return llvm::None;
194 }
195
196 //===----------------------------------------------------------------------===//
197 // CooperativeMatrixType
198 //===----------------------------------------------------------------------===//
199
200 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
201 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
202
203 static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage204 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
205 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
206 CooperativeMatrixTypeStorage(key);
207 }
208
operator ==spirv::detail::CooperativeMatrixTypeStorage209 bool operator==(const KeyTy &key) const {
210 return key == KeyTy(elementType, scope, rows, columns);
211 }
212
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage213 CooperativeMatrixTypeStorage(const KeyTy &key)
214 : elementType(std::get<0>(key)), rows(std::get<2>(key)),
215 columns(std::get<3>(key)), scope(std::get<1>(key)) {}
216
217 Type elementType;
218 unsigned rows;
219 unsigned columns;
220 Scope scope;
221 };
222
get(Type elementType,Scope scope,unsigned rows,unsigned columns)223 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
224 Scope scope, unsigned rows,
225 unsigned columns) {
226 return Base::get(elementType.getContext(), elementType, scope, rows, columns);
227 }
228
getElementType() const229 Type CooperativeMatrixNVType::getElementType() const {
230 return getImpl()->elementType;
231 }
232
getScope() const233 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
234
getRows() const235 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
236
getColumns() const237 unsigned CooperativeMatrixNVType::getColumns() const {
238 return getImpl()->columns;
239 }
240
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)241 void CooperativeMatrixNVType::getExtensions(
242 SPIRVType::ExtensionArrayRefVector &extensions,
243 Optional<StorageClass> storage) {
244 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
245 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
246 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
247 extensions.push_back(ref);
248 }
249
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)250 void CooperativeMatrixNVType::getCapabilities(
251 SPIRVType::CapabilityArrayRefVector &capabilities,
252 Optional<StorageClass> storage) {
253 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
254 static const Capability caps[] = {Capability::CooperativeMatrixNV};
255 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
256 capabilities.push_back(ref);
257 }
258
259 //===----------------------------------------------------------------------===//
260 // ImageType
261 //===----------------------------------------------------------------------===//
262
getNumBits()263 template <typename T> static constexpr unsigned getNumBits() { return 0; }
getNumBits()264 template <> constexpr unsigned getNumBits<Dim>() {
265 static_assert((1 << 3) > getMaxEnumValForDim(),
266 "Not enough bits to encode Dim value");
267 return 3;
268 }
getNumBits()269 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
270 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
271 "Not enough bits to encode ImageDepthInfo value");
272 return 2;
273 }
getNumBits()274 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
275 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
276 "Not enough bits to encode ImageArrayedInfo value");
277 return 1;
278 }
getNumBits()279 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
280 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
281 "Not enough bits to encode ImageSamplingInfo value");
282 return 1;
283 }
getNumBits()284 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
285 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
286 "Not enough bits to encode ImageSamplerUseInfo value");
287 return 2;
288 }
getNumBits()289 template <> constexpr unsigned getNumBits<ImageFormat>() {
290 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
291 "Not enough bits to encode ImageFormat value");
292 return 6;
293 }
294
295 struct spirv::detail::ImageTypeStorage : public TypeStorage {
296 public:
297 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
298 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
299
constructspirv::detail::ImageTypeStorage300 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
301 const KeyTy &key) {
302 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
303 }
304
operator ==spirv::detail::ImageTypeStorage305 bool operator==(const KeyTy &key) const {
306 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
307 samplerUseInfo, format);
308 }
309
ImageTypeStoragespirv::detail::ImageTypeStorage310 ImageTypeStorage(const KeyTy &key)
311 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
312 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
313 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
314 format(std::get<6>(key)) {}
315
316 Type elementType;
317 Dim dim : getNumBits<Dim>();
318 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
319 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
320 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
321 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
322 ImageFormat format : getNumBits<ImageFormat>();
323 };
324
325 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)326 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
327 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
328 value) {
329 return Base::get(std::get<0>(value).getContext(), value);
330 }
331
getElementType() const332 Type ImageType::getElementType() const { return getImpl()->elementType; }
333
getDim() const334 Dim ImageType::getDim() const { return getImpl()->dim; }
335
getDepthInfo() const336 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
337
getArrayedInfo() const338 ImageArrayedInfo ImageType::getArrayedInfo() const {
339 return getImpl()->arrayedInfo;
340 }
341
getSamplingInfo() const342 ImageSamplingInfo ImageType::getSamplingInfo() const {
343 return getImpl()->samplingInfo;
344 }
345
getSamplerUseInfo() const346 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
347 return getImpl()->samplerUseInfo;
348 }
349
getImageFormat() const350 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
351
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)352 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
353 Optional<StorageClass>) {
354 // Image types do not require extra extensions thus far.
355 }
356
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)357 void ImageType::getCapabilities(
358 SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
359 if (auto dimCaps = spirv::getCapabilities(getDim()))
360 capabilities.push_back(*dimCaps);
361
362 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
363 capabilities.push_back(*fmtCaps);
364 }
365
366 //===----------------------------------------------------------------------===//
367 // PointerType
368 //===----------------------------------------------------------------------===//
369
370 struct spirv::detail::PointerTypeStorage : public TypeStorage {
371 // (Type, StorageClass) as the key: Type stored in this struct, and
372 // StorageClass stored as TypeStorage's subclass data.
373 using KeyTy = std::pair<Type, StorageClass>;
374
constructspirv::detail::PointerTypeStorage375 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
376 const KeyTy &key) {
377 return new (allocator.allocate<PointerTypeStorage>())
378 PointerTypeStorage(key);
379 }
380
operator ==spirv::detail::PointerTypeStorage381 bool operator==(const KeyTy &key) const {
382 return key == KeyTy(pointeeType, storageClass);
383 }
384
PointerTypeStoragespirv::detail::PointerTypeStorage385 PointerTypeStorage(const KeyTy &key)
386 : pointeeType(key.first), storageClass(key.second) {}
387
388 Type pointeeType;
389 StorageClass storageClass;
390 };
391
get(Type pointeeType,StorageClass storageClass)392 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
393 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
394 }
395
getPointeeType() const396 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
397
getStorageClass() const398 StorageClass PointerType::getStorageClass() const {
399 return getImpl()->storageClass;
400 }
401
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)402 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
403 Optional<StorageClass> storage) {
404 // Use this pointer type's storage class because this pointer indicates we are
405 // using the pointee type in that specific storage class.
406 getPointeeType().cast<SPIRVType>().getExtensions(extensions,
407 getStorageClass());
408
409 if (auto scExts = spirv::getExtensions(getStorageClass()))
410 extensions.push_back(*scExts);
411 }
412
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)413 void PointerType::getCapabilities(
414 SPIRVType::CapabilityArrayRefVector &capabilities,
415 Optional<StorageClass> storage) {
416 // Use this pointer type's storage class because this pointer indicates we are
417 // using the pointee type in that specific storage class.
418 getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
419 getStorageClass());
420
421 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
422 capabilities.push_back(*scCaps);
423 }
424
425 //===----------------------------------------------------------------------===//
426 // RuntimeArrayType
427 //===----------------------------------------------------------------------===//
428
429 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
430 using KeyTy = std::pair<Type, unsigned>;
431
constructspirv::detail::RuntimeArrayTypeStorage432 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
433 const KeyTy &key) {
434 return new (allocator.allocate<RuntimeArrayTypeStorage>())
435 RuntimeArrayTypeStorage(key);
436 }
437
operator ==spirv::detail::RuntimeArrayTypeStorage438 bool operator==(const KeyTy &key) const {
439 return key == KeyTy(elementType, stride);
440 }
441
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage442 RuntimeArrayTypeStorage(const KeyTy &key)
443 : elementType(key.first), stride(key.second) {}
444
445 Type elementType;
446 unsigned stride;
447 };
448
get(Type elementType)449 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
450 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
451 }
452
get(Type elementType,unsigned stride)453 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
454 return Base::get(elementType.getContext(), elementType, stride);
455 }
456
getElementType() const457 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
458
getArrayStride() const459 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
460
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)461 void RuntimeArrayType::getExtensions(
462 SPIRVType::ExtensionArrayRefVector &extensions,
463 Optional<StorageClass> storage) {
464 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
465 }
466
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)467 void RuntimeArrayType::getCapabilities(
468 SPIRVType::CapabilityArrayRefVector &capabilities,
469 Optional<StorageClass> storage) {
470 {
471 static const Capability caps[] = {Capability::Shader};
472 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
473 capabilities.push_back(ref);
474 }
475 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
476 }
477
478 //===----------------------------------------------------------------------===//
479 // ScalarType
480 //===----------------------------------------------------------------------===//
481
classof(Type type)482 bool ScalarType::classof(Type type) {
483 if (auto floatType = type.dyn_cast<FloatType>()) {
484 return isValid(floatType);
485 }
486 if (auto intType = type.dyn_cast<IntegerType>()) {
487 return isValid(intType);
488 }
489 return false;
490 }
491
isValid(FloatType type)492 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
493
isValid(IntegerType type)494 bool ScalarType::isValid(IntegerType type) {
495 switch (type.getWidth()) {
496 case 1:
497 case 8:
498 case 16:
499 case 32:
500 case 64:
501 return true;
502 default:
503 return false;
504 }
505 }
506
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)507 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
508 Optional<StorageClass> storage) {
509 // 8- or 16-bit integer/floating-point numbers will require extra extensions
510 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
511 // SPV_KHR_8bit_storage for more details.
512 if (!storage)
513 return;
514
515 switch (*storage) {
516 case StorageClass::PushConstant:
517 case StorageClass::StorageBuffer:
518 case StorageClass::Uniform:
519 if (getIntOrFloatBitWidth() == 8) {
520 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
521 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
522 extensions.push_back(ref);
523 }
524 LLVM_FALLTHROUGH;
525 case StorageClass::Input:
526 case StorageClass::Output:
527 if (getIntOrFloatBitWidth() == 16) {
528 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
529 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
530 extensions.push_back(ref);
531 }
532 break;
533 default:
534 break;
535 }
536 }
537
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)538 void ScalarType::getCapabilities(
539 SPIRVType::CapabilityArrayRefVector &capabilities,
540 Optional<StorageClass> storage) {
541 unsigned bitwidth = getIntOrFloatBitWidth();
542
543 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
544 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
545 // SPV_KHR_8bit_storage for more details.
546
547 #define STORAGE_CASE(storage, cap8, cap16) \
548 case StorageClass::storage: { \
549 if (bitwidth == 8) { \
550 static const Capability caps[] = {Capability::cap8}; \
551 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
552 capabilities.push_back(ref); \
553 } else if (bitwidth == 16) { \
554 static const Capability caps[] = {Capability::cap16}; \
555 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
556 capabilities.push_back(ref); \
557 } \
558 /* No requirements for other bitwidths */ \
559 return; \
560 }
561
562 // This part only handles the cases where special bitwidths appearing in
563 // interface storage classes.
564 if (storage) {
565 switch (*storage) {
566 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
567 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
568 StorageBuffer16BitAccess);
569 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
570 StorageUniform16);
571 case StorageClass::Input:
572 case StorageClass::Output: {
573 if (bitwidth == 16) {
574 static const Capability caps[] = {Capability::StorageInputOutput16};
575 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
576 capabilities.push_back(ref);
577 }
578 return;
579 }
580 default:
581 break;
582 }
583 }
584 #undef STORAGE_CASE
585
586 // For other non-interface storage classes, require a different set of
587 // capabilities for special bitwidths.
588
589 #define WIDTH_CASE(type, width) \
590 case width: { \
591 static const Capability caps[] = {Capability::type##width}; \
592 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
593 capabilities.push_back(ref); \
594 } break
595
596 if (auto intType = dyn_cast<IntegerType>()) {
597 switch (bitwidth) {
598 case 32:
599 case 1:
600 break;
601 WIDTH_CASE(Int, 8);
602 WIDTH_CASE(Int, 16);
603 WIDTH_CASE(Int, 64);
604 default:
605 llvm_unreachable("invalid bitwidth to getCapabilities");
606 }
607 } else {
608 assert(isa<FloatType>());
609 switch (bitwidth) {
610 case 32:
611 break;
612 WIDTH_CASE(Float, 16);
613 WIDTH_CASE(Float, 64);
614 default:
615 llvm_unreachable("invalid bitwidth to getCapabilities");
616 }
617 }
618
619 #undef WIDTH_CASE
620 }
621
getSizeInBytes()622 Optional<int64_t> ScalarType::getSizeInBytes() {
623 auto bitWidth = getIntOrFloatBitWidth();
624 // According to the SPIR-V spec:
625 // "There is no physical size or bit pattern defined for values with boolean
626 // type. If they are stored (in conjunction with OpVariable), they can only
627 // be used with logical addressing operations, not physical, and only with
628 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
629 // Private, Function, Input, and Output."
630 if (bitWidth == 1)
631 return llvm::None;
632 return bitWidth / 8;
633 }
634
635 //===----------------------------------------------------------------------===//
636 // SPIRVType
637 //===----------------------------------------------------------------------===//
638
classof(Type type)639 bool SPIRVType::classof(Type type) {
640 // Allow SPIR-V dialect types
641 if (llvm::isa<SPIRVDialect>(type.getDialect()))
642 return true;
643 if (type.isa<ScalarType>())
644 return true;
645 if (auto vectorType = type.dyn_cast<VectorType>())
646 return CompositeType::isValid(vectorType);
647 return false;
648 }
649
isScalarOrVector()650 bool SPIRVType::isScalarOrVector() {
651 return isIntOrFloat() || isa<VectorType>();
652 }
653
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)654 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
655 Optional<StorageClass> storage) {
656 if (auto scalarType = dyn_cast<ScalarType>()) {
657 scalarType.getExtensions(extensions, storage);
658 } else if (auto compositeType = dyn_cast<CompositeType>()) {
659 compositeType.getExtensions(extensions, storage);
660 } else if (auto imageType = dyn_cast<ImageType>()) {
661 imageType.getExtensions(extensions, storage);
662 } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
663 sampledImageType.getExtensions(extensions, storage);
664 } else if (auto matrixType = dyn_cast<MatrixType>()) {
665 matrixType.getExtensions(extensions, storage);
666 } else if (auto ptrType = dyn_cast<PointerType>()) {
667 ptrType.getExtensions(extensions, storage);
668 } else {
669 llvm_unreachable("invalid SPIR-V Type to getExtensions");
670 }
671 }
672
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)673 void SPIRVType::getCapabilities(
674 SPIRVType::CapabilityArrayRefVector &capabilities,
675 Optional<StorageClass> storage) {
676 if (auto scalarType = dyn_cast<ScalarType>()) {
677 scalarType.getCapabilities(capabilities, storage);
678 } else if (auto compositeType = dyn_cast<CompositeType>()) {
679 compositeType.getCapabilities(capabilities, storage);
680 } else if (auto imageType = dyn_cast<ImageType>()) {
681 imageType.getCapabilities(capabilities, storage);
682 } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
683 sampledImageType.getCapabilities(capabilities, storage);
684 } else if (auto matrixType = dyn_cast<MatrixType>()) {
685 matrixType.getCapabilities(capabilities, storage);
686 } else if (auto ptrType = dyn_cast<PointerType>()) {
687 ptrType.getCapabilities(capabilities, storage);
688 } else {
689 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
690 }
691 }
692
getSizeInBytes()693 Optional<int64_t> SPIRVType::getSizeInBytes() {
694 if (auto scalarType = dyn_cast<ScalarType>())
695 return scalarType.getSizeInBytes();
696 if (auto compositeType = dyn_cast<CompositeType>())
697 return compositeType.getSizeInBytes();
698 return llvm::None;
699 }
700
701 //===----------------------------------------------------------------------===//
702 // SampledImageType
703 //===----------------------------------------------------------------------===//
704 struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
705 using KeyTy = Type;
706
SampledImageTypeStoragespirv::detail::SampledImageTypeStorage707 SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
708
operator ==spirv::detail::SampledImageTypeStorage709 bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
710
constructspirv::detail::SampledImageTypeStorage711 static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
712 const KeyTy &key) {
713 return new (allocator.allocate<SampledImageTypeStorage>())
714 SampledImageTypeStorage(key);
715 }
716
717 Type imageType;
718 };
719
get(Type imageType)720 SampledImageType SampledImageType::get(Type imageType) {
721 return Base::get(imageType.getContext(), imageType);
722 }
723
724 SampledImageType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type imageType)725 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
726 Type imageType) {
727 return Base::getChecked(emitError, imageType.getContext(), imageType);
728 }
729
getImageType() const730 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
731
732 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type imageType)733 SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
734 Type imageType) {
735 if (!imageType.isa<ImageType>())
736 return emitError() << "expected image type";
737
738 return success();
739 }
740
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)741 void SampledImageType::getExtensions(
742 SPIRVType::ExtensionArrayRefVector &extensions,
743 Optional<StorageClass> storage) {
744 getImageType().cast<ImageType>().getExtensions(extensions, storage);
745 }
746
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)747 void SampledImageType::getCapabilities(
748 SPIRVType::CapabilityArrayRefVector &capabilities,
749 Optional<StorageClass> storage) {
750 getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
751 }
752
753 //===----------------------------------------------------------------------===//
754 // StructType
755 //===----------------------------------------------------------------------===//
756
757 /// Type storage for SPIR-V structure types:
758 ///
759 /// Structures are uniqued using:
760 /// - for identified structs:
761 /// - a string identifier;
762 /// - for literal structs:
763 /// - a list of member types;
764 /// - a list of member offset info;
765 /// - a list of member decoration info.
766 ///
767 /// Identified structures only have a mutable component consisting of:
768 /// - a list of member types;
769 /// - a list of member offset info;
770 /// - a list of member decoration info.
771 struct spirv::detail::StructTypeStorage : public TypeStorage {
772 /// Construct a storage object for an identified struct type. A struct type
773 /// associated with such storage must call StructType::trySetBody(...) later
774 /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage775 StructTypeStorage(StringRef identifier)
776 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
777 numMemberDecorations(0), memberDecorationsInfo(nullptr),
778 identifier(identifier) {}
779
780 /// Construct a storage object for a literal struct type. A struct type
781 /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage782 StructTypeStorage(
783 unsigned numMembers, Type const *memberTypes,
784 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
785 StructType::MemberDecorationInfo const *memberDecorationsInfo)
786 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
787 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
788 memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
789
790 /// A storage key is divided into 2 parts:
791 /// - for identified structs:
792 /// - a StringRef representing the struct identifier;
793 /// - for literal structs:
794 /// - an ArrayRef<Type> for member types;
795 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
796 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
797 /// info.
798 ///
799 /// An identified struct type is uniqued only by the first part (field 0)
800 /// of the key.
801 ///
802 /// A literal struct type is uniqued only by the second part (fields 1, 2, and
803 /// 3) of the key. The identifier field (field 0) must be empty.
804 using KeyTy =
805 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
806 ArrayRef<StructType::MemberDecorationInfo>>;
807
808 /// For identified structs, return true if the given key contains the same
809 /// identifier.
810 ///
811 /// For literal structs, return true if the given key contains a matching list
812 /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage813 bool operator==(const KeyTy &key) const {
814 if (isIdentified()) {
815 // Identified types are uniqued by their identifier.
816 return getIdentifier() == std::get<0>(key);
817 }
818
819 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
820 getMemberDecorationsInfo());
821 }
822
823 /// If the given key contains a non-empty identifier, this method constructs
824 /// an identified struct and leaves the rest of the struct type data to be set
825 /// through a later call to StructType::trySetBody(...).
826 ///
827 /// If, on the other hand, the key contains an empty identifier, a literal
828 /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage829 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
830 const KeyTy &key) {
831 StringRef keyIdentifier = std::get<0>(key);
832
833 if (!keyIdentifier.empty()) {
834 StringRef identifier = allocator.copyInto(keyIdentifier);
835
836 // Identified StructType body/members will be set through trySetBody(...)
837 // later.
838 return new (allocator.allocate<StructTypeStorage>())
839 StructTypeStorage(identifier);
840 }
841
842 ArrayRef<Type> keyTypes = std::get<1>(key);
843
844 // Copy the member type and layout information into the bump pointer
845 const Type *typesList = nullptr;
846 if (!keyTypes.empty()) {
847 typesList = allocator.copyInto(keyTypes).data();
848 }
849
850 const StructType::OffsetInfo *offsetInfoList = nullptr;
851 if (!std::get<2>(key).empty()) {
852 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
853 assert(keyOffsetInfo.size() == keyTypes.size() &&
854 "size of offset information must be same as the size of number of "
855 "elements");
856 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
857 }
858
859 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
860 unsigned numMemberDecorations = 0;
861 if (!std::get<3>(key).empty()) {
862 auto keyMemberDecorations = std::get<3>(key);
863 numMemberDecorations = keyMemberDecorations.size();
864 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
865 }
866
867 return new (allocator.allocate<StructTypeStorage>())
868 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
869 numMemberDecorations, memberDecorationList);
870 }
871
getMemberTypesspirv::detail::StructTypeStorage872 ArrayRef<Type> getMemberTypes() const {
873 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
874 }
875
getOffsetInfospirv::detail::StructTypeStorage876 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
877 if (offsetInfo) {
878 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
879 }
880 return {};
881 }
882
getMemberDecorationsInfospirv::detail::StructTypeStorage883 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
884 if (memberDecorationsInfo) {
885 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
886 numMemberDecorations);
887 }
888 return {};
889 }
890
getIdentifierspirv::detail::StructTypeStorage891 StringRef getIdentifier() const { return identifier; }
892
isIdentifiedspirv::detail::StructTypeStorage893 bool isIdentified() const { return !identifier.empty(); }
894
895 /// Sets the struct type content for identified structs. Calling this method
896 /// is only valid for identified structs.
897 ///
898 /// Fails under the following conditions:
899 /// - If called for a literal struct;
900 /// - If called for an identified struct whose body was set before (through a
901 /// call to this method) but with different contents from the passed
902 /// arguments.
mutatespirv::detail::StructTypeStorage903 LogicalResult mutate(
904 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
905 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
906 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
907 if (!isIdentified())
908 return failure();
909
910 if (memberTypesAndIsBodySet.getInt() &&
911 (getMemberTypes() != structMemberTypes ||
912 getOffsetInfo() != structOffsetInfo ||
913 getMemberDecorationsInfo() != structMemberDecorationInfo))
914 return failure();
915
916 memberTypesAndIsBodySet.setInt(true);
917 numMembers = structMemberTypes.size();
918
919 // Copy the member type and layout information into the bump pointer.
920 if (!structMemberTypes.empty())
921 memberTypesAndIsBodySet.setPointer(
922 allocator.copyInto(structMemberTypes).data());
923
924 if (!structOffsetInfo.empty()) {
925 assert(structOffsetInfo.size() == structMemberTypes.size() &&
926 "size of offset information must be same as the size of number of "
927 "elements");
928 offsetInfo = allocator.copyInto(structOffsetInfo).data();
929 }
930
931 if (!structMemberDecorationInfo.empty()) {
932 numMemberDecorations = structMemberDecorationInfo.size();
933 memberDecorationsInfo =
934 allocator.copyInto(structMemberDecorationInfo).data();
935 }
936
937 return success();
938 }
939
940 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
941 StructType::OffsetInfo const *offsetInfo;
942 unsigned numMembers;
943 unsigned numMemberDecorations;
944 StructType::MemberDecorationInfo const *memberDecorationsInfo;
945 StringRef identifier;
946 };
947
948 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)949 StructType::get(ArrayRef<Type> memberTypes,
950 ArrayRef<StructType::OffsetInfo> offsetInfo,
951 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
952 assert(!memberTypes.empty() && "Struct needs at least one member type");
953 // Sort the decorations.
954 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
955 memberDecorations.begin(), memberDecorations.end());
956 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
957 return Base::get(memberTypes.vec().front().getContext(),
958 /*identifier=*/StringRef(), memberTypes, offsetInfo,
959 sortedDecorations);
960 }
961
getIdentified(MLIRContext * context,StringRef identifier)962 StructType StructType::getIdentified(MLIRContext *context,
963 StringRef identifier) {
964 assert(!identifier.empty() &&
965 "StructType identifier must be non-empty string");
966
967 return Base::get(context, identifier, ArrayRef<Type>(),
968 ArrayRef<StructType::OffsetInfo>(),
969 ArrayRef<StructType::MemberDecorationInfo>());
970 }
971
getEmpty(MLIRContext * context,StringRef identifier)972 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
973 StructType newStructType = Base::get(
974 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
975 ArrayRef<StructType::MemberDecorationInfo>());
976 // Set an empty body in case this is a identified struct.
977 if (newStructType.isIdentified() &&
978 failed(newStructType.trySetBody(
979 ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
980 ArrayRef<StructType::MemberDecorationInfo>())))
981 return StructType();
982
983 return newStructType;
984 }
985
getIdentifier() const986 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
987
isIdentified() const988 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
989
getNumElements() const990 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
991
getElementType(unsigned index) const992 Type StructType::getElementType(unsigned index) const {
993 assert(getNumElements() > index && "member index out of range");
994 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
995 }
996
getElementTypes() const997 StructType::ElementTypeRange StructType::getElementTypes() const {
998 return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
999 getNumElements());
1000 }
1001
hasOffset() const1002 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1003
getMemberOffset(unsigned index) const1004 uint64_t StructType::getMemberOffset(unsigned index) const {
1005 assert(getNumElements() > index && "member index out of range");
1006 return getImpl()->offsetInfo[index];
1007 }
1008
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1009 void StructType::getMemberDecorations(
1010 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1011 const {
1012 memberDecorations.clear();
1013 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1014 memberDecorations.append(implMemberDecorations.begin(),
1015 implMemberDecorations.end());
1016 }
1017
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1018 void StructType::getMemberDecorations(
1019 unsigned index,
1020 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1021 assert(getNumElements() > index && "member index out of range");
1022 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1023 decorationsInfo.clear();
1024 for (const auto &memberDecoration : memberDecorations) {
1025 if (memberDecoration.memberIndex == index) {
1026 decorationsInfo.push_back(memberDecoration);
1027 }
1028 if (memberDecoration.memberIndex > index) {
1029 // Early exit since the decorations are stored sorted.
1030 return;
1031 }
1032 }
1033 }
1034
1035 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1036 StructType::trySetBody(ArrayRef<Type> memberTypes,
1037 ArrayRef<OffsetInfo> offsetInfo,
1038 ArrayRef<MemberDecorationInfo> memberDecorations) {
1039 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1040 }
1041
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1042 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1043 Optional<StorageClass> storage) {
1044 for (Type elementType : getElementTypes())
1045 elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1046 }
1047
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1048 void StructType::getCapabilities(
1049 SPIRVType::CapabilityArrayRefVector &capabilities,
1050 Optional<StorageClass> storage) {
1051 for (Type elementType : getElementTypes())
1052 elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1053 }
1054
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1055 llvm::hash_code spirv::hash_value(
1056 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1057 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1058 memberDecorationInfo.decoration);
1059 }
1060
1061 //===----------------------------------------------------------------------===//
1062 // MatrixType
1063 //===----------------------------------------------------------------------===//
1064
1065 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1066 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1067 : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1068
1069 using KeyTy = std::tuple<Type, uint32_t>;
1070
constructspirv::detail::MatrixTypeStorage1071 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1072 const KeyTy &key) {
1073
1074 // Initialize the memory using placement new.
1075 return new (allocator.allocate<MatrixTypeStorage>())
1076 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1077 }
1078
operator ==spirv::detail::MatrixTypeStorage1079 bool operator==(const KeyTy &key) const {
1080 return key == KeyTy(columnType, columnCount);
1081 }
1082
1083 Type columnType;
1084 const uint32_t columnCount;
1085 };
1086
get(Type columnType,uint32_t columnCount)1087 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1088 return Base::get(columnType.getContext(), columnType, columnCount);
1089 }
1090
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1091 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1092 Type columnType, uint32_t columnCount) {
1093 return Base::getChecked(emitError, columnType.getContext(), columnType,
1094 columnCount);
1095 }
1096
verify(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1097 LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
1098 Type columnType, uint32_t columnCount) {
1099 if (columnCount < 2 || columnCount > 4)
1100 return emitError() << "matrix can have 2, 3, or 4 columns only";
1101
1102 if (!isValidColumnType(columnType))
1103 return emitError() << "matrix columns must be vectors of floats";
1104
1105 /// The underlying vectors (columns) must be of size 2, 3, or 4
1106 ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1107 if (columnShape.size() != 1)
1108 return emitError() << "matrix columns must be 1D vectors";
1109
1110 if (columnShape[0] < 2 || columnShape[0] > 4)
1111 return emitError() << "matrix columns must be of size 2, 3, or 4";
1112
1113 return success();
1114 }
1115
1116 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1117 bool MatrixType::isValidColumnType(Type columnType) {
1118 if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1119 if (vectorType.getElementType().isa<FloatType>())
1120 return true;
1121 }
1122 return false;
1123 }
1124
getColumnType() const1125 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1126
getElementType() const1127 Type MatrixType::getElementType() const {
1128 return getImpl()->columnType.cast<VectorType>().getElementType();
1129 }
1130
getNumColumns() const1131 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1132
getNumRows() const1133 unsigned MatrixType::getNumRows() const {
1134 return getImpl()->columnType.cast<VectorType>().getShape()[0];
1135 }
1136
getNumElements() const1137 unsigned MatrixType::getNumElements() const {
1138 return (getImpl()->columnCount) * getNumRows();
1139 }
1140
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1141 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1142 Optional<StorageClass> storage) {
1143 getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1144 }
1145
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1146 void MatrixType::getCapabilities(
1147 SPIRVType::CapabilityArrayRefVector &capabilities,
1148 Optional<StorageClass> storage) {
1149 {
1150 static const Capability caps[] = {Capability::Matrix};
1151 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1152 capabilities.push_back(ref);
1153 }
1154 // Add any capabilities associated with the underlying vectors (i.e., columns)
1155 getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1156 }
1157
1158 //===----------------------------------------------------------------------===//
1159 // SPIR-V Dialect
1160 //===----------------------------------------------------------------------===//
1161
registerTypes()1162 void SPIRVDialect::registerTypes() {
1163 addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
1164 PointerType, RuntimeArrayType, SampledImageType, StructType>();
1165 }
1166