1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines utilities to use while converting to the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
14 #define MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/SmallSet.h"
21 
22 namespace mlir {
23 
24 /// Type conversion from builtin types to SPIR-V types for shader interface.
25 ///
26 /// Non-32-bit scalar types require special hardware support that may not exist
27 /// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types
28 /// require special capabilities or extensions. Right now if a scalar type of a
29 /// certain bitwidth is not supported in the target environment, we use 32-bit
30 /// ones unconditionally. This requires the runtime to also feed in data with
31 /// a matched bitwidth and layout for interface types. The runtime can do that
32 /// by inspecting the SPIR-V module.
33 ///
34 /// For memref types, this converter additionally performs type wrapping to
35 /// satisfy shader interface requirements: shader interface types must be
36 /// pointers to structs.
37 ///
38 /// TODO: We might want to introduce a way to control how unsupported bitwidth
39 /// are handled and explicitly fail if wanted.
40 class SPIRVTypeConverter : public TypeConverter {
41 public:
42   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
43 
44   /// Gets the number of bytes used for a type when converted to SPIR-V
45   /// type. Note that it doesnt account for whether the type is legal for a
46   /// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on
47   /// failure.
48   static Optional<int64_t> getConvertedTypeNumBytes(Type);
49 
50   /// Gets the SPIR-V correspondence for the standard index type.
51   static Type getIndexType(MLIRContext *context);
52 
53   /// Returns the corresponding memory space for memref given a SPIR-V storage
54   /// class.
55   static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
56 
57   /// Returns the SPIR-V storage class given a memory space for memref. Return
58   /// llvm::None if the memory space does not map to any SPIR-V storage class.
59   static Optional<spirv::StorageClass>
60   getStorageClassForMemorySpace(unsigned space);
61 
62 private:
63   spirv::TargetEnv targetEnv;
64 };
65 
66 /// Appends to a pattern list additional patterns for translating the builtin
67 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
68 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
69 /// types.
70 void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
71                                         SPIRVTypeConverter &typeConverter,
72                                         OwningRewritePatternList &patterns);
73 
74 namespace spirv {
75 class AccessChainOp;
76 class FuncOp;
77 
78 class SPIRVConversionTarget : public ConversionTarget {
79 public:
80   /// Creates a SPIR-V conversion target for the given target environment.
81   static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
82 
83 private:
84   explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
85 
86   // Be explicit that instance of this class cannot be copied or moved: there
87   // are lambdas capturing fields of the instance.
88   SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
89   SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
90   SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
91   SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
92 
93   /// Returns true if the given `op` is legal to use under the current target
94   /// environment.
95   bool isLegalOp(Operation *op);
96 
97   TargetEnv targetEnv;
98 };
99 
100 /// Returns the value for the given `builtin` variable. This function gets or
101 /// inserts the global variable associated for the builtin within the nearest
102 /// enclosing op that has a symbol table. Returns null Value if such an
103 /// enclosing op cannot be found.
104 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
105                               OpBuilder &builder);
106 
107 /// Performs the index computation to get to the element at `indices` of the
108 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
109 
110 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
111 // that has static strides. Extend to handle dynamic strides.
112 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
113                                    MemRefType baseType, Value basePtr,
114                                    ValueRange indices, Location loc,
115                                    OpBuilder &builder);
116 
117 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
118 /// arguments.
119 LogicalResult setABIAttrs(spirv::FuncOp funcOp,
120                           EntryPointABIAttr entryPointInfo,
121                           ArrayRef<InterfaceVarABIAttr> argABIInfo);
122 } // namespace spirv
123 } // namespace mlir
124 
125 #endif // MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
126