1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_SPIRVCODEGENERATOR
9 #define SKSL_SPIRVCODEGENERATOR
10 
11 #include <stack>
12 #include <tuple>
13 #include <unordered_map>
14 
15 #include "src/sksl/SkSLCodeGenerator.h"
16 #include "src/sksl/SkSLMemoryLayout.h"
17 #include "src/sksl/SkSLStringStream.h"
18 #include "src/sksl/ir/SkSLBinaryExpression.h"
19 #include "src/sksl/ir/SkSLBoolLiteral.h"
20 #include "src/sksl/ir/SkSLConstructor.h"
21 #include "src/sksl/ir/SkSLDoStatement.h"
22 #include "src/sksl/ir/SkSLFieldAccess.h"
23 #include "src/sksl/ir/SkSLFloatLiteral.h"
24 #include "src/sksl/ir/SkSLForStatement.h"
25 #include "src/sksl/ir/SkSLFunctionCall.h"
26 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
27 #include "src/sksl/ir/SkSLFunctionDefinition.h"
28 #include "src/sksl/ir/SkSLIfStatement.h"
29 #include "src/sksl/ir/SkSLIndexExpression.h"
30 #include "src/sksl/ir/SkSLIntLiteral.h"
31 #include "src/sksl/ir/SkSLInterfaceBlock.h"
32 #include "src/sksl/ir/SkSLPostfixExpression.h"
33 #include "src/sksl/ir/SkSLPrefixExpression.h"
34 #include "src/sksl/ir/SkSLProgramElement.h"
35 #include "src/sksl/ir/SkSLReturnStatement.h"
36 #include "src/sksl/ir/SkSLStatement.h"
37 #include "src/sksl/ir/SkSLSwitchStatement.h"
38 #include "src/sksl/ir/SkSLSwizzle.h"
39 #include "src/sksl/ir/SkSLTernaryExpression.h"
40 #include "src/sksl/ir/SkSLVarDeclarations.h"
41 #include "src/sksl/ir/SkSLVariableReference.h"
42 #include "src/sksl/ir/SkSLWhileStatement.h"
43 #include "src/sksl/spirv.h"
44 
45 union ConstantValue {
ConstantValue(int64_t i)46     ConstantValue(int64_t i)
47         : fInt(i) {
48         SkASSERT(sizeof(*this) == sizeof(int64_t));
49     }
50 
ConstantValue(SKSL_FLOAT f)51     ConstantValue(SKSL_FLOAT f) {
52         memset(this, 0, sizeof(*this));
53         fFloat = f;
54     }
55 
56     bool operator==(const ConstantValue& other) const {
57         return fInt == other.fInt;
58     }
59 
60     int64_t fInt;
61     SKSL_FLOAT fFloat;
62 };
63 
64 enum class ConstantType {
65     kInt,
66     kUInt,
67     kShort,
68     kUShort,
69     kFloat,
70     kDouble,
71     kHalf,
72 };
73 
74 namespace std {
75 
76 template <>
77 struct hash<std::pair<ConstantValue, ConstantType>> {
78     size_t operator()(const std::pair<ConstantValue, ConstantType>& key) const {
79         return key.first.fInt ^ (int) key.second;
80     }
81 };
82 
83 }  // namespace std
84 
85 namespace SkSL {
86 
87 #define kLast_Capability SpvCapabilityMultiViewport
88 
89 /**
90  * Converts a Program into a SPIR-V binary.
91  */
92 class SPIRVCodeGenerator : public CodeGenerator {
93 public:
94     class LValue {
95     public:
96         virtual ~LValue() {}
97 
98         // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
99         // by a pointer (e.g. vector swizzles), returns 0.
100         virtual SpvId getPointer() = 0;
101 
102         virtual SpvId load(OutputStream& out) = 0;
103 
104         virtual void store(SpvId value, OutputStream& out) = 0;
105     };
106 
107     SPIRVCodeGenerator(const Context* context,
108                        const Program* program,
109                        ErrorReporter* errors,
110                        OutputStream* out)
111             : INHERITED(program, errors, out)
112             , fContext(*context)
113             , fDefaultLayout(MemoryLayout::k140_Standard)
114             , fCapabilities(0)
115             , fIdCount(1)
116             , fBoolTrue(0)
117             , fBoolFalse(0)
118             , fSetupFragPosition(false)
119             , fCurrentBlock(0)
120             , fSynthetics(errors, /*builtin=*/true) {
121         this->setupIntrinsics();
122     }
123 
124     bool generateCode() override;
125 
126 private:
127     enum IntrinsicKind {
128         kGLSL_STD_450_IntrinsicKind,
129         kSPIRV_IntrinsicKind,
130         kSpecial_IntrinsicKind
131     };
132 
133     enum SpecialIntrinsic {
134         kAtan_SpecialIntrinsic,
135         kClamp_SpecialIntrinsic,
136         kMax_SpecialIntrinsic,
137         kMin_SpecialIntrinsic,
138         kMix_SpecialIntrinsic,
139         kMod_SpecialIntrinsic,
140         kDFdy_SpecialIntrinsic,
141         kSaturate_SpecialIntrinsic,
142         kSampledImage_SpecialIntrinsic,
143         kSubpassLoad_SpecialIntrinsic,
144         kTexture_SpecialIntrinsic,
145     };
146 
147     enum class Precision {
148         kLow,
149         kHigh,
150     };
151 
152     void setupIntrinsics();
153 
154     SpvId nextId();
155 
156     const Type& getActualType(const Type& type);
157 
158     SpvId getType(const Type& type);
159 
160     SpvId getType(const Type& type, const MemoryLayout& layout);
161 
162     SpvId getImageType(const Type& type);
163 
164     SpvId getFunctionType(const FunctionDeclaration& function);
165 
166     SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
167 
168     SpvId getPointerType(const Type& type, const MemoryLayout& layout,
169                          SpvStorageClass_ storageClass);
170 
171     void writePrecisionModifier(Precision precision, SpvId id);
172 
173     void writePrecisionModifier(const Type& type, SpvId id);
174 
175     std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
176 
177     void writeLayout(const Layout& layout, SpvId target);
178 
179     void writeLayout(const Layout& layout, SpvId target, int member);
180 
181     void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId);
182 
183     void writeProgramElement(const ProgramElement& pe, OutputStream& out);
184 
185     SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTHeight = true);
186 
187     SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
188 
189     SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
190 
191     SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
192 
193     void writeGlobalVar(Program::Kind kind, const VarDeclaration& v, OutputStream& out);
194 
195     void writeVarDeclaration(const VarDeclaration& var, OutputStream& out);
196 
197     SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
198 
199     std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
200 
201     SpvId writeExpression(const Expression& expr, OutputStream& out);
202 
203     SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
204 
205     SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
206 
207 
208     void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
209                                       SpvId signedInst, SpvId unsignedInst,
210                                       const std::vector<SpvId>& args, OutputStream& out);
211 
212     /**
213      * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
214      * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
215      * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
216      * vec2, vec3).
217      */
218     std::vector<SpvId> vectorize(const ExpressionArray& args, OutputStream& out);
219 
220     SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
221 
222     SpvId writeConstantVector(const Constructor& c);
223 
224     SpvId writeFloatConstructor(const Constructor& c, OutputStream& out);
225 
226     SpvId writeIntConstructor(const Constructor& c, OutputStream& out);
227 
228     SpvId writeUIntConstructor(const Constructor& c, OutputStream& out);
229 
230     /**
231      * Writes a matrix with the diagonal entries all equal to the provided expression, and all other
232      * entries equal to zero.
233      */
234     void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out);
235 
236     /**
237      * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
238      * source matrix are filled with zero; entries which do not exist in the destination matrix are
239      * ignored.
240      */
241     void writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, const Type& dstType,
242                          OutputStream& out);
243 
244     void addColumnEntry(SpvId columnType, Precision precision, std::vector<SpvId>* currentColumn,
245                         std::vector<SpvId>* columnIds, int* currentCount, int rows, SpvId entry,
246                         OutputStream& out);
247 
248     SpvId writeMatrixConstructor(const Constructor& c, OutputStream& out);
249 
250     SpvId writeVectorConstructor(const Constructor& c, OutputStream& out);
251 
252     SpvId writeArrayConstructor(const Constructor& c, OutputStream& out);
253 
254     SpvId writeConstructor(const Constructor& c, OutputStream& out);
255 
256     SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
257 
258     SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
259 
260     /**
261      * Folds the potentially-vector result of a logical operation down to a single bool. If
262      * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
263      * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
264      * returns the original id value.
265      */
266     SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
267 
268     SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
269                                 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
270                                 SpvOp_ mergeOperator, OutputStream& out);
271 
272     SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
273                                          SpvOp_ floatOperator, SpvOp_ intOperator,
274                                          OutputStream& out);
275 
276     SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
277                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
278                                SpvOp_ ifBool, OutputStream& out);
279 
280     SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
281                                SpvOp_ ifUInt, OutputStream& out);
282 
283     SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
284                                 const Type& rightType, SpvId rhs, const Type& resultType,
285                                 OutputStream& out);
286 
287     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
288 
289     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
290 
291     SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
292 
293     SpvId writeLogicalAnd(const BinaryExpression& b, OutputStream& out);
294 
295     SpvId writeLogicalOr(const BinaryExpression& o, OutputStream& out);
296 
297     SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
298 
299     SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
300 
301     SpvId writeBoolLiteral(const BoolLiteral& b);
302 
303     SpvId writeIntLiteral(const IntLiteral& i);
304 
305     SpvId writeFloatLiteral(const FloatLiteral& f);
306 
307     void writeStatement(const Statement& s, OutputStream& out);
308 
309     void writeBlock(const Block& b, OutputStream& out);
310 
311     void writeIfStatement(const IfStatement& stmt, OutputStream& out);
312 
313     void writeForStatement(const ForStatement& f, OutputStream& out);
314 
315     void writeWhileStatement(const WhileStatement& w, OutputStream& out);
316 
317     void writeDoStatement(const DoStatement& d, OutputStream& out);
318 
319     void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
320 
321     void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
322 
323     void writeCapabilities(OutputStream& out);
324 
325     void writeInstructions(const Program& program, OutputStream& out);
326 
327     void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
328 
329     void writeWord(int32_t word, OutputStream& out);
330 
331     void writeString(const char* string, size_t length, OutputStream& out);
332 
333     void writeLabel(SpvId id, OutputStream& out);
334 
335     void writeInstruction(SpvOp_ opCode, OutputStream& out);
336 
337     void writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out);
338 
339     void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
340 
341     void writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, OutputStream& out);
342 
343     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, StringFragment string,
344                           OutputStream& out);
345 
346     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
347 
348     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
349                           OutputStream& out);
350 
351     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
352                           OutputStream& out);
353 
354     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
355                           int32_t word5, OutputStream& out);
356 
357     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
358                           int32_t word5, int32_t word6, OutputStream& out);
359 
360     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
361                           int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
362 
363     void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
364                           int32_t word5, int32_t word6, int32_t word7, int32_t word8,
365                           OutputStream& out);
366 
367     void writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out);
368 
369     const Context& fContext;
370     const MemoryLayout fDefaultLayout;
371 
372     uint64_t fCapabilities;
373     SpvId fIdCount;
374     SpvId fGLSLExtendedInstructions;
375     typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic;
376     std::unordered_map<String, Intrinsic> fIntrinsicMap;
377     std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
378     std::unordered_map<const Variable*, SpvId> fVariableMap;
379     std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
380     std::unordered_map<String, SpvId> fImageTypeMap;
381     std::unordered_map<String, SpvId> fTypeMap;
382     StringStream fCapabilitiesBuffer;
383     StringStream fGlobalInitializersBuffer;
384     StringStream fConstantBuffer;
385     StringStream fExtraGlobalsBuffer;
386     StringStream fExternalFunctionsBuffer;
387     StringStream fVariableBuffer;
388     StringStream fNameBuffer;
389     StringStream fDecorationBuffer;
390 
391     SpvId fBoolTrue;
392     SpvId fBoolFalse;
393     std::unordered_map<std::pair<ConstantValue, ConstantType>, SpvId> fNumberConstants;
394     bool fSetupFragPosition;
395     // label of the current block, or 0 if we are not in a block
396     SpvId fCurrentBlock;
397     std::stack<SpvId> fBreakTarget;
398     std::stack<SpvId> fContinueTarget;
399     SpvId fRTHeightStructId = (SpvId) -1;
400     SpvId fRTHeightFieldIndex = (SpvId) -1;
401     SpvStorageClass_ fRTHeightStorageClass;
402     // holds variables synthesized during output, for lifetime purposes
403     SymbolTable fSynthetics;
404     int fSkInCount = 1;
405 
406     friend class PointerLValue;
407     friend class SwizzleLValue;
408 
409     using INHERITED = CodeGenerator;
410 };
411 
412 }  // namespace SkSL
413 
414 #endif
415