1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::spirv;
23 
24 //===----------------------------------------------------------------------===//
25 // ArrayType
26 //===----------------------------------------------------------------------===//
27 
28 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
29   using KeyTy = std::tuple<Type, unsigned, unsigned>;
30 
constructspirv::detail::ArrayTypeStorage31   static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
32                                      const KeyTy &key) {
33     return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
34   }
35 
operator ==spirv::detail::ArrayTypeStorage36   bool operator==(const KeyTy &key) const {
37     return key == KeyTy(elementType, elementCount, stride);
38   }
39 
ArrayTypeStoragespirv::detail::ArrayTypeStorage40   ArrayTypeStorage(const KeyTy &key)
41       : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
42         stride(std::get<2>(key)) {}
43 
44   Type elementType;
45   unsigned elementCount;
46   unsigned stride;
47 };
48 
get(Type elementType,unsigned elementCount)49 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
50   assert(elementCount && "ArrayType needs at least one element");
51   return Base::get(elementType.getContext(), elementType, elementCount,
52                    /*stride=*/0);
53 }
54 
get(Type elementType,unsigned elementCount,unsigned stride)55 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
56                          unsigned stride) {
57   assert(elementCount && "ArrayType needs at least one element");
58   return Base::get(elementType.getContext(), elementType, elementCount, stride);
59 }
60 
getNumElements() const61 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
62 
getElementType() const63 Type ArrayType::getElementType() const { return getImpl()->elementType; }
64 
getArrayStride() const65 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
66 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)67 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
68                               Optional<StorageClass> storage) {
69   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
70 }
71 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)72 void ArrayType::getCapabilities(
73     SPIRVType::CapabilityArrayRefVector &capabilities,
74     Optional<StorageClass> storage) {
75   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
76 }
77 
getSizeInBytes()78 Optional<int64_t> ArrayType::getSizeInBytes() {
79   auto elementType = getElementType().cast<SPIRVType>();
80   Optional<int64_t> size = elementType.getSizeInBytes();
81   if (!size)
82     return llvm::None;
83   return (*size + getArrayStride()) * getNumElements();
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // CompositeType
88 //===----------------------------------------------------------------------===//
89 
classof(Type type)90 bool CompositeType::classof(Type type) {
91   if (auto vectorType = type.dyn_cast<VectorType>())
92     return isValid(vectorType);
93   return type
94       .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
95            spirv::RuntimeArrayType, spirv::StructType>();
96 }
97 
isValid(VectorType type)98 bool CompositeType::isValid(VectorType type) {
99   switch (type.getNumElements()) {
100   case 2:
101   case 3:
102   case 4:
103   case 8:
104   case 16:
105     break;
106   default:
107     return false;
108   }
109   return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
110 }
111 
getElementType(unsigned index) const112 Type CompositeType::getElementType(unsigned index) const {
113   return TypeSwitch<Type, Type>(*this)
114       .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
115           [](auto type) { return type.getElementType(); })
116       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
117       .Case<StructType>(
118           [index](StructType type) { return type.getElementType(index); })
119       .Default(
120           [](Type) -> Type { llvm_unreachable("invalid composite type"); });
121 }
122 
getNumElements() const123 unsigned CompositeType::getNumElements() const {
124   if (auto arrayType = dyn_cast<ArrayType>())
125     return arrayType.getNumElements();
126   if (auto matrixType = dyn_cast<MatrixType>())
127     return matrixType.getNumColumns();
128   if (auto structType = dyn_cast<StructType>())
129     return structType.getNumElements();
130   if (auto vectorType = dyn_cast<VectorType>())
131     return vectorType.getNumElements();
132   if (isa<CooperativeMatrixNVType>()) {
133     llvm_unreachable(
134         "invalid to query number of elements of spirv::CooperativeMatrix type");
135   }
136   if (isa<RuntimeArrayType>()) {
137     llvm_unreachable(
138         "invalid to query number of elements of spirv::RuntimeArray type");
139   }
140   llvm_unreachable("invalid composite type");
141 }
142 
hasCompileTimeKnownNumElements() const143 bool CompositeType::hasCompileTimeKnownNumElements() const {
144   return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
145 }
146 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)147 void CompositeType::getExtensions(
148     SPIRVType::ExtensionArrayRefVector &extensions,
149     Optional<StorageClass> storage) {
150   TypeSwitch<Type>(*this)
151       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
152             StructType>(
153           [&](auto type) { type.getExtensions(extensions, storage); })
154       .Case<VectorType>([&](VectorType type) {
155         return type.getElementType().cast<ScalarType>().getExtensions(
156             extensions, storage);
157       })
158       .Default([](Type) { llvm_unreachable("invalid composite type"); });
159 }
160 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)161 void CompositeType::getCapabilities(
162     SPIRVType::CapabilityArrayRefVector &capabilities,
163     Optional<StorageClass> storage) {
164   TypeSwitch<Type>(*this)
165       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
166             StructType>(
167           [&](auto type) { type.getCapabilities(capabilities, storage); })
168       .Case<VectorType>([&](VectorType type) {
169         auto vecSize = getNumElements();
170         if (vecSize == 8 || vecSize == 16) {
171           static const Capability caps[] = {Capability::Vector16};
172           ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
173           capabilities.push_back(ref);
174         }
175         return type.getElementType().cast<ScalarType>().getCapabilities(
176             capabilities, storage);
177       })
178       .Default([](Type) { llvm_unreachable("invalid composite type"); });
179 }
180 
getSizeInBytes()181 Optional<int64_t> CompositeType::getSizeInBytes() {
182   if (auto arrayType = dyn_cast<ArrayType>())
183     return arrayType.getSizeInBytes();
184   if (auto structType = dyn_cast<StructType>())
185     return structType.getSizeInBytes();
186   if (auto vectorType = dyn_cast<VectorType>()) {
187     Optional<int64_t> elementSize =
188         vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
189     if (!elementSize)
190       return llvm::None;
191     return *elementSize * vectorType.getNumElements();
192   }
193   return llvm::None;
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // CooperativeMatrixType
198 //===----------------------------------------------------------------------===//
199 
200 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
201   using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
202 
203   static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage204   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
205     return new (allocator.allocate<CooperativeMatrixTypeStorage>())
206         CooperativeMatrixTypeStorage(key);
207   }
208 
operator ==spirv::detail::CooperativeMatrixTypeStorage209   bool operator==(const KeyTy &key) const {
210     return key == KeyTy(elementType, scope, rows, columns);
211   }
212 
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage213   CooperativeMatrixTypeStorage(const KeyTy &key)
214       : elementType(std::get<0>(key)), rows(std::get<2>(key)),
215         columns(std::get<3>(key)), scope(std::get<1>(key)) {}
216 
217   Type elementType;
218   unsigned rows;
219   unsigned columns;
220   Scope scope;
221 };
222 
get(Type elementType,Scope scope,unsigned rows,unsigned columns)223 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
224                                                      Scope scope, unsigned rows,
225                                                      unsigned columns) {
226   return Base::get(elementType.getContext(), elementType, scope, rows, columns);
227 }
228 
getElementType() const229 Type CooperativeMatrixNVType::getElementType() const {
230   return getImpl()->elementType;
231 }
232 
getScope() const233 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
234 
getRows() const235 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
236 
getColumns() const237 unsigned CooperativeMatrixNVType::getColumns() const {
238   return getImpl()->columns;
239 }
240 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)241 void CooperativeMatrixNVType::getExtensions(
242     SPIRVType::ExtensionArrayRefVector &extensions,
243     Optional<StorageClass> storage) {
244   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
245   static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
246   ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
247   extensions.push_back(ref);
248 }
249 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)250 void CooperativeMatrixNVType::getCapabilities(
251     SPIRVType::CapabilityArrayRefVector &capabilities,
252     Optional<StorageClass> storage) {
253   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
254   static const Capability caps[] = {Capability::CooperativeMatrixNV};
255   ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
256   capabilities.push_back(ref);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ImageType
261 //===----------------------------------------------------------------------===//
262 
getNumBits()263 template <typename T> static constexpr unsigned getNumBits() { return 0; }
getNumBits()264 template <> constexpr unsigned getNumBits<Dim>() {
265   static_assert((1 << 3) > getMaxEnumValForDim(),
266                 "Not enough bits to encode Dim value");
267   return 3;
268 }
getNumBits()269 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
270   static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
271                 "Not enough bits to encode ImageDepthInfo value");
272   return 2;
273 }
getNumBits()274 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
275   static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
276                 "Not enough bits to encode ImageArrayedInfo value");
277   return 1;
278 }
getNumBits()279 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
280   static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
281                 "Not enough bits to encode ImageSamplingInfo value");
282   return 1;
283 }
getNumBits()284 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
285   static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
286                 "Not enough bits to encode ImageSamplerUseInfo value");
287   return 2;
288 }
getNumBits()289 template <> constexpr unsigned getNumBits<ImageFormat>() {
290   static_assert((1 << 6) > getMaxEnumValForImageFormat(),
291                 "Not enough bits to encode ImageFormat value");
292   return 6;
293 }
294 
295 struct spirv::detail::ImageTypeStorage : public TypeStorage {
296 public:
297   using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
298                            ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
299 
constructspirv::detail::ImageTypeStorage300   static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
301                                      const KeyTy &key) {
302     return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
303   }
304 
operator ==spirv::detail::ImageTypeStorage305   bool operator==(const KeyTy &key) const {
306     return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
307                         samplerUseInfo, format);
308   }
309 
ImageTypeStoragespirv::detail::ImageTypeStorage310   ImageTypeStorage(const KeyTy &key)
311       : elementType(std::get<0>(key)), dim(std::get<1>(key)),
312         depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
313         samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
314         format(std::get<6>(key)) {}
315 
316   Type elementType;
317   Dim dim : getNumBits<Dim>();
318   ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
319   ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
320   ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
321   ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
322   ImageFormat format : getNumBits<ImageFormat>();
323 };
324 
325 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)326 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
327                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
328                    value) {
329   return Base::get(std::get<0>(value).getContext(), value);
330 }
331 
getElementType() const332 Type ImageType::getElementType() const { return getImpl()->elementType; }
333 
getDim() const334 Dim ImageType::getDim() const { return getImpl()->dim; }
335 
getDepthInfo() const336 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
337 
getArrayedInfo() const338 ImageArrayedInfo ImageType::getArrayedInfo() const {
339   return getImpl()->arrayedInfo;
340 }
341 
getSamplingInfo() const342 ImageSamplingInfo ImageType::getSamplingInfo() const {
343   return getImpl()->samplingInfo;
344 }
345 
getSamplerUseInfo() const346 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
347   return getImpl()->samplerUseInfo;
348 }
349 
getImageFormat() const350 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
351 
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)352 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
353                               Optional<StorageClass>) {
354   // Image types do not require extra extensions thus far.
355 }
356 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)357 void ImageType::getCapabilities(
358     SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
359   if (auto dimCaps = spirv::getCapabilities(getDim()))
360     capabilities.push_back(*dimCaps);
361 
362   if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
363     capabilities.push_back(*fmtCaps);
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // PointerType
368 //===----------------------------------------------------------------------===//
369 
370 struct spirv::detail::PointerTypeStorage : public TypeStorage {
371   // (Type, StorageClass) as the key: Type stored in this struct, and
372   // StorageClass stored as TypeStorage's subclass data.
373   using KeyTy = std::pair<Type, StorageClass>;
374 
constructspirv::detail::PointerTypeStorage375   static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
376                                        const KeyTy &key) {
377     return new (allocator.allocate<PointerTypeStorage>())
378         PointerTypeStorage(key);
379   }
380 
operator ==spirv::detail::PointerTypeStorage381   bool operator==(const KeyTy &key) const {
382     return key == KeyTy(pointeeType, storageClass);
383   }
384 
PointerTypeStoragespirv::detail::PointerTypeStorage385   PointerTypeStorage(const KeyTy &key)
386       : pointeeType(key.first), storageClass(key.second) {}
387 
388   Type pointeeType;
389   StorageClass storageClass;
390 };
391 
get(Type pointeeType,StorageClass storageClass)392 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
393   return Base::get(pointeeType.getContext(), pointeeType, storageClass);
394 }
395 
getPointeeType() const396 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
397 
getStorageClass() const398 StorageClass PointerType::getStorageClass() const {
399   return getImpl()->storageClass;
400 }
401 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)402 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
403                                 Optional<StorageClass> storage) {
404   // Use this pointer type's storage class because this pointer indicates we are
405   // using the pointee type in that specific storage class.
406   getPointeeType().cast<SPIRVType>().getExtensions(extensions,
407                                                    getStorageClass());
408 
409   if (auto scExts = spirv::getExtensions(getStorageClass()))
410     extensions.push_back(*scExts);
411 }
412 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)413 void PointerType::getCapabilities(
414     SPIRVType::CapabilityArrayRefVector &capabilities,
415     Optional<StorageClass> storage) {
416   // Use this pointer type's storage class because this pointer indicates we are
417   // using the pointee type in that specific storage class.
418   getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
419                                                      getStorageClass());
420 
421   if (auto scCaps = spirv::getCapabilities(getStorageClass()))
422     capabilities.push_back(*scCaps);
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RuntimeArrayType
427 //===----------------------------------------------------------------------===//
428 
429 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
430   using KeyTy = std::pair<Type, unsigned>;
431 
constructspirv::detail::RuntimeArrayTypeStorage432   static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
433                                             const KeyTy &key) {
434     return new (allocator.allocate<RuntimeArrayTypeStorage>())
435         RuntimeArrayTypeStorage(key);
436   }
437 
operator ==spirv::detail::RuntimeArrayTypeStorage438   bool operator==(const KeyTy &key) const {
439     return key == KeyTy(elementType, stride);
440   }
441 
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage442   RuntimeArrayTypeStorage(const KeyTy &key)
443       : elementType(key.first), stride(key.second) {}
444 
445   Type elementType;
446   unsigned stride;
447 };
448 
get(Type elementType)449 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
450   return Base::get(elementType.getContext(), elementType, /*stride=*/0);
451 }
452 
get(Type elementType,unsigned stride)453 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
454   return Base::get(elementType.getContext(), elementType, stride);
455 }
456 
getElementType() const457 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
458 
getArrayStride() const459 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
460 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)461 void RuntimeArrayType::getExtensions(
462     SPIRVType::ExtensionArrayRefVector &extensions,
463     Optional<StorageClass> storage) {
464   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
465 }
466 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)467 void RuntimeArrayType::getCapabilities(
468     SPIRVType::CapabilityArrayRefVector &capabilities,
469     Optional<StorageClass> storage) {
470   {
471     static const Capability caps[] = {Capability::Shader};
472     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
473     capabilities.push_back(ref);
474   }
475   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // ScalarType
480 //===----------------------------------------------------------------------===//
481 
classof(Type type)482 bool ScalarType::classof(Type type) {
483   if (auto floatType = type.dyn_cast<FloatType>()) {
484     return isValid(floatType);
485   }
486   if (auto intType = type.dyn_cast<IntegerType>()) {
487     return isValid(intType);
488   }
489   return false;
490 }
491 
isValid(FloatType type)492 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
493 
isValid(IntegerType type)494 bool ScalarType::isValid(IntegerType type) {
495   switch (type.getWidth()) {
496   case 1:
497   case 8:
498   case 16:
499   case 32:
500   case 64:
501     return true;
502   default:
503     return false;
504   }
505 }
506 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)507 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
508                                Optional<StorageClass> storage) {
509   // 8- or 16-bit integer/floating-point numbers will require extra extensions
510   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
511   // SPV_KHR_8bit_storage for more details.
512   if (!storage)
513     return;
514 
515   switch (*storage) {
516   case StorageClass::PushConstant:
517   case StorageClass::StorageBuffer:
518   case StorageClass::Uniform:
519     if (getIntOrFloatBitWidth() == 8) {
520       static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
521       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
522       extensions.push_back(ref);
523     }
524     LLVM_FALLTHROUGH;
525   case StorageClass::Input:
526   case StorageClass::Output:
527     if (getIntOrFloatBitWidth() == 16) {
528       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
529       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
530       extensions.push_back(ref);
531     }
532     break;
533   default:
534     break;
535   }
536 }
537 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)538 void ScalarType::getCapabilities(
539     SPIRVType::CapabilityArrayRefVector &capabilities,
540     Optional<StorageClass> storage) {
541   unsigned bitwidth = getIntOrFloatBitWidth();
542 
543   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
544   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
545   // SPV_KHR_8bit_storage for more details.
546 
547 #define STORAGE_CASE(storage, cap8, cap16)                                     \
548   case StorageClass::storage: {                                                \
549     if (bitwidth == 8) {                                                       \
550       static const Capability caps[] = {Capability::cap8};                     \
551       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
552       capabilities.push_back(ref);                                             \
553     } else if (bitwidth == 16) {                                               \
554       static const Capability caps[] = {Capability::cap16};                    \
555       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
556       capabilities.push_back(ref);                                             \
557     }                                                                          \
558     /* No requirements for other bitwidths */                                  \
559     return;                                                                    \
560   }
561 
562   // This part only handles the cases where special bitwidths appearing in
563   // interface storage classes.
564   if (storage) {
565     switch (*storage) {
566       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
567       STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
568                    StorageBuffer16BitAccess);
569       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
570                    StorageUniform16);
571     case StorageClass::Input:
572     case StorageClass::Output: {
573       if (bitwidth == 16) {
574         static const Capability caps[] = {Capability::StorageInputOutput16};
575         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
576         capabilities.push_back(ref);
577       }
578       return;
579     }
580     default:
581       break;
582     }
583   }
584 #undef STORAGE_CASE
585 
586   // For other non-interface storage classes, require a different set of
587   // capabilities for special bitwidths.
588 
589 #define WIDTH_CASE(type, width)                                                \
590   case width: {                                                                \
591     static const Capability caps[] = {Capability::type##width};                \
592     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));                \
593     capabilities.push_back(ref);                                               \
594   } break
595 
596   if (auto intType = dyn_cast<IntegerType>()) {
597     switch (bitwidth) {
598     case 32:
599     case 1:
600       break;
601       WIDTH_CASE(Int, 8);
602       WIDTH_CASE(Int, 16);
603       WIDTH_CASE(Int, 64);
604     default:
605       llvm_unreachable("invalid bitwidth to getCapabilities");
606     }
607   } else {
608     assert(isa<FloatType>());
609     switch (bitwidth) {
610     case 32:
611       break;
612       WIDTH_CASE(Float, 16);
613       WIDTH_CASE(Float, 64);
614     default:
615       llvm_unreachable("invalid bitwidth to getCapabilities");
616     }
617   }
618 
619 #undef WIDTH_CASE
620 }
621 
getSizeInBytes()622 Optional<int64_t> ScalarType::getSizeInBytes() {
623   auto bitWidth = getIntOrFloatBitWidth();
624   // According to the SPIR-V spec:
625   // "There is no physical size or bit pattern defined for values with boolean
626   // type. If they are stored (in conjunction with OpVariable), they can only
627   // be used with logical addressing operations, not physical, and only with
628   // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
629   // Private, Function, Input, and Output."
630   if (bitWidth == 1)
631     return llvm::None;
632   return bitWidth / 8;
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // SPIRVType
637 //===----------------------------------------------------------------------===//
638 
classof(Type type)639 bool SPIRVType::classof(Type type) {
640   // Allow SPIR-V dialect types
641   if (llvm::isa<SPIRVDialect>(type.getDialect()))
642     return true;
643   if (type.isa<ScalarType>())
644     return true;
645   if (auto vectorType = type.dyn_cast<VectorType>())
646     return CompositeType::isValid(vectorType);
647   return false;
648 }
649 
isScalarOrVector()650 bool SPIRVType::isScalarOrVector() {
651   return isIntOrFloat() || isa<VectorType>();
652 }
653 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)654 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
655                               Optional<StorageClass> storage) {
656   if (auto scalarType = dyn_cast<ScalarType>()) {
657     scalarType.getExtensions(extensions, storage);
658   } else if (auto compositeType = dyn_cast<CompositeType>()) {
659     compositeType.getExtensions(extensions, storage);
660   } else if (auto imageType = dyn_cast<ImageType>()) {
661     imageType.getExtensions(extensions, storage);
662   } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
663     sampledImageType.getExtensions(extensions, storage);
664   } else if (auto matrixType = dyn_cast<MatrixType>()) {
665     matrixType.getExtensions(extensions, storage);
666   } else if (auto ptrType = dyn_cast<PointerType>()) {
667     ptrType.getExtensions(extensions, storage);
668   } else {
669     llvm_unreachable("invalid SPIR-V Type to getExtensions");
670   }
671 }
672 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)673 void SPIRVType::getCapabilities(
674     SPIRVType::CapabilityArrayRefVector &capabilities,
675     Optional<StorageClass> storage) {
676   if (auto scalarType = dyn_cast<ScalarType>()) {
677     scalarType.getCapabilities(capabilities, storage);
678   } else if (auto compositeType = dyn_cast<CompositeType>()) {
679     compositeType.getCapabilities(capabilities, storage);
680   } else if (auto imageType = dyn_cast<ImageType>()) {
681     imageType.getCapabilities(capabilities, storage);
682   } else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
683     sampledImageType.getCapabilities(capabilities, storage);
684   } else if (auto matrixType = dyn_cast<MatrixType>()) {
685     matrixType.getCapabilities(capabilities, storage);
686   } else if (auto ptrType = dyn_cast<PointerType>()) {
687     ptrType.getCapabilities(capabilities, storage);
688   } else {
689     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
690   }
691 }
692 
getSizeInBytes()693 Optional<int64_t> SPIRVType::getSizeInBytes() {
694   if (auto scalarType = dyn_cast<ScalarType>())
695     return scalarType.getSizeInBytes();
696   if (auto compositeType = dyn_cast<CompositeType>())
697     return compositeType.getSizeInBytes();
698   return llvm::None;
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // SampledImageType
703 //===----------------------------------------------------------------------===//
704 struct spirv::detail::SampledImageTypeStorage : public TypeStorage {
705   using KeyTy = Type;
706 
SampledImageTypeStoragespirv::detail::SampledImageTypeStorage707   SampledImageTypeStorage(const KeyTy &key) : imageType{key} {}
708 
operator ==spirv::detail::SampledImageTypeStorage709   bool operator==(const KeyTy &key) const { return key == KeyTy(imageType); }
710 
constructspirv::detail::SampledImageTypeStorage711   static SampledImageTypeStorage *construct(TypeStorageAllocator &allocator,
712                                             const KeyTy &key) {
713     return new (allocator.allocate<SampledImageTypeStorage>())
714         SampledImageTypeStorage(key);
715   }
716 
717   Type imageType;
718 };
719 
get(Type imageType)720 SampledImageType SampledImageType::get(Type imageType) {
721   return Base::get(imageType.getContext(), imageType);
722 }
723 
724 SampledImageType
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type imageType)725 SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
726                              Type imageType) {
727   return Base::getChecked(emitError, imageType.getContext(), imageType);
728 }
729 
getImageType() const730 Type SampledImageType::getImageType() const { return getImpl()->imageType; }
731 
732 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,Type imageType)733 SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
734                          Type imageType) {
735   if (!imageType.isa<ImageType>())
736     return emitError() << "expected image type";
737 
738   return success();
739 }
740 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)741 void SampledImageType::getExtensions(
742     SPIRVType::ExtensionArrayRefVector &extensions,
743     Optional<StorageClass> storage) {
744   getImageType().cast<ImageType>().getExtensions(extensions, storage);
745 }
746 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)747 void SampledImageType::getCapabilities(
748     SPIRVType::CapabilityArrayRefVector &capabilities,
749     Optional<StorageClass> storage) {
750   getImageType().cast<ImageType>().getCapabilities(capabilities, storage);
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // StructType
755 //===----------------------------------------------------------------------===//
756 
757 /// Type storage for SPIR-V structure types:
758 ///
759 /// Structures are uniqued using:
760 /// - for identified structs:
761 ///   - a string identifier;
762 /// - for literal structs:
763 ///   - a list of member types;
764 ///   - a list of member offset info;
765 ///   - a list of member decoration info.
766 ///
767 /// Identified structures only have a mutable component consisting of:
768 /// - a list of member types;
769 /// - a list of member offset info;
770 /// - a list of member decoration info.
771 struct spirv::detail::StructTypeStorage : public TypeStorage {
772   /// Construct a storage object for an identified struct type. A struct type
773   /// associated with such storage must call StructType::trySetBody(...) later
774   /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage775   StructTypeStorage(StringRef identifier)
776       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
777         numMemberDecorations(0), memberDecorationsInfo(nullptr),
778         identifier(identifier) {}
779 
780   /// Construct a storage object for a literal struct type. A struct type
781   /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage782   StructTypeStorage(
783       unsigned numMembers, Type const *memberTypes,
784       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
785       StructType::MemberDecorationInfo const *memberDecorationsInfo)
786       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
787         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
788         memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
789 
790   /// A storage key is divided into 2 parts:
791   /// - for identified structs:
792   ///   - a StringRef representing the struct identifier;
793   /// - for literal structs:
794   ///   - an ArrayRef<Type> for member types;
795   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
796   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
797   ///     info.
798   ///
799   /// An identified struct type is uniqued only by the first part (field 0)
800   /// of the key.
801   ///
802   /// A literal struct type is uniqued only by the second part (fields 1, 2, and
803   /// 3) of the key. The identifier field (field 0) must be empty.
804   using KeyTy =
805       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
806                  ArrayRef<StructType::MemberDecorationInfo>>;
807 
808   /// For identified structs, return true if the given key contains the same
809   /// identifier.
810   ///
811   /// For literal structs, return true if the given key contains a matching list
812   /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage813   bool operator==(const KeyTy &key) const {
814     if (isIdentified()) {
815       // Identified types are uniqued by their identifier.
816       return getIdentifier() == std::get<0>(key);
817     }
818 
819     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
820                         getMemberDecorationsInfo());
821   }
822 
823   /// If the given key contains a non-empty identifier, this method constructs
824   /// an identified struct and leaves the rest of the struct type data to be set
825   /// through a later call to StructType::trySetBody(...).
826   ///
827   /// If, on the other hand, the key contains an empty identifier, a literal
828   /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage829   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
830                                       const KeyTy &key) {
831     StringRef keyIdentifier = std::get<0>(key);
832 
833     if (!keyIdentifier.empty()) {
834       StringRef identifier = allocator.copyInto(keyIdentifier);
835 
836       // Identified StructType body/members will be set through trySetBody(...)
837       // later.
838       return new (allocator.allocate<StructTypeStorage>())
839           StructTypeStorage(identifier);
840     }
841 
842     ArrayRef<Type> keyTypes = std::get<1>(key);
843 
844     // Copy the member type and layout information into the bump pointer
845     const Type *typesList = nullptr;
846     if (!keyTypes.empty()) {
847       typesList = allocator.copyInto(keyTypes).data();
848     }
849 
850     const StructType::OffsetInfo *offsetInfoList = nullptr;
851     if (!std::get<2>(key).empty()) {
852       ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
853       assert(keyOffsetInfo.size() == keyTypes.size() &&
854              "size of offset information must be same as the size of number of "
855              "elements");
856       offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
857     }
858 
859     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
860     unsigned numMemberDecorations = 0;
861     if (!std::get<3>(key).empty()) {
862       auto keyMemberDecorations = std::get<3>(key);
863       numMemberDecorations = keyMemberDecorations.size();
864       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
865     }
866 
867     return new (allocator.allocate<StructTypeStorage>())
868         StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
869                           numMemberDecorations, memberDecorationList);
870   }
871 
getMemberTypesspirv::detail::StructTypeStorage872   ArrayRef<Type> getMemberTypes() const {
873     return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
874   }
875 
getOffsetInfospirv::detail::StructTypeStorage876   ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
877     if (offsetInfo) {
878       return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
879     }
880     return {};
881   }
882 
getMemberDecorationsInfospirv::detail::StructTypeStorage883   ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
884     if (memberDecorationsInfo) {
885       return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
886                                                         numMemberDecorations);
887     }
888     return {};
889   }
890 
getIdentifierspirv::detail::StructTypeStorage891   StringRef getIdentifier() const { return identifier; }
892 
isIdentifiedspirv::detail::StructTypeStorage893   bool isIdentified() const { return !identifier.empty(); }
894 
895   /// Sets the struct type content for identified structs. Calling this method
896   /// is only valid for identified structs.
897   ///
898   /// Fails under the following conditions:
899   /// - If called for a literal struct;
900   /// - If called for an identified struct whose body was set before (through a
901   /// call to this method) but with different contents from the passed
902   /// arguments.
mutatespirv::detail::StructTypeStorage903   LogicalResult mutate(
904       TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
905       ArrayRef<StructType::OffsetInfo> structOffsetInfo,
906       ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
907     if (!isIdentified())
908       return failure();
909 
910     if (memberTypesAndIsBodySet.getInt() &&
911         (getMemberTypes() != structMemberTypes ||
912          getOffsetInfo() != structOffsetInfo ||
913          getMemberDecorationsInfo() != structMemberDecorationInfo))
914       return failure();
915 
916     memberTypesAndIsBodySet.setInt(true);
917     numMembers = structMemberTypes.size();
918 
919     // Copy the member type and layout information into the bump pointer.
920     if (!structMemberTypes.empty())
921       memberTypesAndIsBodySet.setPointer(
922           allocator.copyInto(structMemberTypes).data());
923 
924     if (!structOffsetInfo.empty()) {
925       assert(structOffsetInfo.size() == structMemberTypes.size() &&
926              "size of offset information must be same as the size of number of "
927              "elements");
928       offsetInfo = allocator.copyInto(structOffsetInfo).data();
929     }
930 
931     if (!structMemberDecorationInfo.empty()) {
932       numMemberDecorations = structMemberDecorationInfo.size();
933       memberDecorationsInfo =
934           allocator.copyInto(structMemberDecorationInfo).data();
935     }
936 
937     return success();
938   }
939 
940   llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
941   StructType::OffsetInfo const *offsetInfo;
942   unsigned numMembers;
943   unsigned numMemberDecorations;
944   StructType::MemberDecorationInfo const *memberDecorationsInfo;
945   StringRef identifier;
946 };
947 
948 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)949 StructType::get(ArrayRef<Type> memberTypes,
950                 ArrayRef<StructType::OffsetInfo> offsetInfo,
951                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
952   assert(!memberTypes.empty() && "Struct needs at least one member type");
953   // Sort the decorations.
954   SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
955       memberDecorations.begin(), memberDecorations.end());
956   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
957   return Base::get(memberTypes.vec().front().getContext(),
958                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
959                    sortedDecorations);
960 }
961 
getIdentified(MLIRContext * context,StringRef identifier)962 StructType StructType::getIdentified(MLIRContext *context,
963                                      StringRef identifier) {
964   assert(!identifier.empty() &&
965          "StructType identifier must be non-empty string");
966 
967   return Base::get(context, identifier, ArrayRef<Type>(),
968                    ArrayRef<StructType::OffsetInfo>(),
969                    ArrayRef<StructType::MemberDecorationInfo>());
970 }
971 
getEmpty(MLIRContext * context,StringRef identifier)972 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
973   StructType newStructType = Base::get(
974       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
975       ArrayRef<StructType::MemberDecorationInfo>());
976   // Set an empty body in case this is a identified struct.
977   if (newStructType.isIdentified() &&
978       failed(newStructType.trySetBody(
979           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
980           ArrayRef<StructType::MemberDecorationInfo>())))
981     return StructType();
982 
983   return newStructType;
984 }
985 
getIdentifier() const986 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
987 
isIdentified() const988 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
989 
getNumElements() const990 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
991 
getElementType(unsigned index) const992 Type StructType::getElementType(unsigned index) const {
993   assert(getNumElements() > index && "member index out of range");
994   return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
995 }
996 
getElementTypes() const997 StructType::ElementTypeRange StructType::getElementTypes() const {
998   return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
999                           getNumElements());
1000 }
1001 
hasOffset() const1002 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1003 
getMemberOffset(unsigned index) const1004 uint64_t StructType::getMemberOffset(unsigned index) const {
1005   assert(getNumElements() > index && "member index out of range");
1006   return getImpl()->offsetInfo[index];
1007 }
1008 
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1009 void StructType::getMemberDecorations(
1010     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1011     const {
1012   memberDecorations.clear();
1013   auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1014   memberDecorations.append(implMemberDecorations.begin(),
1015                            implMemberDecorations.end());
1016 }
1017 
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1018 void StructType::getMemberDecorations(
1019     unsigned index,
1020     SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1021   assert(getNumElements() > index && "member index out of range");
1022   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1023   decorationsInfo.clear();
1024   for (const auto &memberDecoration : memberDecorations) {
1025     if (memberDecoration.memberIndex == index) {
1026       decorationsInfo.push_back(memberDecoration);
1027     }
1028     if (memberDecoration.memberIndex > index) {
1029       // Early exit since the decorations are stored sorted.
1030       return;
1031     }
1032   }
1033 }
1034 
1035 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1036 StructType::trySetBody(ArrayRef<Type> memberTypes,
1037                        ArrayRef<OffsetInfo> offsetInfo,
1038                        ArrayRef<MemberDecorationInfo> memberDecorations) {
1039   return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1040 }
1041 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1042 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1043                                Optional<StorageClass> storage) {
1044   for (Type elementType : getElementTypes())
1045     elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1046 }
1047 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1048 void StructType::getCapabilities(
1049     SPIRVType::CapabilityArrayRefVector &capabilities,
1050     Optional<StorageClass> storage) {
1051   for (Type elementType : getElementTypes())
1052     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1053 }
1054 
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1055 llvm::hash_code spirv::hash_value(
1056     const StructType::MemberDecorationInfo &memberDecorationInfo) {
1057   return llvm::hash_combine(memberDecorationInfo.memberIndex,
1058                             memberDecorationInfo.decoration);
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // MatrixType
1063 //===----------------------------------------------------------------------===//
1064 
1065 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1066   MatrixTypeStorage(Type columnType, uint32_t columnCount)
1067       : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1068 
1069   using KeyTy = std::tuple<Type, uint32_t>;
1070 
constructspirv::detail::MatrixTypeStorage1071   static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1072                                       const KeyTy &key) {
1073 
1074     // Initialize the memory using placement new.
1075     return new (allocator.allocate<MatrixTypeStorage>())
1076         MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1077   }
1078 
operator ==spirv::detail::MatrixTypeStorage1079   bool operator==(const KeyTy &key) const {
1080     return key == KeyTy(columnType, columnCount);
1081   }
1082 
1083   Type columnType;
1084   const uint32_t columnCount;
1085 };
1086 
get(Type columnType,uint32_t columnCount)1087 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1088   return Base::get(columnType.getContext(), columnType, columnCount);
1089 }
1090 
getChecked(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1091 MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1092                                   Type columnType, uint32_t columnCount) {
1093   return Base::getChecked(emitError, columnType.getContext(), columnType,
1094                           columnCount);
1095 }
1096 
verify(function_ref<InFlightDiagnostic ()> emitError,Type columnType,uint32_t columnCount)1097 LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
1098                                  Type columnType, uint32_t columnCount) {
1099   if (columnCount < 2 || columnCount > 4)
1100     return emitError() << "matrix can have 2, 3, or 4 columns only";
1101 
1102   if (!isValidColumnType(columnType))
1103     return emitError() << "matrix columns must be vectors of floats";
1104 
1105   /// The underlying vectors (columns) must be of size 2, 3, or 4
1106   ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1107   if (columnShape.size() != 1)
1108     return emitError() << "matrix columns must be 1D vectors";
1109 
1110   if (columnShape[0] < 2 || columnShape[0] > 4)
1111     return emitError() << "matrix columns must be of size 2, 3, or 4";
1112 
1113   return success();
1114 }
1115 
1116 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1117 bool MatrixType::isValidColumnType(Type columnType) {
1118   if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1119     if (vectorType.getElementType().isa<FloatType>())
1120       return true;
1121   }
1122   return false;
1123 }
1124 
getColumnType() const1125 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1126 
getElementType() const1127 Type MatrixType::getElementType() const {
1128   return getImpl()->columnType.cast<VectorType>().getElementType();
1129 }
1130 
getNumColumns() const1131 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1132 
getNumRows() const1133 unsigned MatrixType::getNumRows() const {
1134   return getImpl()->columnType.cast<VectorType>().getShape()[0];
1135 }
1136 
getNumElements() const1137 unsigned MatrixType::getNumElements() const {
1138   return (getImpl()->columnCount) * getNumRows();
1139 }
1140 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1141 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1142                                Optional<StorageClass> storage) {
1143   getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1144 }
1145 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1146 void MatrixType::getCapabilities(
1147     SPIRVType::CapabilityArrayRefVector &capabilities,
1148     Optional<StorageClass> storage) {
1149   {
1150     static const Capability caps[] = {Capability::Matrix};
1151     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1152     capabilities.push_back(ref);
1153   }
1154   // Add any capabilities associated with the underlying vectors (i.e., columns)
1155   getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // SPIR-V Dialect
1160 //===----------------------------------------------------------------------===//
1161 
registerTypes()1162 void SPIRVDialect::registerTypes() {
1163   addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
1164            PointerType, RuntimeArrayType, SampledImageType, StructType>();
1165 }
1166