1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // This file provides a class hierarchy for representing SPIR-V types.
16 
17 #ifndef SOURCE_OPT_TYPES_H_
18 #define SOURCE_OPT_TYPES_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28 
29 #include "source/latest_version_spirv_header.h"
30 #include "source/opt/instruction.h"
31 #include "spirv-tools/libspirv.h"
32 
33 namespace spvtools {
34 namespace opt {
35 namespace analysis {
36 
37 class Void;
38 class Bool;
39 class Integer;
40 class Float;
41 class Vector;
42 class Matrix;
43 class Image;
44 class Sampler;
45 class SampledImage;
46 class Array;
47 class RuntimeArray;
48 class Struct;
49 class Opaque;
50 class Pointer;
51 class Function;
52 class Event;
53 class DeviceEvent;
54 class ReserveId;
55 class Queue;
56 class Pipe;
57 class ForwardPointer;
58 class PipeStorage;
59 class NamedBarrier;
60 class AccelerationStructureNV;
61 class CooperativeMatrixNV;
62 class RayQueryProvisionalKHR;
63 
64 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods,
65 // which is used as a way to probe the actual <subclass>.
66 class Type {
67  public:
68   typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache;
69 
70   // Available subtypes.
71   //
72   // When adding a new derived class of Type, please add an entry to the enum.
73   enum Kind {
74     kVoid,
75     kBool,
76     kInteger,
77     kFloat,
78     kVector,
79     kMatrix,
80     kImage,
81     kSampler,
82     kSampledImage,
83     kArray,
84     kRuntimeArray,
85     kStruct,
86     kOpaque,
87     kPointer,
88     kFunction,
89     kEvent,
90     kDeviceEvent,
91     kReserveId,
92     kQueue,
93     kPipe,
94     kForwardPointer,
95     kPipeStorage,
96     kNamedBarrier,
97     kAccelerationStructureNV,
98     kCooperativeMatrixNV,
99     kRayQueryProvisionalKHR
100   };
101 
Type(Kind k)102   Type(Kind k) : kind_(k) {}
103 
~Type()104   virtual ~Type() {}
105 
106   // Attaches a decoration directly on this type.
AddDecoration(std::vector<uint32_t> && d)107   void AddDecoration(std::vector<uint32_t>&& d) {
108     decorations_.push_back(std::move(d));
109   }
110   // Returns the decorations on this type as a string.
111   std::string GetDecorationStr() const;
112   // Returns true if this type has exactly the same decorations as |that| type.
113   bool HasSameDecorations(const Type* that) const;
114   // Returns true if this type is exactly the same as |that| type, including
115   // decorations.
IsSame(const Type * that)116   bool IsSame(const Type* that) const {
117     IsSameCache seen;
118     return IsSameImpl(that, &seen);
119   }
120 
121   // Returns true if this type is exactly the same as |that| type, including
122   // decorations.  |seen| is the set of |Pointer*| pair that are currently being
123   // compared in a parent call to |IsSameImpl|.
124   virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0;
125 
126   // Returns a human-readable string to represent this type.
127   virtual std::string str() const = 0;
128 
kind()129   Kind kind() const { return kind_; }
decorations()130   const std::vector<std::vector<uint32_t>>& decorations() const {
131     return decorations_;
132   }
133 
134   // Returns true if there is no decoration on this type. For struct types,
135   // returns true only when there is no decoration for both the struct type
136   // and the struct members.
decoration_empty()137   virtual bool decoration_empty() const { return decorations_.empty(); }
138 
139   // Creates a clone of |this|.
140   std::unique_ptr<Type> Clone() const;
141 
142   // Returns a clone of |this| minus any decorations.
143   std::unique_ptr<Type> RemoveDecorations() const;
144 
145   // Returns true if this type must be unique.
146   //
147   // If variable pointers are allowed, then pointers are not required to be
148   // unique.
149   // TODO(alanbaker): Update this if variable pointers become a core feature.
150   bool IsUniqueType(bool allowVariablePointers = false) const;
151 
152   bool operator==(const Type& other) const;
153 
154   // Returns the hash value of this type.
155   size_t HashValue() const;
156 
157   // Adds the necessary words to compute a hash value of this type to |words|.
GetHashWords(std::vector<uint32_t> * words)158   void GetHashWords(std::vector<uint32_t>* words) const {
159     std::unordered_set<const Type*> seen;
160     GetHashWords(words, &seen);
161   }
162 
163   // Adds the necessary words to compute a hash value of this type to |words|.
164   void GetHashWords(std::vector<uint32_t>* words,
165                     std::unordered_set<const Type*>* seen) const;
166 
167   // Adds necessary extra words for a subtype to calculate a hash value into
168   // |words|.
169   virtual void GetExtraHashWords(
170       std::vector<uint32_t>* words,
171       std::unordered_set<const Type*>* pSet) const = 0;
172 
173 // A bunch of methods for casting this type to a given type. Returns this if the
174 // cast can be done, nullptr otherwise.
175 // clang-format off
176 #define DeclareCastMethod(target)                  \
177   virtual target* As##target() { return nullptr; } \
178   virtual const target* As##target() const { return nullptr; }
179   DeclareCastMethod(Void)
180   DeclareCastMethod(Bool)
181   DeclareCastMethod(Integer)
182   DeclareCastMethod(Float)
183   DeclareCastMethod(Vector)
184   DeclareCastMethod(Matrix)
185   DeclareCastMethod(Image)
186   DeclareCastMethod(Sampler)
187   DeclareCastMethod(SampledImage)
188   DeclareCastMethod(Array)
189   DeclareCastMethod(RuntimeArray)
190   DeclareCastMethod(Struct)
191   DeclareCastMethod(Opaque)
192   DeclareCastMethod(Pointer)
193   DeclareCastMethod(Function)
194   DeclareCastMethod(Event)
195   DeclareCastMethod(DeviceEvent)
196   DeclareCastMethod(ReserveId)
197   DeclareCastMethod(Queue)
198   DeclareCastMethod(Pipe)
199   DeclareCastMethod(ForwardPointer)
200   DeclareCastMethod(PipeStorage)
201   DeclareCastMethod(NamedBarrier)
202   DeclareCastMethod(AccelerationStructureNV)
203   DeclareCastMethod(CooperativeMatrixNV)
204   DeclareCastMethod(RayQueryProvisionalKHR)
205 #undef DeclareCastMethod
206 
207  protected:
208   // Decorations attached to this type. Each decoration is encoded as a vector
209   // of uint32_t numbers. The first uint32_t number is the decoration value,
210   // and the rest are the parameters to the decoration (if exists).
211   std::vector<std::vector<uint32_t>> decorations_;
212 
213  private:
214   // Removes decorations on this type. For struct types, also removes element
215   // decorations.
ClearDecorations()216   virtual void ClearDecorations() { decorations_.clear(); }
217 
218   Kind kind_;
219 };
220 // clang-format on
221 
222 class Integer : public Type {
223  public:
Integer(uint32_t w,bool is_signed)224   Integer(uint32_t w, bool is_signed)
225       : Type(kInteger), width_(w), signed_(is_signed) {}
226   Integer(const Integer&) = default;
227 
228   std::string str() const override;
229 
AsInteger()230   Integer* AsInteger() override { return this; }
AsInteger()231   const Integer* AsInteger() const override { return this; }
width()232   uint32_t width() const { return width_; }
IsSigned()233   bool IsSigned() const { return signed_; }
234 
235   void GetExtraHashWords(std::vector<uint32_t>* words,
236                          std::unordered_set<const Type*>* pSet) const override;
237 
238  private:
239   bool IsSameImpl(const Type* that, IsSameCache*) const override;
240 
241   uint32_t width_;  // bit width
242   bool signed_;     // true if this integer is signed
243 };
244 
245 class Float : public Type {
246  public:
Float(uint32_t w)247   Float(uint32_t w) : Type(kFloat), width_(w) {}
248   Float(const Float&) = default;
249 
250   std::string str() const override;
251 
AsFloat()252   Float* AsFloat() override { return this; }
AsFloat()253   const Float* AsFloat() const override { return this; }
width()254   uint32_t width() const { return width_; }
255 
256   void GetExtraHashWords(std::vector<uint32_t>* words,
257                          std::unordered_set<const Type*>* pSet) const override;
258 
259  private:
260   bool IsSameImpl(const Type* that, IsSameCache*) const override;
261 
262   uint32_t width_;  // bit width
263 };
264 
265 class Vector : public Type {
266  public:
267   Vector(const Type* element_type, uint32_t count);
268   Vector(const Vector&) = default;
269 
270   std::string str() const override;
element_type()271   const Type* element_type() const { return element_type_; }
element_count()272   uint32_t element_count() const { return count_; }
273 
AsVector()274   Vector* AsVector() override { return this; }
AsVector()275   const Vector* AsVector() const override { return this; }
276 
277   void GetExtraHashWords(std::vector<uint32_t>* words,
278                          std::unordered_set<const Type*>* pSet) const override;
279 
280  private:
281   bool IsSameImpl(const Type* that, IsSameCache*) const override;
282 
283   const Type* element_type_;
284   uint32_t count_;
285 };
286 
287 class Matrix : public Type {
288  public:
289   Matrix(const Type* element_type, uint32_t count);
290   Matrix(const Matrix&) = default;
291 
292   std::string str() const override;
element_type()293   const Type* element_type() const { return element_type_; }
element_count()294   uint32_t element_count() const { return count_; }
295 
AsMatrix()296   Matrix* AsMatrix() override { return this; }
AsMatrix()297   const Matrix* AsMatrix() const override { return this; }
298 
299   void GetExtraHashWords(std::vector<uint32_t>* words,
300                          std::unordered_set<const Type*>* pSet) const override;
301 
302  private:
303   bool IsSameImpl(const Type* that, IsSameCache*) const override;
304 
305   const Type* element_type_;
306   uint32_t count_;
307 };
308 
309 class Image : public Type {
310  public:
311   Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
312         uint32_t sampling, SpvImageFormat f,
313         SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly);
314   Image(const Image&) = default;
315 
316   std::string str() const override;
317 
AsImage()318   Image* AsImage() override { return this; }
AsImage()319   const Image* AsImage() const override { return this; }
320 
sampled_type()321   const Type* sampled_type() const { return sampled_type_; }
dim()322   SpvDim dim() const { return dim_; }
depth()323   uint32_t depth() const { return depth_; }
is_arrayed()324   bool is_arrayed() const { return arrayed_; }
is_multisampled()325   bool is_multisampled() const { return ms_; }
sampled()326   uint32_t sampled() const { return sampled_; }
format()327   SpvImageFormat format() const { return format_; }
access_qualifier()328   SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
329 
330   void GetExtraHashWords(std::vector<uint32_t>* words,
331                          std::unordered_set<const Type*>* pSet) const override;
332 
333  private:
334   bool IsSameImpl(const Type* that, IsSameCache*) const override;
335 
336   Type* sampled_type_;
337   SpvDim dim_;
338   uint32_t depth_;
339   bool arrayed_;
340   bool ms_;
341   uint32_t sampled_;
342   SpvImageFormat format_;
343   SpvAccessQualifier access_qualifier_;
344 };
345 
346 class SampledImage : public Type {
347  public:
SampledImage(Type * image)348   SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {}
349   SampledImage(const SampledImage&) = default;
350 
351   std::string str() const override;
352 
AsSampledImage()353   SampledImage* AsSampledImage() override { return this; }
AsSampledImage()354   const SampledImage* AsSampledImage() const override { return this; }
355 
image_type()356   const Type* image_type() const { return image_type_; }
357 
358   void GetExtraHashWords(std::vector<uint32_t>* words,
359                          std::unordered_set<const Type*>* pSet) const override;
360 
361  private:
362   bool IsSameImpl(const Type* that, IsSameCache*) const override;
363   Type* image_type_;
364 };
365 
366 class Array : public Type {
367  public:
368   // Data about the length operand, that helps us distinguish between one
369   // array length and another.
370   struct LengthInfo {
371     // The result id of the instruction defining the length.
372     const uint32_t id;
373     enum Case : uint32_t {
374       kConstant = 0,
375       kConstantWithSpecId = 1,
376       kDefiningId = 2
377     };
378     // Extra words used to distinshish one array length and another.
379     //  - if OpConstant, then it's 0, then the words in the literal constant
380     //    value.
381     //  - if OpSpecConstant, then it's 1, then the SpecID decoration if there
382     //    is one, followed by the words in the literal constant value.
383     //    The spec might not be overridden, in which case we'll end up using
384     //    the literal value.
385     //  - Otherwise, it's an OpSpecConsant, and this 2, then the ID (again).
386     const std::vector<uint32_t> words;
387   };
388 
389   // Constructs an array type with given element and length.  If the length
390   // is an OpSpecConstant, then |spec_id| should be its SpecId decoration.
391   Array(const Type* element_type, const LengthInfo& length_info_arg);
392   Array(const Array&) = default;
393 
394   std::string str() const override;
element_type()395   const Type* element_type() const { return element_type_; }
LengthId()396   uint32_t LengthId() const { return length_info_.id; }
length_info()397   const LengthInfo& length_info() const { return length_info_; }
398 
AsArray()399   Array* AsArray() override { return this; }
AsArray()400   const Array* AsArray() const override { return this; }
401 
402   void GetExtraHashWords(std::vector<uint32_t>* words,
403                          std::unordered_set<const Type*>* pSet) const override;
404 
405   void ReplaceElementType(const Type* element_type);
406 
407  private:
408   bool IsSameImpl(const Type* that, IsSameCache*) const override;
409 
410   const Type* element_type_;
411   const LengthInfo length_info_;
412 };
413 
414 class RuntimeArray : public Type {
415  public:
416   RuntimeArray(const Type* element_type);
417   RuntimeArray(const RuntimeArray&) = default;
418 
419   std::string str() const override;
element_type()420   const Type* element_type() const { return element_type_; }
421 
AsRuntimeArray()422   RuntimeArray* AsRuntimeArray() override { return this; }
AsRuntimeArray()423   const RuntimeArray* AsRuntimeArray() const override { return this; }
424 
425   void GetExtraHashWords(std::vector<uint32_t>* words,
426                          std::unordered_set<const Type*>* pSet) const override;
427 
428   void ReplaceElementType(const Type* element_type);
429 
430  private:
431   bool IsSameImpl(const Type* that, IsSameCache*) const override;
432 
433   const Type* element_type_;
434 };
435 
436 class Struct : public Type {
437  public:
438   Struct(const std::vector<const Type*>& element_types);
439   Struct(const Struct&) = default;
440 
441   // Adds a decoration to the member at the given index.  The first word is the
442   // decoration enum, and the remaining words, if any, are its operands.
443   void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration);
444 
445   std::string str() const override;
element_types()446   const std::vector<const Type*>& element_types() const {
447     return element_types_;
448   }
element_types()449   std::vector<const Type*>& element_types() { return element_types_; }
decoration_empty()450   bool decoration_empty() const override {
451     return decorations_.empty() && element_decorations_.empty();
452   }
453 
454   const std::map<uint32_t, std::vector<std::vector<uint32_t>>>&
element_decorations()455   element_decorations() const {
456     return element_decorations_;
457   }
458 
AsStruct()459   Struct* AsStruct() override { return this; }
AsStruct()460   const Struct* AsStruct() const override { return this; }
461 
462   void GetExtraHashWords(std::vector<uint32_t>* words,
463                          std::unordered_set<const Type*>* pSet) const override;
464 
465  private:
466   bool IsSameImpl(const Type* that, IsSameCache*) const override;
467 
ClearDecorations()468   void ClearDecorations() override {
469     decorations_.clear();
470     element_decorations_.clear();
471   }
472 
473   std::vector<const Type*> element_types_;
474   // We can attach decorations to struct members and that should not affect the
475   // underlying element type. So we need an extra data structure here to keep
476   // track of element type decorations.  They must be stored in an ordered map
477   // because |GetExtraHashWords| will traverse the structure.  It must have a
478   // fixed order in order to hash to the same value every time.
479   std::map<uint32_t, std::vector<std::vector<uint32_t>>> element_decorations_;
480 };
481 
482 class Opaque : public Type {
483  public:
Opaque(std::string n)484   Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {}
485   Opaque(const Opaque&) = default;
486 
487   std::string str() const override;
488 
AsOpaque()489   Opaque* AsOpaque() override { return this; }
AsOpaque()490   const Opaque* AsOpaque() const override { return this; }
491 
name()492   const std::string& name() const { return name_; }
493 
494   void GetExtraHashWords(std::vector<uint32_t>* words,
495                          std::unordered_set<const Type*>* pSet) const override;
496 
497  private:
498   bool IsSameImpl(const Type* that, IsSameCache*) const override;
499 
500   std::string name_;
501 };
502 
503 class Pointer : public Type {
504  public:
505   Pointer(const Type* pointee, SpvStorageClass sc);
506   Pointer(const Pointer&) = default;
507 
508   std::string str() const override;
pointee_type()509   const Type* pointee_type() const { return pointee_type_; }
storage_class()510   SpvStorageClass storage_class() const { return storage_class_; }
511 
AsPointer()512   Pointer* AsPointer() override { return this; }
AsPointer()513   const Pointer* AsPointer() const override { return this; }
514 
515   void GetExtraHashWords(std::vector<uint32_t>* words,
516                          std::unordered_set<const Type*>* pSet) const override;
517 
518   void SetPointeeType(const Type* type);
519 
520  private:
521   bool IsSameImpl(const Type* that, IsSameCache*) const override;
522 
523   const Type* pointee_type_;
524   SpvStorageClass storage_class_;
525 };
526 
527 class Function : public Type {
528  public:
529   Function(const Type* ret_type, const std::vector<const Type*>& params);
530   Function(const Type* ret_type, std::vector<const Type*>& params);
531   Function(const Function&) = default;
532 
533   std::string str() const override;
534 
AsFunction()535   Function* AsFunction() override { return this; }
AsFunction()536   const Function* AsFunction() const override { return this; }
537 
return_type()538   const Type* return_type() const { return return_type_; }
param_types()539   const std::vector<const Type*>& param_types() const { return param_types_; }
param_types()540   std::vector<const Type*>& param_types() { return param_types_; }
541 
542   void GetExtraHashWords(std::vector<uint32_t>* words,
543                          std::unordered_set<const Type*>*) const override;
544 
545   void SetReturnType(const Type* type);
546 
547  private:
548   bool IsSameImpl(const Type* that, IsSameCache*) const override;
549 
550   const Type* return_type_;
551   std::vector<const Type*> param_types_;
552 };
553 
554 class Pipe : public Type {
555  public:
Pipe(SpvAccessQualifier qualifier)556   Pipe(SpvAccessQualifier qualifier)
557       : Type(kPipe), access_qualifier_(qualifier) {}
558   Pipe(const Pipe&) = default;
559 
560   std::string str() const override;
561 
AsPipe()562   Pipe* AsPipe() override { return this; }
AsPipe()563   const Pipe* AsPipe() const override { return this; }
564 
access_qualifier()565   SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
566 
567   void GetExtraHashWords(std::vector<uint32_t>* words,
568                          std::unordered_set<const Type*>* pSet) const override;
569 
570  private:
571   bool IsSameImpl(const Type* that, IsSameCache*) const override;
572 
573   SpvAccessQualifier access_qualifier_;
574 };
575 
576 class ForwardPointer : public Type {
577  public:
ForwardPointer(uint32_t id,SpvStorageClass sc)578   ForwardPointer(uint32_t id, SpvStorageClass sc)
579       : Type(kForwardPointer),
580         target_id_(id),
581         storage_class_(sc),
582         pointer_(nullptr) {}
583   ForwardPointer(const ForwardPointer&) = default;
584 
target_id()585   uint32_t target_id() const { return target_id_; }
SetTargetPointer(const Pointer * pointer)586   void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; }
storage_class()587   SpvStorageClass storage_class() const { return storage_class_; }
target_pointer()588   const Pointer* target_pointer() const { return pointer_; }
589 
590   std::string str() const override;
591 
AsForwardPointer()592   ForwardPointer* AsForwardPointer() override { return this; }
AsForwardPointer()593   const ForwardPointer* AsForwardPointer() const override { return this; }
594 
595   void GetExtraHashWords(std::vector<uint32_t>* words,
596                          std::unordered_set<const Type*>* pSet) const override;
597 
598  private:
599   bool IsSameImpl(const Type* that, IsSameCache*) const override;
600 
601   uint32_t target_id_;
602   SpvStorageClass storage_class_;
603   const Pointer* pointer_;
604 };
605 
606 class CooperativeMatrixNV : public Type {
607  public:
608   CooperativeMatrixNV(const Type* type, const uint32_t scope,
609                       const uint32_t rows, const uint32_t columns);
610   CooperativeMatrixNV(const CooperativeMatrixNV&) = default;
611 
612   std::string str() const override;
613 
AsCooperativeMatrixNV()614   CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; }
AsCooperativeMatrixNV()615   const CooperativeMatrixNV* AsCooperativeMatrixNV() const override {
616     return this;
617   }
618 
619   void GetExtraHashWords(std::vector<uint32_t>*,
620                          std::unordered_set<const Type*>*) const override;
621 
component_type()622   const Type* component_type() const { return component_type_; }
scope_id()623   uint32_t scope_id() const { return scope_id_; }
rows_id()624   uint32_t rows_id() const { return rows_id_; }
columns_id()625   uint32_t columns_id() const { return columns_id_; }
626 
627  private:
628   bool IsSameImpl(const Type* that, IsSameCache*) const override;
629 
630   const Type* component_type_;
631   const uint32_t scope_id_;
632   const uint32_t rows_id_;
633   const uint32_t columns_id_;
634 };
635 
636 #define DefineParameterlessType(type, name)                                    \
637   class type : public Type {                                                   \
638    public:                                                                     \
639     type() : Type(k##type) {}                                                  \
640     type(const type&) = default;                                               \
641                                                                                \
642     std::string str() const override { return #name; }                         \
643                                                                                \
644     type* As##type() override { return this; }                                 \
645     const type* As##type() const override { return this; }                     \
646                                                                                \
647     void GetExtraHashWords(std::vector<uint32_t>*,                             \
648                            std::unordered_set<const Type*>*) const override {} \
649                                                                                \
650    private:                                                                    \
651     bool IsSameImpl(const Type* that, IsSameCache*) const override {           \
652       return that->As##type() && HasSameDecorations(that);                     \
653     }                                                                          \
654   }
655 DefineParameterlessType(Void, void);
656 DefineParameterlessType(Bool, bool);
657 DefineParameterlessType(Sampler, sampler);
658 DefineParameterlessType(Event, event);
659 DefineParameterlessType(DeviceEvent, device_event);
660 DefineParameterlessType(ReserveId, reserve_id);
661 DefineParameterlessType(Queue, queue);
662 DefineParameterlessType(PipeStorage, pipe_storage);
663 DefineParameterlessType(NamedBarrier, named_barrier);
664 DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV);
665 DefineParameterlessType(RayQueryProvisionalKHR, rayQueryProvisionalKHR);
666 #undef DefineParameterlessType
667 
668 }  // namespace analysis
669 }  // namespace opt
670 }  // namespace spvtools
671 
672 #endif  // SOURCE_OPT_TYPES_H_
673