1 // Copyright (c) 2016 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // This file provides a class hierarchy for representing SPIR-V types. 16 17 #ifndef SOURCE_OPT_TYPES_H_ 18 #define SOURCE_OPT_TYPES_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <unordered_set> 26 #include <utility> 27 #include <vector> 28 29 #include "source/latest_version_spirv_header.h" 30 #include "source/opt/instruction.h" 31 #include "spirv-tools/libspirv.h" 32 33 namespace spvtools { 34 namespace opt { 35 namespace analysis { 36 37 class Void; 38 class Bool; 39 class Integer; 40 class Float; 41 class Vector; 42 class Matrix; 43 class Image; 44 class Sampler; 45 class SampledImage; 46 class Array; 47 class RuntimeArray; 48 class Struct; 49 class Opaque; 50 class Pointer; 51 class Function; 52 class Event; 53 class DeviceEvent; 54 class ReserveId; 55 class Queue; 56 class Pipe; 57 class ForwardPointer; 58 class PipeStorage; 59 class NamedBarrier; 60 class AccelerationStructureNV; 61 class CooperativeMatrixNV; 62 class RayQueryProvisionalKHR; 63 64 // Abstract class for a SPIR-V type. It has a bunch of As<sublcass>() methods, 65 // which is used as a way to probe the actual <subclass>. 66 class Type { 67 public: 68 typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache; 69 70 // Available subtypes. 71 // 72 // When adding a new derived class of Type, please add an entry to the enum. 73 enum Kind { 74 kVoid, 75 kBool, 76 kInteger, 77 kFloat, 78 kVector, 79 kMatrix, 80 kImage, 81 kSampler, 82 kSampledImage, 83 kArray, 84 kRuntimeArray, 85 kStruct, 86 kOpaque, 87 kPointer, 88 kFunction, 89 kEvent, 90 kDeviceEvent, 91 kReserveId, 92 kQueue, 93 kPipe, 94 kForwardPointer, 95 kPipeStorage, 96 kNamedBarrier, 97 kAccelerationStructureNV, 98 kCooperativeMatrixNV, 99 kRayQueryProvisionalKHR 100 }; 101 Type(Kind k)102 Type(Kind k) : kind_(k) {} 103 ~Type()104 virtual ~Type() {} 105 106 // Attaches a decoration directly on this type. AddDecoration(std::vector<uint32_t> && d)107 void AddDecoration(std::vector<uint32_t>&& d) { 108 decorations_.push_back(std::move(d)); 109 } 110 // Returns the decorations on this type as a string. 111 std::string GetDecorationStr() const; 112 // Returns true if this type has exactly the same decorations as |that| type. 113 bool HasSameDecorations(const Type* that) const; 114 // Returns true if this type is exactly the same as |that| type, including 115 // decorations. IsSame(const Type * that)116 bool IsSame(const Type* that) const { 117 IsSameCache seen; 118 return IsSameImpl(that, &seen); 119 } 120 121 // Returns true if this type is exactly the same as |that| type, including 122 // decorations. |seen| is the set of |Pointer*| pair that are currently being 123 // compared in a parent call to |IsSameImpl|. 124 virtual bool IsSameImpl(const Type* that, IsSameCache* seen) const = 0; 125 126 // Returns a human-readable string to represent this type. 127 virtual std::string str() const = 0; 128 kind()129 Kind kind() const { return kind_; } decorations()130 const std::vector<std::vector<uint32_t>>& decorations() const { 131 return decorations_; 132 } 133 134 // Returns true if there is no decoration on this type. For struct types, 135 // returns true only when there is no decoration for both the struct type 136 // and the struct members. decoration_empty()137 virtual bool decoration_empty() const { return decorations_.empty(); } 138 139 // Creates a clone of |this|. 140 std::unique_ptr<Type> Clone() const; 141 142 // Returns a clone of |this| minus any decorations. 143 std::unique_ptr<Type> RemoveDecorations() const; 144 145 // Returns true if this type must be unique. 146 // 147 // If variable pointers are allowed, then pointers are not required to be 148 // unique. 149 // TODO(alanbaker): Update this if variable pointers become a core feature. 150 bool IsUniqueType(bool allowVariablePointers = false) const; 151 152 bool operator==(const Type& other) const; 153 154 // Returns the hash value of this type. 155 size_t HashValue() const; 156 157 // Adds the necessary words to compute a hash value of this type to |words|. GetHashWords(std::vector<uint32_t> * words)158 void GetHashWords(std::vector<uint32_t>* words) const { 159 std::unordered_set<const Type*> seen; 160 GetHashWords(words, &seen); 161 } 162 163 // Adds the necessary words to compute a hash value of this type to |words|. 164 void GetHashWords(std::vector<uint32_t>* words, 165 std::unordered_set<const Type*>* seen) const; 166 167 // Adds necessary extra words for a subtype to calculate a hash value into 168 // |words|. 169 virtual void GetExtraHashWords( 170 std::vector<uint32_t>* words, 171 std::unordered_set<const Type*>* pSet) const = 0; 172 173 // A bunch of methods for casting this type to a given type. Returns this if the 174 // cast can be done, nullptr otherwise. 175 // clang-format off 176 #define DeclareCastMethod(target) \ 177 virtual target* As##target() { return nullptr; } \ 178 virtual const target* As##target() const { return nullptr; } 179 DeclareCastMethod(Void) 180 DeclareCastMethod(Bool) 181 DeclareCastMethod(Integer) 182 DeclareCastMethod(Float) 183 DeclareCastMethod(Vector) 184 DeclareCastMethod(Matrix) 185 DeclareCastMethod(Image) 186 DeclareCastMethod(Sampler) 187 DeclareCastMethod(SampledImage) 188 DeclareCastMethod(Array) 189 DeclareCastMethod(RuntimeArray) 190 DeclareCastMethod(Struct) 191 DeclareCastMethod(Opaque) 192 DeclareCastMethod(Pointer) 193 DeclareCastMethod(Function) 194 DeclareCastMethod(Event) 195 DeclareCastMethod(DeviceEvent) 196 DeclareCastMethod(ReserveId) 197 DeclareCastMethod(Queue) 198 DeclareCastMethod(Pipe) 199 DeclareCastMethod(ForwardPointer) 200 DeclareCastMethod(PipeStorage) 201 DeclareCastMethod(NamedBarrier) 202 DeclareCastMethod(AccelerationStructureNV) 203 DeclareCastMethod(CooperativeMatrixNV) 204 DeclareCastMethod(RayQueryProvisionalKHR) 205 #undef DeclareCastMethod 206 207 protected: 208 // Decorations attached to this type. Each decoration is encoded as a vector 209 // of uint32_t numbers. The first uint32_t number is the decoration value, 210 // and the rest are the parameters to the decoration (if exists). 211 std::vector<std::vector<uint32_t>> decorations_; 212 213 private: 214 // Removes decorations on this type. For struct types, also removes element 215 // decorations. ClearDecorations()216 virtual void ClearDecorations() { decorations_.clear(); } 217 218 Kind kind_; 219 }; 220 // clang-format on 221 222 class Integer : public Type { 223 public: Integer(uint32_t w,bool is_signed)224 Integer(uint32_t w, bool is_signed) 225 : Type(kInteger), width_(w), signed_(is_signed) {} 226 Integer(const Integer&) = default; 227 228 std::string str() const override; 229 AsInteger()230 Integer* AsInteger() override { return this; } AsInteger()231 const Integer* AsInteger() const override { return this; } width()232 uint32_t width() const { return width_; } IsSigned()233 bool IsSigned() const { return signed_; } 234 235 void GetExtraHashWords(std::vector<uint32_t>* words, 236 std::unordered_set<const Type*>* pSet) const override; 237 238 private: 239 bool IsSameImpl(const Type* that, IsSameCache*) const override; 240 241 uint32_t width_; // bit width 242 bool signed_; // true if this integer is signed 243 }; 244 245 class Float : public Type { 246 public: Float(uint32_t w)247 Float(uint32_t w) : Type(kFloat), width_(w) {} 248 Float(const Float&) = default; 249 250 std::string str() const override; 251 AsFloat()252 Float* AsFloat() override { return this; } AsFloat()253 const Float* AsFloat() const override { return this; } width()254 uint32_t width() const { return width_; } 255 256 void GetExtraHashWords(std::vector<uint32_t>* words, 257 std::unordered_set<const Type*>* pSet) const override; 258 259 private: 260 bool IsSameImpl(const Type* that, IsSameCache*) const override; 261 262 uint32_t width_; // bit width 263 }; 264 265 class Vector : public Type { 266 public: 267 Vector(const Type* element_type, uint32_t count); 268 Vector(const Vector&) = default; 269 270 std::string str() const override; element_type()271 const Type* element_type() const { return element_type_; } element_count()272 uint32_t element_count() const { return count_; } 273 AsVector()274 Vector* AsVector() override { return this; } AsVector()275 const Vector* AsVector() const override { return this; } 276 277 void GetExtraHashWords(std::vector<uint32_t>* words, 278 std::unordered_set<const Type*>* pSet) const override; 279 280 private: 281 bool IsSameImpl(const Type* that, IsSameCache*) const override; 282 283 const Type* element_type_; 284 uint32_t count_; 285 }; 286 287 class Matrix : public Type { 288 public: 289 Matrix(const Type* element_type, uint32_t count); 290 Matrix(const Matrix&) = default; 291 292 std::string str() const override; element_type()293 const Type* element_type() const { return element_type_; } element_count()294 uint32_t element_count() const { return count_; } 295 AsMatrix()296 Matrix* AsMatrix() override { return this; } AsMatrix()297 const Matrix* AsMatrix() const override { return this; } 298 299 void GetExtraHashWords(std::vector<uint32_t>* words, 300 std::unordered_set<const Type*>* pSet) const override; 301 302 private: 303 bool IsSameImpl(const Type* that, IsSameCache*) const override; 304 305 const Type* element_type_; 306 uint32_t count_; 307 }; 308 309 class Image : public Type { 310 public: 311 Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, 312 uint32_t sampling, SpvImageFormat f, 313 SpvAccessQualifier qualifier = SpvAccessQualifierReadOnly); 314 Image(const Image&) = default; 315 316 std::string str() const override; 317 AsImage()318 Image* AsImage() override { return this; } AsImage()319 const Image* AsImage() const override { return this; } 320 sampled_type()321 const Type* sampled_type() const { return sampled_type_; } dim()322 SpvDim dim() const { return dim_; } depth()323 uint32_t depth() const { return depth_; } is_arrayed()324 bool is_arrayed() const { return arrayed_; } is_multisampled()325 bool is_multisampled() const { return ms_; } sampled()326 uint32_t sampled() const { return sampled_; } format()327 SpvImageFormat format() const { return format_; } access_qualifier()328 SpvAccessQualifier access_qualifier() const { return access_qualifier_; } 329 330 void GetExtraHashWords(std::vector<uint32_t>* words, 331 std::unordered_set<const Type*>* pSet) const override; 332 333 private: 334 bool IsSameImpl(const Type* that, IsSameCache*) const override; 335 336 Type* sampled_type_; 337 SpvDim dim_; 338 uint32_t depth_; 339 bool arrayed_; 340 bool ms_; 341 uint32_t sampled_; 342 SpvImageFormat format_; 343 SpvAccessQualifier access_qualifier_; 344 }; 345 346 class SampledImage : public Type { 347 public: SampledImage(Type * image)348 SampledImage(Type* image) : Type(kSampledImage), image_type_(image) {} 349 SampledImage(const SampledImage&) = default; 350 351 std::string str() const override; 352 AsSampledImage()353 SampledImage* AsSampledImage() override { return this; } AsSampledImage()354 const SampledImage* AsSampledImage() const override { return this; } 355 image_type()356 const Type* image_type() const { return image_type_; } 357 358 void GetExtraHashWords(std::vector<uint32_t>* words, 359 std::unordered_set<const Type*>* pSet) const override; 360 361 private: 362 bool IsSameImpl(const Type* that, IsSameCache*) const override; 363 Type* image_type_; 364 }; 365 366 class Array : public Type { 367 public: 368 // Data about the length operand, that helps us distinguish between one 369 // array length and another. 370 struct LengthInfo { 371 // The result id of the instruction defining the length. 372 const uint32_t id; 373 enum Case : uint32_t { 374 kConstant = 0, 375 kConstantWithSpecId = 1, 376 kDefiningId = 2 377 }; 378 // Extra words used to distinshish one array length and another. 379 // - if OpConstant, then it's 0, then the words in the literal constant 380 // value. 381 // - if OpSpecConstant, then it's 1, then the SpecID decoration if there 382 // is one, followed by the words in the literal constant value. 383 // The spec might not be overridden, in which case we'll end up using 384 // the literal value. 385 // - Otherwise, it's an OpSpecConsant, and this 2, then the ID (again). 386 const std::vector<uint32_t> words; 387 }; 388 389 // Constructs an array type with given element and length. If the length 390 // is an OpSpecConstant, then |spec_id| should be its SpecId decoration. 391 Array(const Type* element_type, const LengthInfo& length_info_arg); 392 Array(const Array&) = default; 393 394 std::string str() const override; element_type()395 const Type* element_type() const { return element_type_; } LengthId()396 uint32_t LengthId() const { return length_info_.id; } length_info()397 const LengthInfo& length_info() const { return length_info_; } 398 AsArray()399 Array* AsArray() override { return this; } AsArray()400 const Array* AsArray() const override { return this; } 401 402 void GetExtraHashWords(std::vector<uint32_t>* words, 403 std::unordered_set<const Type*>* pSet) const override; 404 405 void ReplaceElementType(const Type* element_type); 406 407 private: 408 bool IsSameImpl(const Type* that, IsSameCache*) const override; 409 410 const Type* element_type_; 411 const LengthInfo length_info_; 412 }; 413 414 class RuntimeArray : public Type { 415 public: 416 RuntimeArray(const Type* element_type); 417 RuntimeArray(const RuntimeArray&) = default; 418 419 std::string str() const override; element_type()420 const Type* element_type() const { return element_type_; } 421 AsRuntimeArray()422 RuntimeArray* AsRuntimeArray() override { return this; } AsRuntimeArray()423 const RuntimeArray* AsRuntimeArray() const override { return this; } 424 425 void GetExtraHashWords(std::vector<uint32_t>* words, 426 std::unordered_set<const Type*>* pSet) const override; 427 428 void ReplaceElementType(const Type* element_type); 429 430 private: 431 bool IsSameImpl(const Type* that, IsSameCache*) const override; 432 433 const Type* element_type_; 434 }; 435 436 class Struct : public Type { 437 public: 438 Struct(const std::vector<const Type*>& element_types); 439 Struct(const Struct&) = default; 440 441 // Adds a decoration to the member at the given index. The first word is the 442 // decoration enum, and the remaining words, if any, are its operands. 443 void AddMemberDecoration(uint32_t index, std::vector<uint32_t>&& decoration); 444 445 std::string str() const override; element_types()446 const std::vector<const Type*>& element_types() const { 447 return element_types_; 448 } element_types()449 std::vector<const Type*>& element_types() { return element_types_; } decoration_empty()450 bool decoration_empty() const override { 451 return decorations_.empty() && element_decorations_.empty(); 452 } 453 454 const std::map<uint32_t, std::vector<std::vector<uint32_t>>>& element_decorations()455 element_decorations() const { 456 return element_decorations_; 457 } 458 AsStruct()459 Struct* AsStruct() override { return this; } AsStruct()460 const Struct* AsStruct() const override { return this; } 461 462 void GetExtraHashWords(std::vector<uint32_t>* words, 463 std::unordered_set<const Type*>* pSet) const override; 464 465 private: 466 bool IsSameImpl(const Type* that, IsSameCache*) const override; 467 ClearDecorations()468 void ClearDecorations() override { 469 decorations_.clear(); 470 element_decorations_.clear(); 471 } 472 473 std::vector<const Type*> element_types_; 474 // We can attach decorations to struct members and that should not affect the 475 // underlying element type. So we need an extra data structure here to keep 476 // track of element type decorations. They must be stored in an ordered map 477 // because |GetExtraHashWords| will traverse the structure. It must have a 478 // fixed order in order to hash to the same value every time. 479 std::map<uint32_t, std::vector<std::vector<uint32_t>>> element_decorations_; 480 }; 481 482 class Opaque : public Type { 483 public: Opaque(std::string n)484 Opaque(std::string n) : Type(kOpaque), name_(std::move(n)) {} 485 Opaque(const Opaque&) = default; 486 487 std::string str() const override; 488 AsOpaque()489 Opaque* AsOpaque() override { return this; } AsOpaque()490 const Opaque* AsOpaque() const override { return this; } 491 name()492 const std::string& name() const { return name_; } 493 494 void GetExtraHashWords(std::vector<uint32_t>* words, 495 std::unordered_set<const Type*>* pSet) const override; 496 497 private: 498 bool IsSameImpl(const Type* that, IsSameCache*) const override; 499 500 std::string name_; 501 }; 502 503 class Pointer : public Type { 504 public: 505 Pointer(const Type* pointee, SpvStorageClass sc); 506 Pointer(const Pointer&) = default; 507 508 std::string str() const override; pointee_type()509 const Type* pointee_type() const { return pointee_type_; } storage_class()510 SpvStorageClass storage_class() const { return storage_class_; } 511 AsPointer()512 Pointer* AsPointer() override { return this; } AsPointer()513 const Pointer* AsPointer() const override { return this; } 514 515 void GetExtraHashWords(std::vector<uint32_t>* words, 516 std::unordered_set<const Type*>* pSet) const override; 517 518 void SetPointeeType(const Type* type); 519 520 private: 521 bool IsSameImpl(const Type* that, IsSameCache*) const override; 522 523 const Type* pointee_type_; 524 SpvStorageClass storage_class_; 525 }; 526 527 class Function : public Type { 528 public: 529 Function(const Type* ret_type, const std::vector<const Type*>& params); 530 Function(const Type* ret_type, std::vector<const Type*>& params); 531 Function(const Function&) = default; 532 533 std::string str() const override; 534 AsFunction()535 Function* AsFunction() override { return this; } AsFunction()536 const Function* AsFunction() const override { return this; } 537 return_type()538 const Type* return_type() const { return return_type_; } param_types()539 const std::vector<const Type*>& param_types() const { return param_types_; } param_types()540 std::vector<const Type*>& param_types() { return param_types_; } 541 542 void GetExtraHashWords(std::vector<uint32_t>* words, 543 std::unordered_set<const Type*>*) const override; 544 545 void SetReturnType(const Type* type); 546 547 private: 548 bool IsSameImpl(const Type* that, IsSameCache*) const override; 549 550 const Type* return_type_; 551 std::vector<const Type*> param_types_; 552 }; 553 554 class Pipe : public Type { 555 public: Pipe(SpvAccessQualifier qualifier)556 Pipe(SpvAccessQualifier qualifier) 557 : Type(kPipe), access_qualifier_(qualifier) {} 558 Pipe(const Pipe&) = default; 559 560 std::string str() const override; 561 AsPipe()562 Pipe* AsPipe() override { return this; } AsPipe()563 const Pipe* AsPipe() const override { return this; } 564 access_qualifier()565 SpvAccessQualifier access_qualifier() const { return access_qualifier_; } 566 567 void GetExtraHashWords(std::vector<uint32_t>* words, 568 std::unordered_set<const Type*>* pSet) const override; 569 570 private: 571 bool IsSameImpl(const Type* that, IsSameCache*) const override; 572 573 SpvAccessQualifier access_qualifier_; 574 }; 575 576 class ForwardPointer : public Type { 577 public: ForwardPointer(uint32_t id,SpvStorageClass sc)578 ForwardPointer(uint32_t id, SpvStorageClass sc) 579 : Type(kForwardPointer), 580 target_id_(id), 581 storage_class_(sc), 582 pointer_(nullptr) {} 583 ForwardPointer(const ForwardPointer&) = default; 584 target_id()585 uint32_t target_id() const { return target_id_; } SetTargetPointer(const Pointer * pointer)586 void SetTargetPointer(const Pointer* pointer) { pointer_ = pointer; } storage_class()587 SpvStorageClass storage_class() const { return storage_class_; } target_pointer()588 const Pointer* target_pointer() const { return pointer_; } 589 590 std::string str() const override; 591 AsForwardPointer()592 ForwardPointer* AsForwardPointer() override { return this; } AsForwardPointer()593 const ForwardPointer* AsForwardPointer() const override { return this; } 594 595 void GetExtraHashWords(std::vector<uint32_t>* words, 596 std::unordered_set<const Type*>* pSet) const override; 597 598 private: 599 bool IsSameImpl(const Type* that, IsSameCache*) const override; 600 601 uint32_t target_id_; 602 SpvStorageClass storage_class_; 603 const Pointer* pointer_; 604 }; 605 606 class CooperativeMatrixNV : public Type { 607 public: 608 CooperativeMatrixNV(const Type* type, const uint32_t scope, 609 const uint32_t rows, const uint32_t columns); 610 CooperativeMatrixNV(const CooperativeMatrixNV&) = default; 611 612 std::string str() const override; 613 AsCooperativeMatrixNV()614 CooperativeMatrixNV* AsCooperativeMatrixNV() override { return this; } AsCooperativeMatrixNV()615 const CooperativeMatrixNV* AsCooperativeMatrixNV() const override { 616 return this; 617 } 618 619 void GetExtraHashWords(std::vector<uint32_t>*, 620 std::unordered_set<const Type*>*) const override; 621 component_type()622 const Type* component_type() const { return component_type_; } scope_id()623 uint32_t scope_id() const { return scope_id_; } rows_id()624 uint32_t rows_id() const { return rows_id_; } columns_id()625 uint32_t columns_id() const { return columns_id_; } 626 627 private: 628 bool IsSameImpl(const Type* that, IsSameCache*) const override; 629 630 const Type* component_type_; 631 const uint32_t scope_id_; 632 const uint32_t rows_id_; 633 const uint32_t columns_id_; 634 }; 635 636 #define DefineParameterlessType(type, name) \ 637 class type : public Type { \ 638 public: \ 639 type() : Type(k##type) {} \ 640 type(const type&) = default; \ 641 \ 642 std::string str() const override { return #name; } \ 643 \ 644 type* As##type() override { return this; } \ 645 const type* As##type() const override { return this; } \ 646 \ 647 void GetExtraHashWords(std::vector<uint32_t>*, \ 648 std::unordered_set<const Type*>*) const override {} \ 649 \ 650 private: \ 651 bool IsSameImpl(const Type* that, IsSameCache*) const override { \ 652 return that->As##type() && HasSameDecorations(that); \ 653 } \ 654 } 655 DefineParameterlessType(Void, void); 656 DefineParameterlessType(Bool, bool); 657 DefineParameterlessType(Sampler, sampler); 658 DefineParameterlessType(Event, event); 659 DefineParameterlessType(DeviceEvent, device_event); 660 DefineParameterlessType(ReserveId, reserve_id); 661 DefineParameterlessType(Queue, queue); 662 DefineParameterlessType(PipeStorage, pipe_storage); 663 DefineParameterlessType(NamedBarrier, named_barrier); 664 DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV); 665 DefineParameterlessType(RayQueryProvisionalKHR, rayQueryProvisionalKHR); 666 #undef DefineParameterlessType 667 668 } // namespace analysis 669 } // namespace opt 670 } // namespace spvtools 671 672 #endif // SOURCE_OPT_TYPES_H_ 673