1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/SkSLSPIRVCodeGenerator.h"
9 
10 #include "src/sksl/GLSL.std.450.h"
11 
12 #include "src/sksl/SkSLCompiler.h"
13 #include "src/sksl/ir/SkSLExpressionStatement.h"
14 #include "src/sksl/ir/SkSLExtension.h"
15 #include "src/sksl/ir/SkSLIndexExpression.h"
16 #include "src/sksl/ir/SkSLVariableReference.h"
17 
18 #ifdef SK_VULKAN
19 #include "src/gpu/vk/GrVkCaps.h"
20 #endif
21 
22 namespace SkSL {
23 
24 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
25 
setupIntrinsics()26 void SPIRVCodeGenerator::setupIntrinsics() {
27 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
28                                     GLSLstd450 ## x, GLSLstd450 ## x)
29 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
30                                                              GLSLstd450 ## ifFloat, \
31                                                              GLSLstd450 ## ifInt, \
32                                                              GLSLstd450 ## ifUInt, \
33                                                              SpvOpUndef)
34 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
35                                                            SpvOp ## x)
36 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
37                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
38                                    k ## x ## _SpecialIntrinsic)
39     fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
40     fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
41     fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
42     fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
43     fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
44     fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
45     fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
46     fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
47     fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
48     fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
49     fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
50     fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
51     fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
52     fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
53     fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
54     fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
55     fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
56     fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
57     fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
58     fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
59     fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
60     fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
61     fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
62     fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
63     fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
64     fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
65     fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
66     fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
67     fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
68     fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
69     fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
70     fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
71     fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
72     fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
73     fIntrinsicMap[String("min")]           = SPECIAL(Min);
74     fIntrinsicMap[String("max")]           = SPECIAL(Max);
75     fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
76     fIntrinsicMap[String("saturate")]      = SPECIAL(Saturate);
77     fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
78                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
79     fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
80     fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
81     fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
82     fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
83     fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
84     fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
85 
86 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
87                    fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
88     PACK(Snorm4x8);
89     PACK(Unorm4x8);
90     PACK(Snorm2x16);
91     PACK(Unorm2x16);
92     PACK(Half2x16);
93     PACK(Double2x32);
94     fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
95     fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
96     fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
97     fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
98     fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
99     fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
100     fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
101     fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
102     fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
103     fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
104                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
105     fIntrinsicMap[String("dFdy")]        = SPECIAL(DFdy);
106     fIntrinsicMap[String("fwidth")]      = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFwidth,
107                                                            SpvOpUndef, SpvOpUndef, SpvOpUndef);
108     fIntrinsicMap[String("makeSampler2D")] = SPECIAL(SampledImage);
109 
110     fIntrinsicMap[String("sample")]      = SPECIAL(Texture);
111     fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
112 
113     fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
114                                                                 SpvOpUndef, SpvOpUndef, SpvOpAny);
115     fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
116                                                                 SpvOpUndef, SpvOpUndef, SpvOpAll);
117     fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
118                                                                 SpvOpFOrdEqual, SpvOpIEqual,
119                                                                 SpvOpIEqual, SpvOpLogicalEqual);
120     fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
121                                                                 SpvOpFOrdNotEqual, SpvOpINotEqual,
122                                                                 SpvOpINotEqual,
123                                                                 SpvOpLogicalNotEqual);
124     fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
125                                                                 SpvOpFOrdLessThan, SpvOpSLessThan,
126                                                                 SpvOpULessThan, SpvOpUndef);
127     fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                 SpvOpFOrdLessThanEqual,
129                                                                 SpvOpSLessThanEqual,
130                                                                 SpvOpULessThanEqual,
131                                                                 SpvOpUndef);
132     fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                 SpvOpFOrdGreaterThan,
134                                                                 SpvOpSGreaterThan,
135                                                                 SpvOpUGreaterThan,
136                                                                 SpvOpUndef);
137     fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
138                                                                 SpvOpFOrdGreaterThanEqual,
139                                                                 SpvOpSGreaterThanEqual,
140                                                                 SpvOpUGreaterThanEqual,
141                                                                 SpvOpUndef);
142     fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
143     fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
144 // interpolateAt* not yet supported...
145 }
146 
writeWord(int32_t word,OutputStream & out)147 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
148     out.write((const char*) &word, sizeof(word));
149 }
150 
is_float(const Context & context,const Type & type)151 static bool is_float(const Context& context, const Type& type) {
152     if (type.columns() > 1) {
153         return is_float(context, type.componentType());
154     }
155     return type == *context.fFloat_Type || type == *context.fHalf_Type ||
156            type == *context.fDouble_Type;
157 }
158 
is_signed(const Context & context,const Type & type)159 static bool is_signed(const Context& context, const Type& type) {
160     if (type.kind() == Type::kVector_Kind) {
161         return is_signed(context, type.componentType());
162     }
163     return type == *context.fInt_Type || type == *context.fShort_Type ||
164            type == *context.fByte_Type;
165 }
166 
is_unsigned(const Context & context,const Type & type)167 static bool is_unsigned(const Context& context, const Type& type) {
168     if (type.kind() == Type::kVector_Kind) {
169         return is_unsigned(context, type.componentType());
170     }
171     return type == *context.fUInt_Type || type == *context.fUShort_Type ||
172            type == *context.fUByte_Type;
173 }
174 
is_bool(const Context & context,const Type & type)175 static bool is_bool(const Context& context, const Type& type) {
176     if (type.kind() == Type::kVector_Kind) {
177         return is_bool(context, type.componentType());
178     }
179     return type == *context.fBool_Type;
180 }
181 
is_out(const Variable & var)182 static bool is_out(const Variable& var) {
183     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
184 }
185 
writeOpCode(SpvOp_ opCode,int length,OutputStream & out)186 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
187     SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
188     SkASSERT(opCode != SpvOpUndef);
189     switch (opCode) {
190         case SpvOpReturn:      // fall through
191         case SpvOpReturnValue: // fall through
192         case SpvOpKill:        // fall through
193         case SpvOpBranch:      // fall through
194         case SpvOpBranchConditional:
195             SkASSERT(fCurrentBlock);
196             fCurrentBlock = 0;
197             break;
198         case SpvOpConstant:          // fall through
199         case SpvOpConstantTrue:      // fall through
200         case SpvOpConstantFalse:     // fall through
201         case SpvOpConstantComposite: // fall through
202         case SpvOpTypeVoid:          // fall through
203         case SpvOpTypeInt:           // fall through
204         case SpvOpTypeFloat:         // fall through
205         case SpvOpTypeBool:          // fall through
206         case SpvOpTypeVector:        // fall through
207         case SpvOpTypeMatrix:        // fall through
208         case SpvOpTypeArray:         // fall through
209         case SpvOpTypePointer:       // fall through
210         case SpvOpTypeFunction:      // fall through
211         case SpvOpTypeRuntimeArray:  // fall through
212         case SpvOpTypeStruct:        // fall through
213         case SpvOpTypeImage:         // fall through
214         case SpvOpTypeSampledImage:  // fall through
215         case SpvOpTypeSampler:       // fall through
216         case SpvOpVariable:          // fall through
217         case SpvOpFunction:          // fall through
218         case SpvOpFunctionParameter: // fall through
219         case SpvOpFunctionEnd:       // fall through
220         case SpvOpExecutionMode:     // fall through
221         case SpvOpMemoryModel:       // fall through
222         case SpvOpCapability:        // fall through
223         case SpvOpExtInstImport:     // fall through
224         case SpvOpEntryPoint:        // fall through
225         case SpvOpSource:            // fall through
226         case SpvOpSourceExtension:   // fall through
227         case SpvOpName:              // fall through
228         case SpvOpMemberName:        // fall through
229         case SpvOpDecorate:          // fall through
230         case SpvOpMemberDecorate:
231             break;
232         default:
233             SkASSERT(fCurrentBlock);
234     }
235     this->writeWord((length << 16) | opCode, out);
236 }
237 
writeLabel(SpvId label,OutputStream & out)238 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
239     fCurrentBlock = label;
240     this->writeInstruction(SpvOpLabel, label, out);
241 }
242 
writeInstruction(SpvOp_ opCode,OutputStream & out)243 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
244     this->writeOpCode(opCode, 1, out);
245 }
246 
writeInstruction(SpvOp_ opCode,int32_t word1,OutputStream & out)247 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
248     this->writeOpCode(opCode, 2, out);
249     this->writeWord(word1, out);
250 }
251 
writeString(const char * string,size_t length,OutputStream & out)252 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
253     out.write(string, length);
254     switch (length % 4) {
255         case 1:
256             out.write8(0);
257             // fall through
258         case 2:
259             out.write8(0);
260             // fall through
261         case 3:
262             out.write8(0);
263             break;
264         default:
265             this->writeWord(0, out);
266     }
267 }
268 
writeInstruction(SpvOp_ opCode,StringFragment string,OutputStream & out)269 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
270     this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
271     this->writeString(string.fChars, string.fLength, out);
272 }
273 
274 
writeInstruction(SpvOp_ opCode,int32_t word1,StringFragment string,OutputStream & out)275 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
276                                           OutputStream& out) {
277     this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
278     this->writeWord(word1, out);
279     this->writeString(string.fChars, string.fLength, out);
280 }
281 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,StringFragment string,OutputStream & out)282 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283                                           StringFragment string, OutputStream& out) {
284     this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
285     this->writeWord(word1, out);
286     this->writeWord(word2, out);
287     this->writeString(string.fChars, string.fLength, out);
288 }
289 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,OutputStream & out)290 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
291                                           OutputStream& out) {
292     this->writeOpCode(opCode, 3, out);
293     this->writeWord(word1, out);
294     this->writeWord(word2, out);
295 }
296 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,OutputStream & out)297 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298                                           int32_t word3, OutputStream& out) {
299     this->writeOpCode(opCode, 4, out);
300     this->writeWord(word1, out);
301     this->writeWord(word2, out);
302     this->writeWord(word3, out);
303 }
304 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,OutputStream & out)305 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
306                                           int32_t word3, int32_t word4, OutputStream& out) {
307     this->writeOpCode(opCode, 5, out);
308     this->writeWord(word1, out);
309     this->writeWord(word2, out);
310     this->writeWord(word3, out);
311     this->writeWord(word4, out);
312 }
313 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,OutputStream & out)314 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
315                                           int32_t word3, int32_t word4, int32_t word5,
316                                           OutputStream& out) {
317     this->writeOpCode(opCode, 6, out);
318     this->writeWord(word1, out);
319     this->writeWord(word2, out);
320     this->writeWord(word3, out);
321     this->writeWord(word4, out);
322     this->writeWord(word5, out);
323 }
324 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,OutputStream & out)325 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
326                                           int32_t word3, int32_t word4, int32_t word5,
327                                           int32_t word6, OutputStream& out) {
328     this->writeOpCode(opCode, 7, out);
329     this->writeWord(word1, out);
330     this->writeWord(word2, out);
331     this->writeWord(word3, out);
332     this->writeWord(word4, out);
333     this->writeWord(word5, out);
334     this->writeWord(word6, out);
335 }
336 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,OutputStream & out)337 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
338                                           int32_t word3, int32_t word4, int32_t word5,
339                                           int32_t word6, int32_t word7, OutputStream& out) {
340     this->writeOpCode(opCode, 8, out);
341     this->writeWord(word1, out);
342     this->writeWord(word2, out);
343     this->writeWord(word3, out);
344     this->writeWord(word4, out);
345     this->writeWord(word5, out);
346     this->writeWord(word6, out);
347     this->writeWord(word7, out);
348 }
349 
writeInstruction(SpvOp_ opCode,int32_t word1,int32_t word2,int32_t word3,int32_t word4,int32_t word5,int32_t word6,int32_t word7,int32_t word8,OutputStream & out)350 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
351                                           int32_t word3, int32_t word4, int32_t word5,
352                                           int32_t word6, int32_t word7, int32_t word8,
353                                           OutputStream& out) {
354     this->writeOpCode(opCode, 9, out);
355     this->writeWord(word1, out);
356     this->writeWord(word2, out);
357     this->writeWord(word3, out);
358     this->writeWord(word4, out);
359     this->writeWord(word5, out);
360     this->writeWord(word6, out);
361     this->writeWord(word7, out);
362     this->writeWord(word8, out);
363 }
364 
writeCapabilities(OutputStream & out)365 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
366     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
367         if (fCapabilities & bit) {
368             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
369         }
370     }
371     if (fProgram.fKind == Program::kGeometry_Kind) {
372         this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
373     }
374     else {
375         this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
376     }
377 }
378 
nextId()379 SpvId SPIRVCodeGenerator::nextId() {
380     return fIdCount++;
381 }
382 
writeStruct(const Type & type,const MemoryLayout & memoryLayout,SpvId resultId)383 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
384                                      SpvId resultId) {
385     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
386     // go ahead and write all of the field types, so we don't inadvertently write them while we're
387     // in the middle of writing the struct instruction
388     std::vector<SpvId> types;
389     for (const auto& f : type.fields()) {
390         types.push_back(this->getType(*f.fType, memoryLayout));
391     }
392     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
393     this->writeWord(resultId, fConstantBuffer);
394     for (SpvId id : types) {
395         this->writeWord(id, fConstantBuffer);
396     }
397     size_t offset = 0;
398     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
399         const Type::Field& field = type.fields()[i];
400         size_t size = memoryLayout.size(*field.fType);
401         size_t alignment = memoryLayout.alignment(*field.fType);
402         const Layout& fieldLayout = field.fModifiers.fLayout;
403         if (fieldLayout.fOffset >= 0) {
404             if (fieldLayout.fOffset < (int) offset) {
405                 fErrors.error(type.fOffset,
406                               "offset of field '" + field.fName + "' must be at "
407                               "least " + to_string((int) offset));
408             }
409             if (fieldLayout.fOffset % alignment) {
410                 fErrors.error(type.fOffset,
411                               "offset of field '" + field.fName + "' must be a multiple"
412                               " of " + to_string((int) alignment));
413             }
414             offset = fieldLayout.fOffset;
415         } else {
416             size_t mod = offset % alignment;
417             if (mod) {
418                 offset += alignment - mod;
419             }
420         }
421         this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
422         this->writeLayout(fieldLayout, resultId, i);
423         if (field.fModifiers.fLayout.fBuiltin < 0) {
424             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
425                                    (SpvId) offset, fDecorationBuffer);
426         }
427         if (field.fType->kind() == Type::kMatrix_Kind) {
428             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
429                                    fDecorationBuffer);
430             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
431                                    (SpvId) memoryLayout.stride(*field.fType),
432                                    fDecorationBuffer);
433         }
434         if (!field.fType->highPrecision()) {
435             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
436                                    SpvDecorationRelaxedPrecision, fDecorationBuffer);
437         }
438         offset += size;
439         Type::Kind kind = field.fType->kind();
440         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
441             offset += alignment - offset % alignment;
442         }
443     }
444 }
445 
getActualType(const Type & type)446 Type SPIRVCodeGenerator::getActualType(const Type& type) {
447     if (type.isFloat()) {
448         return *fContext.fFloat_Type;
449     }
450     if (type.isSigned()) {
451         return *fContext.fInt_Type;
452     }
453     if (type.isUnsigned()) {
454         return *fContext.fUInt_Type;
455     }
456     if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
457         if (type.componentType() == *fContext.fHalf_Type) {
458             return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
459         }
460         if (type.componentType() == *fContext.fShort_Type ||
461             type.componentType() == *fContext.fByte_Type) {
462             return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
463         }
464         if (type.componentType() == *fContext.fUShort_Type ||
465             type.componentType() == *fContext.fUByte_Type) {
466             return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
467         }
468     }
469     return type;
470 }
471 
getType(const Type & type)472 SpvId SPIRVCodeGenerator::getType(const Type& type) {
473     return this->getType(type, fDefaultLayout);
474 }
475 
getType(const Type & rawType,const MemoryLayout & layout)476 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
477     Type type = this->getActualType(rawType);
478     String key = type.name() + to_string((int) layout.fStd);
479     auto entry = fTypeMap.find(key);
480     if (entry == fTypeMap.end()) {
481         SpvId result = this->nextId();
482         switch (type.kind()) {
483             case Type::kScalar_Kind:
484                 if (type == *fContext.fBool_Type) {
485                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
486                 } else if (type == *fContext.fInt_Type || type == *fContext.fShort_Type ||
487                            type == *fContext.fIntLiteral_Type) {
488                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
489                 } else if (type == *fContext.fUInt_Type || type == *fContext.fUShort_Type) {
490                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
491                 } else if (type == *fContext.fFloat_Type || type == *fContext.fHalf_Type ||
492                            type == *fContext.fFloatLiteral_Type) {
493                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
494                 } else if (type == *fContext.fDouble_Type) {
495                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
496                 } else {
497                     SkASSERT(false);
498                 }
499                 break;
500             case Type::kVector_Kind:
501                 this->writeInstruction(SpvOpTypeVector, result,
502                                        this->getType(type.componentType(), layout),
503                                        type.columns(), fConstantBuffer);
504                 break;
505             case Type::kMatrix_Kind:
506                 this->writeInstruction(SpvOpTypeMatrix, result,
507                                        this->getType(index_type(fContext, type), layout),
508                                        type.columns(), fConstantBuffer);
509                 break;
510             case Type::kStruct_Kind:
511                 this->writeStruct(type, layout, result);
512                 break;
513             case Type::kArray_Kind: {
514                 if (type.columns() > 0) {
515                     IntLiteral count(fContext, -1, type.columns());
516                     this->writeInstruction(SpvOpTypeArray, result,
517                                            this->getType(type.componentType(), layout),
518                                            this->writeIntLiteral(count), fConstantBuffer);
519                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
520                                            (int32_t) layout.stride(type),
521                                            fDecorationBuffer);
522                 } else {
523                     SkASSERT(false); // we shouldn't have any runtime-sized arrays right now
524                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
525                                            this->getType(type.componentType(), layout),
526                                            fConstantBuffer);
527                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
528                                            (int32_t) layout.stride(type),
529                                            fDecorationBuffer);
530                 }
531                 break;
532             }
533             case Type::kSampler_Kind: {
534                 SpvId image = result;
535                 if (SpvDimSubpassData != type.dimensions()) {
536                     image = this->getType(type.textureType(), layout);
537                 }
538                 if (SpvDimBuffer == type.dimensions()) {
539                     fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
540                 }
541                 if (SpvDimSubpassData != type.dimensions()) {
542                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
543                 }
544                 break;
545             }
546             case Type::kSeparateSampler_Kind: {
547                 this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer);
548                 break;
549             }
550             case Type::kTexture_Kind: {
551                 this->writeInstruction(SpvOpTypeImage, result,
552                                        this->getType(*fContext.fFloat_Type, layout),
553                                        type.dimensions(), type.isDepth(), type.isArrayed(),
554                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
555                                        SpvImageFormatUnknown, fConstantBuffer);
556                 fImageTypeMap[key] = result;
557                 break;
558             }
559             default:
560                 if (type == *fContext.fVoid_Type) {
561                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
562                 } else {
563                     ABORT("invalid type: %s", type.description().c_str());
564                 }
565         }
566         fTypeMap[key] = result;
567         return result;
568     }
569     return entry->second;
570 }
571 
getImageType(const Type & type)572 SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
573     SkASSERT(type.kind() == Type::kSampler_Kind);
574     this->getType(type);
575     String key = type.name() + to_string((int) fDefaultLayout.fStd);
576     SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
577     return fImageTypeMap[key];
578 }
579 
getFunctionType(const FunctionDeclaration & function)580 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
581     String key = function.fReturnType.description() + "(";
582     String separator;
583     for (size_t i = 0; i < function.fParameters.size(); i++) {
584         key += separator;
585         separator = ", ";
586         key += function.fParameters[i]->fType.description();
587     }
588     key += ")";
589     auto entry = fTypeMap.find(key);
590     if (entry == fTypeMap.end()) {
591         SpvId result = this->nextId();
592         int32_t length = 3 + (int32_t) function.fParameters.size();
593         SpvId returnType = this->getType(function.fReturnType);
594         std::vector<SpvId> parameterTypes;
595         for (size_t i = 0; i < function.fParameters.size(); i++) {
596             // glslang seems to treat all function arguments as pointers whether they need to be or
597             // not. I  was initially puzzled by this until I ran bizarre failures with certain
598             // patterns of function calls and control constructs, as exemplified by this minimal
599             // failure case:
600             //
601             // void sphere(float x) {
602             // }
603             //
604             // void map() {
605             //     sphere(1.0);
606             // }
607             //
608             // void main() {
609             //     for (int i = 0; i < 1; i++) {
610             //         map();
611             //     }
612             // }
613             //
614             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
615             // crashes. Making it take a float* and storing the argument in a temporary variable,
616             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
617             // the spec makes this make sense.
618 //            if (is_out(function->fParameters[i])) {
619                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
620                                                               SpvStorageClassFunction));
621 //            } else {
622 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
623 //            }
624         }
625         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
626         this->writeWord(result, fConstantBuffer);
627         this->writeWord(returnType, fConstantBuffer);
628         for (SpvId id : parameterTypes) {
629             this->writeWord(id, fConstantBuffer);
630         }
631         fTypeMap[key] = result;
632         return result;
633     }
634     return entry->second;
635 }
636 
getPointerType(const Type & type,SpvStorageClass_ storageClass)637 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
638     return this->getPointerType(type, fDefaultLayout, storageClass);
639 }
640 
getPointerType(const Type & rawType,const MemoryLayout & layout,SpvStorageClass_ storageClass)641 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
642                                          SpvStorageClass_ storageClass) {
643     Type type = this->getActualType(rawType);
644     String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
645     auto entry = fTypeMap.find(key);
646     if (entry == fTypeMap.end()) {
647         SpvId result = this->nextId();
648         this->writeInstruction(SpvOpTypePointer, result, storageClass,
649                                this->getType(type), fConstantBuffer);
650         fTypeMap[key] = result;
651         return result;
652     }
653     return entry->second;
654 }
655 
writeExpression(const Expression & expr,OutputStream & out)656 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
657     switch (expr.fKind) {
658         case Expression::kBinary_Kind:
659             return this->writeBinaryExpression((BinaryExpression&) expr, out);
660         case Expression::kBoolLiteral_Kind:
661             return this->writeBoolLiteral((BoolLiteral&) expr);
662         case Expression::kConstructor_Kind:
663             return this->writeConstructor((Constructor&) expr, out);
664         case Expression::kIntLiteral_Kind:
665             return this->writeIntLiteral((IntLiteral&) expr);
666         case Expression::kFieldAccess_Kind:
667             return this->writeFieldAccess(((FieldAccess&) expr), out);
668         case Expression::kFloatLiteral_Kind:
669             return this->writeFloatLiteral(((FloatLiteral&) expr));
670         case Expression::kFunctionCall_Kind:
671             return this->writeFunctionCall((FunctionCall&) expr, out);
672         case Expression::kPrefix_Kind:
673             return this->writePrefixExpression((PrefixExpression&) expr, out);
674         case Expression::kPostfix_Kind:
675             return this->writePostfixExpression((PostfixExpression&) expr, out);
676         case Expression::kSwizzle_Kind:
677             return this->writeSwizzle((Swizzle&) expr, out);
678         case Expression::kVariableReference_Kind:
679             return this->writeVariableReference((VariableReference&) expr, out);
680         case Expression::kTernary_Kind:
681             return this->writeTernaryExpression((TernaryExpression&) expr, out);
682         case Expression::kIndex_Kind:
683             return this->writeIndexExpression((IndexExpression&) expr, out);
684         default:
685             ABORT("unsupported expression: %s", expr.description().c_str());
686     }
687     return -1;
688 }
689 
writeIntrinsicCall(const FunctionCall & c,OutputStream & out)690 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
691     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
692     SkASSERT(intrinsic != fIntrinsicMap.end());
693     int32_t intrinsicId;
694     if (c.fArguments.size() > 0) {
695         const Type& type = c.fArguments[0]->fType;
696         if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
697             intrinsicId = std::get<1>(intrinsic->second);
698         } else if (is_signed(fContext, type)) {
699             intrinsicId = std::get<2>(intrinsic->second);
700         } else if (is_unsigned(fContext, type)) {
701             intrinsicId = std::get<3>(intrinsic->second);
702         } else if (is_bool(fContext, type)) {
703             intrinsicId = std::get<4>(intrinsic->second);
704         } else {
705             intrinsicId = std::get<1>(intrinsic->second);
706         }
707     } else {
708         intrinsicId = std::get<1>(intrinsic->second);
709     }
710     switch (std::get<0>(intrinsic->second)) {
711         case kGLSL_STD_450_IntrinsicKind: {
712             SpvId result = this->nextId();
713             std::vector<SpvId> arguments;
714             for (size_t i = 0; i < c.fArguments.size(); i++) {
715                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
716                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
717                 } else {
718                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
719                 }
720             }
721             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
722             this->writeWord(this->getType(c.fType), out);
723             this->writeWord(result, out);
724             this->writeWord(fGLSLExtendedInstructions, out);
725             this->writeWord(intrinsicId, out);
726             for (SpvId id : arguments) {
727                 this->writeWord(id, out);
728             }
729             return result;
730         }
731         case kSPIRV_IntrinsicKind: {
732             SpvId result = this->nextId();
733             std::vector<SpvId> arguments;
734             for (size_t i = 0; i < c.fArguments.size(); i++) {
735                 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
736                     arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
737                 } else {
738                     arguments.push_back(this->writeExpression(*c.fArguments[i], out));
739                 }
740             }
741             if (c.fType != *fContext.fVoid_Type) {
742                 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
743                 this->writeWord(this->getType(c.fType), out);
744                 this->writeWord(result, out);
745             } else {
746                 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
747             }
748             for (SpvId id : arguments) {
749                 this->writeWord(id, out);
750             }
751             return result;
752         }
753         case kSpecial_IntrinsicKind:
754             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
755         default:
756             ABORT("unsupported intrinsic kind");
757     }
758 }
759 
vectorize(const std::vector<std::unique_ptr<Expression>> & args,OutputStream & out)760 std::vector<SpvId> SPIRVCodeGenerator::vectorize(
761                                                const std::vector<std::unique_ptr<Expression>>& args,
762                                                OutputStream& out) {
763     int vectorSize = 0;
764     for (const auto& a : args) {
765         if (a->fType.kind() == Type::kVector_Kind) {
766             if (vectorSize) {
767                 SkASSERT(a->fType.columns() == vectorSize);
768             }
769             else {
770                 vectorSize = a->fType.columns();
771             }
772         }
773     }
774     std::vector<SpvId> result;
775     for (const auto& a : args) {
776         SpvId raw = this->writeExpression(*a, out);
777         if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
778             SpvId vector = this->nextId();
779             this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
780             this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
781             this->writeWord(vector, out);
782             for (int i = 0; i < vectorSize; i++) {
783                 this->writeWord(raw, out);
784             }
785             this->writePrecisionModifier(a->fType, vector);
786             result.push_back(vector);
787         } else {
788             result.push_back(raw);
789         }
790     }
791     return result;
792 }
793 
writeGLSLExtendedInstruction(const Type & type,SpvId id,SpvId floatInst,SpvId signedInst,SpvId unsignedInst,const std::vector<SpvId> & args,OutputStream & out)794 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
795                                                       SpvId signedInst, SpvId unsignedInst,
796                                                       const std::vector<SpvId>& args,
797                                                       OutputStream& out) {
798     this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
799     this->writeWord(this->getType(type), out);
800     this->writeWord(id, out);
801     this->writeWord(fGLSLExtendedInstructions, out);
802 
803     if (is_float(fContext, type)) {
804         this->writeWord(floatInst, out);
805     } else if (is_signed(fContext, type)) {
806         this->writeWord(signedInst, out);
807     } else if (is_unsigned(fContext, type)) {
808         this->writeWord(unsignedInst, out);
809     } else {
810         SkASSERT(false);
811     }
812     for (SpvId a : args) {
813         this->writeWord(a, out);
814     }
815 }
816 
writeSpecialIntrinsic(const FunctionCall & c,SpecialIntrinsic kind,OutputStream & out)817 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
818                                                 OutputStream& out) {
819     SpvId result = this->nextId();
820     switch (kind) {
821         case kAtan_SpecialIntrinsic: {
822             std::vector<SpvId> arguments;
823             for (size_t i = 0; i < c.fArguments.size(); i++) {
824                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
825             }
826             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
827             this->writeWord(this->getType(c.fType), out);
828             this->writeWord(result, out);
829             this->writeWord(fGLSLExtendedInstructions, out);
830             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
831             for (SpvId id : arguments) {
832                 this->writeWord(id, out);
833             }
834             break;
835         }
836         case kSampledImage_SpecialIntrinsic: {
837             SkASSERT(2 == c.fArguments.size());
838             SpvId img = this->writeExpression(*c.fArguments[0], out);
839             SpvId sampler = this->writeExpression(*c.fArguments[1], out);
840             this->writeInstruction(SpvOpSampledImage,
841                                    this->getType(c.fType),
842                                    result,
843                                    img,
844                                    sampler,
845                                    out);
846             break;
847         }
848         case kSubpassLoad_SpecialIntrinsic: {
849             SpvId img = this->writeExpression(*c.fArguments[0], out);
850             std::vector<std::unique_ptr<Expression>> args;
851             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
852             args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
853             Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
854             SpvId coords = this->writeConstantVector(ctor);
855             if (1 == c.fArguments.size()) {
856                 this->writeInstruction(SpvOpImageRead,
857                                        this->getType(c.fType),
858                                        result,
859                                        img,
860                                        coords,
861                                        out);
862             } else {
863                 SkASSERT(2 == c.fArguments.size());
864                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
865                 this->writeInstruction(SpvOpImageRead,
866                                        this->getType(c.fType),
867                                        result,
868                                        img,
869                                        coords,
870                                        SpvImageOperandsSampleMask,
871                                        sample,
872                                        out);
873             }
874             break;
875         }
876         case kTexture_SpecialIntrinsic: {
877             SpvOp_ op = SpvOpImageSampleImplicitLod;
878             switch (c.fArguments[0]->fType.dimensions()) {
879                 case SpvDim1D:
880                     if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
881                         op = SpvOpImageSampleProjImplicitLod;
882                     } else {
883                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
884                     }
885                     break;
886                 case SpvDim2D:
887                     if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
888                         op = SpvOpImageSampleProjImplicitLod;
889                     } else {
890                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
891                     }
892                     break;
893                 case SpvDim3D:
894                     if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
895                         op = SpvOpImageSampleProjImplicitLod;
896                     } else {
897                         SkASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
898                     }
899                     break;
900                 case SpvDimCube:   // fall through
901                 case SpvDimRect:   // fall through
902                 case SpvDimBuffer: // fall through
903                 case SpvDimSubpassData:
904                     break;
905             }
906             SpvId type = this->getType(c.fType);
907             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
908             SpvId uv = this->writeExpression(*c.fArguments[1], out);
909             if (c.fArguments.size() == 3) {
910                 this->writeInstruction(op, type, result, sampler, uv,
911                                        SpvImageOperandsBiasMask,
912                                        this->writeExpression(*c.fArguments[2], out),
913                                        out);
914             } else {
915                 SkASSERT(c.fArguments.size() == 2);
916                 if (fProgram.fSettings.fSharpenTextures) {
917                     FloatLiteral lodBias(fContext, -1, -0.5);
918                     this->writeInstruction(op, type, result, sampler, uv,
919                                            SpvImageOperandsBiasMask,
920                                            this->writeFloatLiteral(lodBias),
921                                            out);
922                 } else {
923                     this->writeInstruction(op, type, result, sampler, uv,
924                                            out);
925                 }
926             }
927             break;
928         }
929         case kMod_SpecialIntrinsic: {
930             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
931             SkASSERT(args.size() == 2);
932             const Type& operandType = c.fArguments[0]->fType;
933             SpvOp_ op;
934             if (is_float(fContext, operandType)) {
935                 op = SpvOpFMod;
936             } else if (is_signed(fContext, operandType)) {
937                 op = SpvOpSMod;
938             } else if (is_unsigned(fContext, operandType)) {
939                 op = SpvOpUMod;
940             } else {
941                 SkASSERT(false);
942                 return 0;
943             }
944             this->writeOpCode(op, 5, out);
945             this->writeWord(this->getType(operandType), out);
946             this->writeWord(result, out);
947             this->writeWord(args[0], out);
948             this->writeWord(args[1], out);
949             break;
950         }
951         case kDFdy_SpecialIntrinsic: {
952             SpvId fn = this->writeExpression(*c.fArguments[0], out);
953             this->writeOpCode(SpvOpDPdy, 4, out);
954             this->writeWord(this->getType(c.fType), out);
955             this->writeWord(result, out);
956             this->writeWord(fn, out);
957             if (fProgram.fSettings.fFlipY) {
958                 // Flipping Y also negates the Y derivatives.
959                 SpvId flipped = this->nextId();
960                 this->writeInstruction(SpvOpFNegate, this->getType(c.fType), flipped, result, out);
961                 this->writePrecisionModifier(c.fType, flipped);
962                 return flipped;
963             }
964             break;
965         }
966         case kClamp_SpecialIntrinsic: {
967             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
968             SkASSERT(args.size() == 3);
969             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
970                                                GLSLstd450UClamp, args, out);
971             break;
972         }
973         case kMax_SpecialIntrinsic: {
974             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
975             SkASSERT(args.size() == 2);
976             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
977                                                GLSLstd450UMax, args, out);
978             break;
979         }
980         case kMin_SpecialIntrinsic: {
981             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
982             SkASSERT(args.size() == 2);
983             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
984                                                GLSLstd450UMin, args, out);
985             break;
986         }
987         case kMix_SpecialIntrinsic: {
988             std::vector<SpvId> args = this->vectorize(c.fArguments, out);
989             SkASSERT(args.size() == 3);
990             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
991                                                SpvOpUndef, args, out);
992             break;
993         }
994         case kSaturate_SpecialIntrinsic: {
995             SkASSERT(c.fArguments.size() == 1);
996             std::vector<std::unique_ptr<Expression>> finalArgs;
997             finalArgs.push_back(c.fArguments[0]->clone());
998             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 0));
999             finalArgs.emplace_back(new FloatLiteral(fContext, -1, 1));
1000             std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
1001             this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1002                                                GLSLstd450UClamp, spvArgs, out);
1003             break;
1004         }
1005     }
1006     return result;
1007 }
1008 
writeFunctionCall(const FunctionCall & c,OutputStream & out)1009 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1010     const auto& entry = fFunctionMap.find(&c.fFunction);
1011     if (entry == fFunctionMap.end()) {
1012         return this->writeIntrinsicCall(c, out);
1013     }
1014     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1015     std::vector<std::tuple<SpvId, const Type*, std::unique_ptr<LValue>>> lvalues;
1016     std::vector<SpvId> arguments;
1017     for (size_t i = 0; i < c.fArguments.size(); i++) {
1018         // id of temporary variable that we will use to hold this argument, or 0 if it is being
1019         // passed directly
1020         SpvId tmpVar;
1021         // if we need a temporary var to store this argument, this is the value to store in the var
1022         SpvId tmpValueId;
1023         if (is_out(*c.fFunction.fParameters[i])) {
1024             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1025             SpvId ptr = lv->getPointer();
1026             if (ptr) {
1027                 arguments.push_back(ptr);
1028                 continue;
1029             } else {
1030                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1031                 // copy it into a temp, call the function, read the value out of the temp, and then
1032                 // update the lvalue.
1033                 tmpValueId = lv->load(out);
1034                 tmpVar = this->nextId();
1035                 lvalues.push_back(std::make_tuple(tmpVar, &c.fArguments[i]->fType, std::move(lv)));
1036             }
1037         } else {
1038             // see getFunctionType for an explanation of why we're always using pointer parameters
1039             tmpValueId = this->writeExpression(*c.fArguments[i], out);
1040             tmpVar = this->nextId();
1041         }
1042         this->writeInstruction(SpvOpVariable,
1043                                this->getPointerType(c.fArguments[i]->fType,
1044                                                     SpvStorageClassFunction),
1045                                tmpVar,
1046                                SpvStorageClassFunction,
1047                                fVariableBuffer);
1048         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1049         arguments.push_back(tmpVar);
1050     }
1051     SpvId result = this->nextId();
1052     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1053     this->writeWord(this->getType(c.fType), out);
1054     this->writeWord(result, out);
1055     this->writeWord(entry->second, out);
1056     for (SpvId id : arguments) {
1057         this->writeWord(id, out);
1058     }
1059     // now that the call is complete, we may need to update some lvalues with the new values of out
1060     // arguments
1061     for (const auto& tuple : lvalues) {
1062         SpvId load = this->nextId();
1063         this->writeInstruction(SpvOpLoad, getType(*std::get<1>(tuple)), load, std::get<0>(tuple),
1064                                out);
1065         this->writePrecisionModifier(*std::get<1>(tuple), load);
1066         std::get<2>(tuple)->store(load, out);
1067     }
1068     return result;
1069 }
1070 
writeConstantVector(const Constructor & c)1071 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1072     SkASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1073     SpvId result = this->nextId();
1074     std::vector<SpvId> arguments;
1075     for (size_t i = 0; i < c.fArguments.size(); i++) {
1076         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1077     }
1078     SpvId type = this->getType(c.fType);
1079     if (c.fArguments.size() == 1) {
1080         // with a single argument, a vector will have all of its entries equal to the argument
1081         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1082         this->writeWord(type, fConstantBuffer);
1083         this->writeWord(result, fConstantBuffer);
1084         for (int i = 0; i < c.fType.columns(); i++) {
1085             this->writeWord(arguments[0], fConstantBuffer);
1086         }
1087     } else {
1088         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1089                           fConstantBuffer);
1090         this->writeWord(type, fConstantBuffer);
1091         this->writeWord(result, fConstantBuffer);
1092         for (SpvId id : arguments) {
1093             this->writeWord(id, fConstantBuffer);
1094         }
1095     }
1096     return result;
1097 }
1098 
writeFloatConstructor(const Constructor & c,OutputStream & out)1099 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1100     SkASSERT(c.fType.isFloat());
1101     SkASSERT(c.fArguments.size() == 1);
1102     SkASSERT(c.fArguments[0]->fType.isNumber());
1103     SpvId result = this->nextId();
1104     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1105     if (c.fArguments[0]->fType.isSigned()) {
1106         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1107                                out);
1108     } else {
1109         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1110         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1111                                out);
1112     }
1113     return result;
1114 }
1115 
writeIntConstructor(const Constructor & c,OutputStream & out)1116 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1117     SkASSERT(c.fType.isSigned());
1118     SkASSERT(c.fArguments.size() == 1);
1119     SkASSERT(c.fArguments[0]->fType.isNumber());
1120     SpvId result = this->nextId();
1121     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1122     if (c.fArguments[0]->fType.isFloat()) {
1123         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1124                                out);
1125     }
1126     else {
1127         SkASSERT(c.fArguments[0]->fType.isUnsigned());
1128         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1129                                out);
1130     }
1131     return result;
1132 }
1133 
writeUIntConstructor(const Constructor & c,OutputStream & out)1134 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1135     SkASSERT(c.fType.isUnsigned());
1136     SkASSERT(c.fArguments.size() == 1);
1137     SkASSERT(c.fArguments[0]->fType.isNumber());
1138     SpvId result = this->nextId();
1139     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1140     if (c.fArguments[0]->fType.isFloat()) {
1141         this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1142                                out);
1143     } else {
1144         SkASSERT(c.fArguments[0]->fType.isSigned());
1145         this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1146                                out);
1147     }
1148     return result;
1149 }
1150 
writeUniformScaleMatrix(SpvId id,SpvId diagonal,const Type & type,OutputStream & out)1151 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1152                                                  OutputStream& out) {
1153     FloatLiteral zero(fContext, -1, 0);
1154     SpvId zeroId = this->writeFloatLiteral(zero);
1155     std::vector<SpvId> columnIds;
1156     for (int column = 0; column < type.columns(); column++) {
1157         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1158                           out);
1159         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1160                         out);
1161         SpvId columnId = this->nextId();
1162         this->writeWord(columnId, out);
1163         columnIds.push_back(columnId);
1164         for (int row = 0; row < type.columns(); row++) {
1165             this->writeWord(row == column ? diagonal : zeroId, out);
1166         }
1167         this->writePrecisionModifier(type, columnId);
1168     }
1169     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1170                       out);
1171     this->writeWord(this->getType(type), out);
1172     this->writeWord(id, out);
1173     for (SpvId id : columnIds) {
1174         this->writeWord(id, out);
1175     }
1176     this->writePrecisionModifier(type, id);
1177 }
1178 
writeMatrixCopy(SpvId id,SpvId src,const Type & srcType,const Type & dstType,OutputStream & out)1179 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1180                                          const Type& dstType, OutputStream& out) {
1181     SkASSERT(srcType.kind() == Type::kMatrix_Kind);
1182     SkASSERT(dstType.kind() == Type::kMatrix_Kind);
1183     SkASSERT(srcType.componentType() == dstType.componentType());
1184     SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1185                                                                            srcType.rows(),
1186                                                                            1));
1187     SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1188                                                                            dstType.rows(),
1189                                                                            1));
1190     SpvId zeroId;
1191     if (dstType.componentType() == *fContext.fFloat_Type) {
1192         FloatLiteral zero(fContext, -1, 0.0);
1193         zeroId = this->writeFloatLiteral(zero);
1194     } else if (dstType.componentType() == *fContext.fInt_Type) {
1195         IntLiteral zero(fContext, -1, 0);
1196         zeroId = this->writeIntLiteral(zero);
1197     } else {
1198         ABORT("unsupported matrix component type");
1199     }
1200     SpvId zeroColumn = 0;
1201     SpvId columns[4];
1202     for (int i = 0; i < dstType.columns(); i++) {
1203         if (i < srcType.columns()) {
1204             // we're still inside the src matrix, copy the column
1205             SpvId srcColumn = this->nextId();
1206             this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1207             this->writePrecisionModifier(dstType, srcColumn);
1208             SpvId dstColumn;
1209             if (srcType.rows() == dstType.rows()) {
1210                 // columns are equal size, don't need to do anything
1211                 dstColumn = srcColumn;
1212             }
1213             else if (dstType.rows() > srcType.rows()) {
1214                 // dst column is bigger, need to zero-pad it
1215                 dstColumn = this->nextId();
1216                 int delta = dstType.rows() - srcType.rows();
1217                 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1218                 this->writeWord(dstColumnType, out);
1219                 this->writeWord(dstColumn, out);
1220                 this->writeWord(srcColumn, out);
1221                 for (int i = 0; i < delta; ++i) {
1222                     this->writeWord(zeroId, out);
1223                 }
1224                 this->writePrecisionModifier(dstType, dstColumn);
1225             }
1226             else {
1227                 // dst column is smaller, need to swizzle the src column
1228                 dstColumn = this->nextId();
1229                 int count = dstType.rows();
1230                 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1231                 this->writeWord(dstColumnType, out);
1232                 this->writeWord(dstColumn, out);
1233                 this->writeWord(srcColumn, out);
1234                 this->writeWord(srcColumn, out);
1235                 for (int i = 0; i < count; i++) {
1236                     this->writeWord(i, out);
1237                 }
1238                 this->writePrecisionModifier(dstType, dstColumn);
1239             }
1240             columns[i] = dstColumn;
1241         } else {
1242             // we're past the end of the src matrix, need a vector of zeroes
1243             if (!zeroColumn) {
1244                 zeroColumn = this->nextId();
1245                 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1246                 this->writeWord(dstColumnType, out);
1247                 this->writeWord(zeroColumn, out);
1248                 for (int i = 0; i < dstType.rows(); ++i) {
1249                     this->writeWord(zeroId, out);
1250                 }
1251                 this->writePrecisionModifier(dstType, zeroColumn);
1252             }
1253             columns[i] = zeroColumn;
1254         }
1255     }
1256     this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1257     this->writeWord(this->getType(dstType), out);
1258     this->writeWord(id, out);
1259     for (int i = 0; i < dstType.columns(); i++) {
1260         this->writeWord(columns[i], out);
1261     }
1262     this->writePrecisionModifier(dstType, id);
1263 }
1264 
addColumnEntry(SpvId columnType,Precision precision,std::vector<SpvId> * currentColumn,std::vector<SpvId> * columnIds,int * currentCount,int rows,SpvId entry,OutputStream & out)1265 void SPIRVCodeGenerator::addColumnEntry(SpvId columnType, Precision precision,
1266                                         std::vector<SpvId>* currentColumn,
1267                                         std::vector<SpvId>* columnIds,
1268                                         int* currentCount, int rows, SpvId entry,
1269                                         OutputStream& out) {
1270     SkASSERT(*currentCount < rows);
1271     ++(*currentCount);
1272     currentColumn->push_back(entry);
1273     if (*currentCount == rows) {
1274         *currentCount = 0;
1275         this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn->size(), out);
1276         this->writeWord(columnType, out);
1277         SpvId columnId = this->nextId();
1278         this->writeWord(columnId, out);
1279         columnIds->push_back(columnId);
1280         for (SpvId id : *currentColumn) {
1281             this->writeWord(id, out);
1282         }
1283         currentColumn->clear();
1284         this->writePrecisionModifier(precision, columnId);
1285     }
1286 }
1287 
writeMatrixConstructor(const Constructor & c,OutputStream & out)1288 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1289     SkASSERT(c.fType.kind() == Type::kMatrix_Kind);
1290     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1291     // an instruction
1292     std::vector<SpvId> arguments;
1293     for (size_t i = 0; i < c.fArguments.size(); i++) {
1294         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1295     }
1296     SpvId result = this->nextId();
1297     int rows = c.fType.rows();
1298     int columns = c.fType.columns();
1299     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1300         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1301     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1302         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1303     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1304         SkASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1305         SkASSERT(c.fArguments[0]->fType.columns() == 4);
1306         SpvId componentType = this->getType(c.fType.componentType());
1307         SpvId v[4];
1308         for (int i = 0; i < 4; ++i) {
1309             v[i] = this->nextId();
1310             this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1311         }
1312         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1313         SpvId column1 = this->nextId();
1314         this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1315         SpvId column2 = this->nextId();
1316         this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1317         this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1318                                column2, out);
1319     } else {
1320         SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, rows, 1));
1321         std::vector<SpvId> columnIds;
1322         // ids of vectors and scalars we have written to the current column so far
1323         std::vector<SpvId> currentColumn;
1324         // the total number of scalars represented by currentColumn's entries
1325         int currentCount = 0;
1326         Precision precision = c.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
1327         for (size_t i = 0; i < arguments.size(); i++) {
1328             if (currentCount == 0 && c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1329                     c.fArguments[i]->fType.columns() == c.fType.rows()) {
1330                 // this is a complete column by itself
1331                 columnIds.push_back(arguments[i]);
1332             } else {
1333                 if (c.fArguments[i]->fType.columns() == 1) {
1334                     this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1335                                          &currentCount, rows, arguments[i], out);
1336                 } else {
1337                     SpvId componentType = this->getType(c.fArguments[i]->fType.componentType());
1338                     for (int j = 0; j < c.fArguments[i]->fType.columns(); ++j) {
1339                         SpvId swizzle = this->nextId();
1340                         this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1341                                                arguments[i], j, out);
1342                         this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
1343                                              &currentCount, rows, swizzle, out);
1344                     }
1345                 }
1346             }
1347         }
1348         SkASSERT(columnIds.size() == (size_t) columns);
1349         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1350         this->writeWord(this->getType(c.fType), out);
1351         this->writeWord(result, out);
1352         for (SpvId id : columnIds) {
1353             this->writeWord(id, out);
1354         }
1355     }
1356     this->writePrecisionModifier(c.fType, result);
1357     return result;
1358 }
1359 
writeVectorConstructor(const Constructor & c,OutputStream & out)1360 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1361     SkASSERT(c.fType.kind() == Type::kVector_Kind);
1362     if (c.isConstant()) {
1363         return this->writeConstantVector(c);
1364     }
1365     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1366     // an instruction
1367     std::vector<SpvId> arguments;
1368     for (size_t i = 0; i < c.fArguments.size(); i++) {
1369         if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1370             // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1371             // extract the components and convert them in that case manually. On top of that,
1372             // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1373             // doesn't handle vector arguments at all, so we always extract vector components and
1374             // pass them into OpCreateComposite individually.
1375             SpvId vec = this->writeExpression(*c.fArguments[i], out);
1376             SpvOp_ op = SpvOpUndef;
1377             const Type& src = c.fArguments[i]->fType.componentType();
1378             const Type& dst = c.fType.componentType();
1379             if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1380                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1381                     if (c.fArguments.size() == 1) {
1382                         return vec;
1383                     }
1384                 } else if (src == *fContext.fInt_Type ||
1385                            src == *fContext.fShort_Type ||
1386                            src == *fContext.fByte_Type) {
1387                     op = SpvOpConvertSToF;
1388                 } else if (src == *fContext.fUInt_Type ||
1389                            src == *fContext.fUShort_Type ||
1390                            src == *fContext.fUByte_Type) {
1391                     op = SpvOpConvertUToF;
1392                 } else {
1393                     SkASSERT(false);
1394                 }
1395             } else if (dst == *fContext.fInt_Type ||
1396                        dst == *fContext.fShort_Type ||
1397                        dst == *fContext.fByte_Type) {
1398                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1399                     op = SpvOpConvertFToS;
1400                 } else if (src == *fContext.fInt_Type ||
1401                            src == *fContext.fShort_Type ||
1402                            src == *fContext.fByte_Type) {
1403                     if (c.fArguments.size() == 1) {
1404                         return vec;
1405                     }
1406                 } else if (src == *fContext.fUInt_Type ||
1407                            src == *fContext.fUShort_Type ||
1408                            src == *fContext.fUByte_Type) {
1409                     op = SpvOpBitcast;
1410                 } else {
1411                     SkASSERT(false);
1412                 }
1413             } else if (dst == *fContext.fUInt_Type ||
1414                        dst == *fContext.fUShort_Type ||
1415                        dst == *fContext.fUByte_Type) {
1416                 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1417                     op = SpvOpConvertFToS;
1418                 } else if (src == *fContext.fInt_Type ||
1419                            src == *fContext.fShort_Type ||
1420                            src == *fContext.fByte_Type) {
1421                     op = SpvOpBitcast;
1422                 } else if (src == *fContext.fUInt_Type ||
1423                            src == *fContext.fUShort_Type ||
1424                            src == *fContext.fUByte_Type) {
1425                     if (c.fArguments.size() == 1) {
1426                         return vec;
1427                     }
1428                 } else {
1429                     SkASSERT(false);
1430                 }
1431             }
1432             for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1433                 SpvId swizzle = this->nextId();
1434                 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1435                                        out);
1436                 if (op != SpvOpUndef) {
1437                     SpvId cast = this->nextId();
1438                     this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1439                     arguments.push_back(cast);
1440                 } else {
1441                     arguments.push_back(swizzle);
1442                 }
1443             }
1444         } else {
1445             arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1446         }
1447     }
1448     SpvId result = this->nextId();
1449     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1450         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1451         this->writeWord(this->getType(c.fType), out);
1452         this->writeWord(result, out);
1453         for (int i = 0; i < c.fType.columns(); i++) {
1454             this->writeWord(arguments[0], out);
1455         }
1456     } else {
1457         SkASSERT(arguments.size() > 1);
1458         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1459         this->writeWord(this->getType(c.fType), out);
1460         this->writeWord(result, out);
1461         for (SpvId id : arguments) {
1462             this->writeWord(id, out);
1463         }
1464     }
1465     return result;
1466 }
1467 
writeArrayConstructor(const Constructor & c,OutputStream & out)1468 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1469     SkASSERT(c.fType.kind() == Type::kArray_Kind);
1470     // go ahead and write the arguments so we don't try to write new instructions in the middle of
1471     // an instruction
1472     std::vector<SpvId> arguments;
1473     for (size_t i = 0; i < c.fArguments.size(); i++) {
1474         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1475     }
1476     SpvId result = this->nextId();
1477     this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1478     this->writeWord(this->getType(c.fType), out);
1479     this->writeWord(result, out);
1480     for (SpvId id : arguments) {
1481         this->writeWord(id, out);
1482     }
1483     return result;
1484 }
1485 
writeConstructor(const Constructor & c,OutputStream & out)1486 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1487     if (c.fArguments.size() == 1 &&
1488         this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1489         return this->writeExpression(*c.fArguments[0], out);
1490     }
1491     if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1492         return this->writeFloatConstructor(c, out);
1493     } else if (c.fType == *fContext.fInt_Type ||
1494                c.fType == *fContext.fShort_Type ||
1495                c.fType == *fContext.fByte_Type) {
1496         return this->writeIntConstructor(c, out);
1497     } else if (c.fType == *fContext.fUInt_Type ||
1498                c.fType == *fContext.fUShort_Type ||
1499                c.fType == *fContext.fUByte_Type) {
1500         return this->writeUIntConstructor(c, out);
1501     }
1502     switch (c.fType.kind()) {
1503         case Type::kVector_Kind:
1504             return this->writeVectorConstructor(c, out);
1505         case Type::kMatrix_Kind:
1506             return this->writeMatrixConstructor(c, out);
1507         case Type::kArray_Kind:
1508             return this->writeArrayConstructor(c, out);
1509         default:
1510             ABORT("unsupported constructor: %s", c.description().c_str());
1511     }
1512 }
1513 
get_storage_class(const Modifiers & modifiers)1514 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1515     if (modifiers.fFlags & Modifiers::kIn_Flag) {
1516         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1517         return SpvStorageClassInput;
1518     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1519         SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1520         return SpvStorageClassOutput;
1521     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1522         if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1523             return SpvStorageClassPushConstant;
1524         }
1525         return SpvStorageClassUniform;
1526     } else {
1527         return SpvStorageClassFunction;
1528     }
1529 }
1530 
get_storage_class(const Expression & expr)1531 SpvStorageClass_ get_storage_class(const Expression& expr) {
1532     switch (expr.fKind) {
1533         case Expression::kVariableReference_Kind: {
1534             const Variable& var = ((VariableReference&) expr).fVariable;
1535             if (var.fStorage != Variable::kGlobal_Storage) {
1536                 return SpvStorageClassFunction;
1537             }
1538             SpvStorageClass_ result = get_storage_class(var.fModifiers);
1539             if (result == SpvStorageClassFunction) {
1540                 result = SpvStorageClassPrivate;
1541             }
1542             return result;
1543         }
1544         case Expression::kFieldAccess_Kind:
1545             return get_storage_class(*((FieldAccess&) expr).fBase);
1546         case Expression::kIndex_Kind:
1547             return get_storage_class(*((IndexExpression&) expr).fBase);
1548         default:
1549             return SpvStorageClassFunction;
1550     }
1551 }
1552 
getAccessChain(const Expression & expr,OutputStream & out)1553 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1554     std::vector<SpvId> chain;
1555     switch (expr.fKind) {
1556         case Expression::kIndex_Kind: {
1557             IndexExpression& indexExpr = (IndexExpression&) expr;
1558             chain = this->getAccessChain(*indexExpr.fBase, out);
1559             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1560             break;
1561         }
1562         case Expression::kFieldAccess_Kind: {
1563             FieldAccess& fieldExpr = (FieldAccess&) expr;
1564             chain = this->getAccessChain(*fieldExpr.fBase, out);
1565             IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1566             chain.push_back(this->writeIntLiteral(index));
1567             break;
1568         }
1569         default: {
1570             SpvId id = this->getLValue(expr, out)->getPointer();
1571             SkASSERT(id != 0);
1572             chain.push_back(id);
1573         }
1574     }
1575     return chain;
1576 }
1577 
1578 class PointerLValue : public SPIRVCodeGenerator::LValue {
1579 public:
PointerLValue(SPIRVCodeGenerator & gen,SpvId pointer,SpvId type,SPIRVCodeGenerator::Precision precision)1580     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type,
1581                   SPIRVCodeGenerator::Precision precision)
1582     : fGen(gen)
1583     , fPointer(pointer)
1584     , fType(type)
1585     , fPrecision(precision) {}
1586 
getPointer()1587     virtual SpvId getPointer() override {
1588         return fPointer;
1589     }
1590 
load(OutputStream & out)1591     virtual SpvId load(OutputStream& out) override {
1592         SpvId result = fGen.nextId();
1593         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1594         fGen.writePrecisionModifier(fPrecision, result);
1595         return result;
1596     }
1597 
store(SpvId value,OutputStream & out)1598     virtual void store(SpvId value, OutputStream& out) override {
1599         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1600     }
1601 
1602 private:
1603     SPIRVCodeGenerator& fGen;
1604     const SpvId fPointer;
1605     const SpvId fType;
1606     const SPIRVCodeGenerator::Precision fPrecision;
1607 };
1608 
1609 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1610 public:
SwizzleLValue(SPIRVCodeGenerator & gen,SpvId vecPointer,const std::vector<int> & components,const Type & baseType,const Type & swizzleType,SPIRVCodeGenerator::Precision precision)1611     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1612                   const Type& baseType, const Type& swizzleType,
1613                   SPIRVCodeGenerator::Precision precision)
1614     : fGen(gen)
1615     , fVecPointer(vecPointer)
1616     , fComponents(components)
1617     , fBaseType(baseType)
1618     , fSwizzleType(swizzleType)
1619     , fPrecision(precision) {}
1620 
getPointer()1621     virtual SpvId getPointer() override {
1622         return 0;
1623     }
1624 
load(OutputStream & out)1625     virtual SpvId load(OutputStream& out) override {
1626         SpvId base = fGen.nextId();
1627         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1628         fGen.writePrecisionModifier(fPrecision, base);
1629         SpvId result = fGen.nextId();
1630         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1631         fGen.writeWord(fGen.getType(fSwizzleType), out);
1632         fGen.writeWord(result, out);
1633         fGen.writeWord(base, out);
1634         fGen.writeWord(base, out);
1635         for (int component : fComponents) {
1636             fGen.writeWord(component, out);
1637         }
1638         fGen.writePrecisionModifier(fPrecision, result);
1639         return result;
1640     }
1641 
store(SpvId value,OutputStream & out)1642     virtual void store(SpvId value, OutputStream& out) override {
1643         // use OpVectorShuffle to mix and match the vector components. We effectively create
1644         // a virtual vector out of the concatenation of the left and right vectors, and then
1645         // select components from this virtual vector to make the result vector. For
1646         // instance, given:
1647         // float3L = ...;
1648         // float3R = ...;
1649         // L.xz = R.xy;
1650         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1651         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1652         // (3, 1, 4).
1653         SpvId base = fGen.nextId();
1654         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1655         SpvId shuffle = fGen.nextId();
1656         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1657         fGen.writeWord(fGen.getType(fBaseType), out);
1658         fGen.writeWord(shuffle, out);
1659         fGen.writeWord(base, out);
1660         fGen.writeWord(value, out);
1661         for (int i = 0; i < fBaseType.columns(); i++) {
1662             // current offset into the virtual vector, defaults to pulling the unmodified
1663             // value from the left side
1664             int offset = i;
1665             // check to see if we are writing this component
1666             for (size_t j = 0; j < fComponents.size(); j++) {
1667                 if (fComponents[j] == i) {
1668                     // we're writing to this component, so adjust the offset to pull from
1669                     // the correct component of the right side instead of preserving the
1670                     // value from the left
1671                     offset = (int) (j + fBaseType.columns());
1672                     break;
1673                 }
1674             }
1675             fGen.writeWord(offset, out);
1676         }
1677         fGen.writePrecisionModifier(fPrecision, shuffle);
1678         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1679     }
1680 
1681 private:
1682     SPIRVCodeGenerator& fGen;
1683     const SpvId fVecPointer;
1684     const std::vector<int>& fComponents;
1685     const Type& fBaseType;
1686     const Type& fSwizzleType;
1687     const SPIRVCodeGenerator::Precision fPrecision;
1688 };
1689 
getLValue(const Expression & expr,OutputStream & out)1690 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1691                                                                           OutputStream& out) {
1692     Precision precision = expr.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
1693     switch (expr.fKind) {
1694         case Expression::kVariableReference_Kind: {
1695             SpvId type;
1696             const Variable& var = ((VariableReference&) expr).fVariable;
1697             if (var.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
1698                 type = this->getType(Type("sk_in", Type::kArray_Kind, var.fType.componentType(),
1699                                           fSkInCount));
1700             } else {
1701                 type = this->getType(expr.fType);
1702             }
1703             auto entry = fVariableMap.find(&var);
1704             SkASSERT(entry != fVariableMap.end());
1705             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(*this,
1706                                                                                  entry->second,
1707                                                                                  type,
1708                                                                                  precision));
1709         }
1710         case Expression::kIndex_Kind: // fall through
1711         case Expression::kFieldAccess_Kind: {
1712             std::vector<SpvId> chain = this->getAccessChain(expr, out);
1713             SpvId member = this->nextId();
1714             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1715             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1716             this->writeWord(member, out);
1717             for (SpvId idx : chain) {
1718                 this->writeWord(idx, out);
1719             }
1720             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1721                                                                         *this,
1722                                                                         member,
1723                                                                         this->getType(expr.fType),
1724                                                                         precision));
1725         }
1726         case Expression::kSwizzle_Kind: {
1727             Swizzle& swizzle = (Swizzle&) expr;
1728             size_t count = swizzle.fComponents.size();
1729             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1730             SkASSERT(base);
1731             if (count == 1) {
1732                 IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1733                 SpvId member = this->nextId();
1734                 this->writeInstruction(SpvOpAccessChain,
1735                                        this->getPointerType(swizzle.fType,
1736                                                             get_storage_class(*swizzle.fBase)),
1737                                        member,
1738                                        base,
1739                                        this->writeIntLiteral(index),
1740                                        out);
1741                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1742                                                                        *this,
1743                                                                        member,
1744                                                                        this->getType(expr.fType),
1745                                                                        precision));
1746             } else {
1747                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1748                                                                               *this,
1749                                                                               base,
1750                                                                               swizzle.fComponents,
1751                                                                               swizzle.fBase->fType,
1752                                                                               expr.fType,
1753                                                                               precision));
1754             }
1755         }
1756         case Expression::kTernary_Kind: {
1757             TernaryExpression& t = (TernaryExpression&) expr;
1758             SpvId test = this->writeExpression(*t.fTest, out);
1759             SpvId end = this->nextId();
1760             SpvId ifTrueLabel = this->nextId();
1761             SpvId ifFalseLabel = this->nextId();
1762             this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1763             this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1764             this->writeLabel(ifTrueLabel, out);
1765             SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1766             SkASSERT(ifTrue);
1767             this->writeInstruction(SpvOpBranch, end, out);
1768             ifTrueLabel = fCurrentBlock;
1769             SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1770             SkASSERT(ifFalse);
1771             ifFalseLabel = fCurrentBlock;
1772             this->writeInstruction(SpvOpBranch, end, out);
1773             SpvId result = this->nextId();
1774             this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1775                        ifTrueLabel, ifFalse, ifFalseLabel, out);
1776             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1777                                                                        *this,
1778                                                                        result,
1779                                                                        this->getType(expr.fType),
1780                                                                        precision));
1781         }
1782         default:
1783             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1784             // to the need to store values in temporary variables during function calls (see
1785             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1786             // caught by IRGenerator
1787             SpvId result = this->nextId();
1788             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1789             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1790                                    fVariableBuffer);
1791             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1792             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1793                                                                        *this,
1794                                                                        result,
1795                                                                        this->getType(expr.fType),
1796                                                                        precision));
1797     }
1798 }
1799 
writeVariableReference(const VariableReference & ref,OutputStream & out)1800 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1801     SpvId result = this->nextId();
1802     auto entry = fVariableMap.find(&ref.fVariable);
1803     SkASSERT(entry != fVariableMap.end());
1804     SpvId var = entry->second;
1805     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1806     this->writePrecisionModifier(ref.fVariable.fType, result);
1807     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1808         fProgram.fSettings.fFlipY) {
1809         // need to remap to a top-left coordinate system
1810         if (fRTHeightStructId == (SpvId) -1) {
1811             // height variable hasn't been written yet
1812             std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1813             SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
1814             std::vector<Type::Field> fields;
1815             SkASSERT(fProgram.fSettings.fRTHeightOffset >= 0);
1816             fields.emplace_back(Modifiers(Layout(0, -1, fProgram.fSettings.fRTHeightOffset, -1,
1817                                                  -1, -1, -1, -1, Layout::Format::kUnspecified,
1818                                                  Layout::kUnspecified_Primitive, -1, -1, "",
1819                                                  Layout::kNo_Key, Layout::CType::kDefault), 0),
1820                                 SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1821             StringFragment name("sksl_synthetic_uniforms");
1822             Type intfStruct(-1, name, fields);
1823             int binding;
1824             int set;
1825 #ifdef SK_VULKAN
1826             const GrVkCaps* vkCaps = fProgram.fSettings.fVkCaps;
1827             SkASSERT(vkCaps);
1828             binding = vkCaps->getFragmentUniformBinding();
1829             set = vkCaps->getFragmentUniformSet();
1830 #else
1831             binding = 0;
1832             set = 0;
1833 #endif
1834             Layout layout(0, -1, -1, binding, -1, set, -1, -1, Layout::Format::kUnspecified,
1835                           Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1836                           Layout::CType::kDefault);
1837             Variable* intfVar = (Variable*) fSynthetics.takeOwnership(std::unique_ptr<Symbol>(
1838                                            new Variable(-1,
1839                                                         Modifiers(layout, Modifiers::kUniform_Flag),
1840                                                         name,
1841                                                         intfStruct,
1842                                                         Variable::kGlobal_Storage)));
1843             InterfaceBlock intf(-1, intfVar, name, String(""),
1844                                 std::vector<std::unique_ptr<Expression>>(), st);
1845             fRTHeightStructId = this->writeInterfaceBlock(intf);
1846             fRTHeightFieldIndex = 0;
1847         }
1848         SkASSERT(fRTHeightFieldIndex != (SpvId) -1);
1849         // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, gl_FragCoord.w)
1850         SpvId xId = this->nextId();
1851         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1852                                result, 0, out);
1853         IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1854         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1855         SpvId heightPtr = this->nextId();
1856         this->writeOpCode(SpvOpAccessChain, 5, out);
1857         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1858         this->writeWord(heightPtr, out);
1859         this->writeWord(fRTHeightStructId, out);
1860         this->writeWord(fieldIndexId, out);
1861         SpvId heightRead = this->nextId();
1862         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1863                                heightPtr, out);
1864         SpvId rawYId = this->nextId();
1865         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1866                                result, 1, out);
1867         SpvId flippedYId = this->nextId();
1868         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1869                                heightRead, rawYId, out);
1870         FloatLiteral zero(fContext, -1, 0.0);
1871         SpvId zeroId = writeFloatLiteral(zero);
1872         FloatLiteral one(fContext, -1, 1.0);
1873         SpvId wId = this->nextId();
1874         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), wId,
1875                                result, 3, out);
1876         SpvId flipped = this->nextId();
1877         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1878         this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1879         this->writeWord(flipped, out);
1880         this->writeWord(xId, out);
1881         this->writeWord(flippedYId, out);
1882         this->writeWord(zeroId, out);
1883         this->writeWord(wId, out);
1884         return flipped;
1885     }
1886     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
1887         !fProgram.fSettings.fFlipY) {
1888         // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
1889         // the default convention of "counter-clockwise face is front".
1890         SpvId inverse = this->nextId();
1891         this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fBool_Type), inverse,
1892                                result, out);
1893         return inverse;
1894     }
1895     return result;
1896 }
1897 
writeIndexExpression(const IndexExpression & expr,OutputStream & out)1898 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1899     if (expr.fBase->fType.kind() == Type::Kind::kVector_Kind) {
1900         SpvId base = this->writeExpression(*expr.fBase, out);
1901         SpvId index = this->writeExpression(*expr.fIndex, out);
1902         SpvId result = this->nextId();
1903         this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.fType), result, base,
1904                                index, out);
1905         return result;
1906     }
1907     return getLValue(expr, out)->load(out);
1908 }
1909 
writeFieldAccess(const FieldAccess & f,OutputStream & out)1910 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1911     return getLValue(f, out)->load(out);
1912 }
1913 
writeSwizzle(const Swizzle & swizzle,OutputStream & out)1914 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1915     SpvId base = this->writeExpression(*swizzle.fBase, out);
1916     SpvId result = this->nextId();
1917     size_t count = swizzle.fComponents.size();
1918     if (count == 1) {
1919         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1920                                swizzle.fComponents[0], out);
1921     } else {
1922         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1923         this->writeWord(this->getType(swizzle.fType), out);
1924         this->writeWord(result, out);
1925         this->writeWord(base, out);
1926         SpvId other = base;
1927         for (int c : swizzle.fComponents) {
1928             if (c < 0) {
1929                 if (!fConstantZeroOneVector) {
1930                     FloatLiteral zero(fContext, -1, 0);
1931                     SpvId zeroId = this->writeFloatLiteral(zero);
1932                     FloatLiteral one(fContext, -1, 1);
1933                     SpvId oneId = this->writeFloatLiteral(one);
1934                     SpvId type = this->getType(*fContext.fFloat2_Type);
1935                     fConstantZeroOneVector = this->nextId();
1936                     this->writeOpCode(SpvOpConstantComposite, 5, fConstantBuffer);
1937                     this->writeWord(type, fConstantBuffer);
1938                     this->writeWord(fConstantZeroOneVector, fConstantBuffer);
1939                     this->writeWord(zeroId, fConstantBuffer);
1940                     this->writeWord(oneId, fConstantBuffer);
1941                 }
1942                 other = fConstantZeroOneVector;
1943                 break;
1944             }
1945         }
1946         this->writeWord(other, out);
1947         for (int component : swizzle.fComponents) {
1948             if (component == SKSL_SWIZZLE_0) {
1949                 this->writeWord(swizzle.fBase->fType.columns(), out);
1950             } else if (component == SKSL_SWIZZLE_1) {
1951                 this->writeWord(swizzle.fBase->fType.columns() + 1, out);
1952             } else {
1953                 this->writeWord(component, out);
1954             }
1955         }
1956     }
1957     return result;
1958 }
1959 
writeBinaryOperation(const Type & resultType,const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ ifFloat,SpvOp_ ifInt,SpvOp_ ifUInt,SpvOp_ ifBool,OutputStream & out)1960 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1961                                                const Type& operandType, SpvId lhs,
1962                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1963                                                SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1964     SpvId result = this->nextId();
1965     if (is_float(fContext, operandType)) {
1966         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1967     } else if (is_signed(fContext, operandType)) {
1968         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1969     } else if (is_unsigned(fContext, operandType)) {
1970         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1971     } else if (operandType == *fContext.fBool_Type) {
1972         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1973         return result; // skip RelaxedPrecision check
1974     } else {
1975         ABORT("invalid operandType: %s", operandType.description().c_str());
1976     }
1977     if (getActualType(resultType) == operandType && !resultType.highPrecision()) {
1978         this->writeInstruction(SpvOpDecorate, result, SpvDecorationRelaxedPrecision,
1979                                fDecorationBuffer);
1980     }
1981     return result;
1982 }
1983 
foldToBool(SpvId id,const Type & operandType,SpvOp op,OutputStream & out)1984 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
1985                                      OutputStream& out) {
1986     if (operandType.kind() == Type::kVector_Kind) {
1987         SpvId result = this->nextId();
1988         this->writeInstruction(op, this->getType(*fContext.fBool_Type), result, id, out);
1989         return result;
1990     }
1991     return id;
1992 }
1993 
writeMatrixComparison(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,SpvOp_ vectorMergeOperator,SpvOp_ mergeOperator,OutputStream & out)1994 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1995                                                 SpvOp_ floatOperator, SpvOp_ intOperator,
1996                                                 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
1997                                                 OutputStream& out) {
1998     SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1999     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
2000     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2001                                                                             operandType.rows(),
2002                                                                             1));
2003     SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
2004                                                                     operandType.rows(),
2005                                                                     1));
2006     SpvId boolType = this->getType(*fContext.fBool_Type);
2007     SpvId result = 0;
2008     for (int i = 0; i < operandType.columns(); i++) {
2009         SpvId columnL = this->nextId();
2010         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2011         SpvId columnR = this->nextId();
2012         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2013         SpvId compare = this->nextId();
2014         this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2015         SpvId merge = this->nextId();
2016         this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2017         if (result != 0) {
2018             SpvId next = this->nextId();
2019             this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2020             result = next;
2021         }
2022         else {
2023             result = merge;
2024         }
2025     }
2026     return result;
2027 }
2028 
writeComponentwiseMatrixBinary(const Type & operandType,SpvId lhs,SpvId rhs,SpvOp_ floatOperator,SpvOp_ intOperator,OutputStream & out)2029 SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2030                                                          SpvId rhs, SpvOp_ floatOperator,
2031                                                          SpvOp_ intOperator,
2032                                                          OutputStream& out) {
2033     SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
2034     SkASSERT(operandType.kind() == Type::kMatrix_Kind);
2035     SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2036                                                                             operandType.rows(),
2037                                                                             1));
2038     SpvId columns[4];
2039     for (int i = 0; i < operandType.columns(); i++) {
2040         SpvId columnL = this->nextId();
2041         this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2042         SpvId columnR = this->nextId();
2043         this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2044         columns[i] = this->nextId();
2045         this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
2046     }
2047     SpvId result = this->nextId();
2048     this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
2049     this->writeWord(this->getType(operandType), out);
2050     this->writeWord(result, out);
2051     for (int i = 0; i < operandType.columns(); i++) {
2052         this->writeWord(columns[i], out);
2053     }
2054     return result;
2055 }
2056 
create_literal_1(const Context & context,const Type & type)2057 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2058     if (type.isInteger()) {
2059         return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
2060     }
2061     else if (type.isFloat()) {
2062         return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
2063     } else {
2064         ABORT("math is unsupported on type '%s'", type.name().c_str());
2065     }
2066 }
2067 
writeBinaryExpression(const Type & leftType,SpvId lhs,Token::Kind op,const Type & rightType,SpvId rhs,const Type & resultType,OutputStream & out)2068 SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
2069                                                 const Type& rightType, SpvId rhs,
2070                                                 const Type& resultType, OutputStream& out) {
2071     Type tmp("<invalid>");
2072     // overall type we are operating on: float2, int, uint4...
2073     const Type* operandType;
2074     // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2075     // handling in SPIR-V
2076     if (this->getActualType(leftType) != this->getActualType(rightType)) {
2077         if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
2078             if (op == Token::SLASH) {
2079                 SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
2080                 SpvId inverse = this->nextId();
2081                 this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
2082                 rhs = inverse;
2083                 op = Token::STAR;
2084             }
2085             if (op == Token::STAR) {
2086                 SpvId result = this->nextId();
2087                 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2088                                        result, lhs, rhs, out);
2089                 return result;
2090             }
2091             // promote number to vector
2092             SpvId vec = this->nextId();
2093             const Type& vecType = leftType;
2094             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2095             this->writeWord(this->getType(vecType), out);
2096             this->writeWord(vec, out);
2097             for (int i = 0; i < vecType.columns(); i++) {
2098                 this->writeWord(rhs, out);
2099             }
2100             rhs = vec;
2101             operandType = &leftType;
2102         } else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
2103             if (op == Token::STAR) {
2104                 SpvId result = this->nextId();
2105                 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2106                                        result, rhs, lhs, out);
2107                 return result;
2108             }
2109             // promote number to vector
2110             SpvId vec = this->nextId();
2111             const Type& vecType = rightType;
2112             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2113             this->writeWord(this->getType(vecType), out);
2114             this->writeWord(vec, out);
2115             for (int i = 0; i < vecType.columns(); i++) {
2116                 this->writeWord(lhs, out);
2117             }
2118             lhs = vec;
2119             operandType = &rightType;
2120         } else if (leftType.kind() == Type::kMatrix_Kind) {
2121             SpvOp_ spvop;
2122             if (rightType.kind() == Type::kMatrix_Kind) {
2123                 spvop = SpvOpMatrixTimesMatrix;
2124             } else if (rightType.kind() == Type::kVector_Kind) {
2125                 spvop = SpvOpMatrixTimesVector;
2126             } else {
2127                 SkASSERT(rightType.kind() == Type::kScalar_Kind);
2128                 spvop = SpvOpMatrixTimesScalar;
2129             }
2130             SpvId result = this->nextId();
2131             this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2132             return result;
2133         } else if (rightType.kind() == Type::kMatrix_Kind) {
2134             SpvId result = this->nextId();
2135             if (leftType.kind() == Type::kVector_Kind) {
2136                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
2137                                        lhs, rhs, out);
2138             } else {
2139                 SkASSERT(leftType.kind() == Type::kScalar_Kind);
2140                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
2141                                        rhs, lhs, out);
2142             }
2143             return result;
2144         } else {
2145             SkASSERT(false);
2146             return -1;
2147         }
2148     } else {
2149         tmp = this->getActualType(leftType);
2150         operandType = &tmp;
2151         SkASSERT(*operandType == this->getActualType(rightType));
2152     }
2153     switch (op) {
2154         case Token::EQEQ: {
2155             if (operandType->kind() == Type::kMatrix_Kind) {
2156                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2157                                                    SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2158             }
2159             SkASSERT(resultType == *fContext.fBool_Type);
2160             const Type* tmpType;
2161             if (operandType->kind() == Type::kVector_Kind) {
2162                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2163                                                            operandType->columns(),
2164                                                            operandType->rows());
2165             } else {
2166                 tmpType = &resultType;
2167             }
2168             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2169                                                                SpvOpFOrdEqual, SpvOpIEqual,
2170                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
2171                                     *operandType, SpvOpAll, out);
2172         }
2173         case Token::NEQ:
2174             if (operandType->kind() == Type::kMatrix_Kind) {
2175                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2176                                                    SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2177             }
2178             SkASSERT(resultType == *fContext.fBool_Type);
2179             const Type* tmpType;
2180             if (operandType->kind() == Type::kVector_Kind) {
2181                 tmpType = &fContext.fBool_Type->toCompound(fContext,
2182                                                            operandType->columns(),
2183                                                            operandType->rows());
2184             } else {
2185                 tmpType = &resultType;
2186             }
2187             return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2188                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
2189                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
2190                                                                out),
2191                                     *operandType, SpvOpAny, out);
2192         case Token::GT:
2193             SkASSERT(resultType == *fContext.fBool_Type);
2194             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2195                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2196                                               SpvOpUGreaterThan, SpvOpUndef, out);
2197         case Token::LT:
2198             SkASSERT(resultType == *fContext.fBool_Type);
2199             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2200                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2201         case Token::GTEQ:
2202             SkASSERT(resultType == *fContext.fBool_Type);
2203             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2204                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2205                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
2206         case Token::LTEQ:
2207             SkASSERT(resultType == *fContext.fBool_Type);
2208             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2209                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2210                                               SpvOpULessThanEqual, SpvOpUndef, out);
2211         case Token::PLUS:
2212             if (leftType.kind() == Type::kMatrix_Kind &&
2213                 rightType.kind() == Type::kMatrix_Kind) {
2214                 SkASSERT(leftType == rightType);
2215                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2216                                                             SpvOpFAdd, SpvOpIAdd, out);
2217             }
2218             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2219                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2220         case Token::MINUS:
2221             if (leftType.kind() == Type::kMatrix_Kind &&
2222                 rightType.kind() == Type::kMatrix_Kind) {
2223                 SkASSERT(leftType == rightType);
2224                 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
2225                                                             SpvOpFSub, SpvOpISub, out);
2226             }
2227             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2228                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
2229         case Token::STAR:
2230             if (leftType.kind() == Type::kMatrix_Kind &&
2231                 rightType.kind() == Type::kMatrix_Kind) {
2232                 // matrix multiply
2233                 SpvId result = this->nextId();
2234                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2235                                        lhs, rhs, out);
2236                 return result;
2237             }
2238             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2239                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2240         case Token::SLASH:
2241             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2242                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2243         case Token::PERCENT:
2244             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2245                                               SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2246         case Token::SHL:
2247             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2248                                               SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2249                                               SpvOpUndef, out);
2250         case Token::SHR:
2251             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2252                                               SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2253                                               SpvOpUndef, out);
2254         case Token::BITWISEAND:
2255             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2256                                               SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2257         case Token::BITWISEOR:
2258             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2259                                               SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2260         case Token::BITWISEXOR:
2261             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2262                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2263         case Token::COMMA:
2264             return rhs;
2265         default:
2266             SkASSERT(false);
2267             return -1;
2268     }
2269 }
2270 
writeBinaryExpression(const BinaryExpression & b,OutputStream & out)2271 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2272     // handle cases where we don't necessarily evaluate both LHS and RHS
2273     switch (b.fOperator) {
2274         case Token::EQ: {
2275             SpvId rhs = this->writeExpression(*b.fRight, out);
2276             this->getLValue(*b.fLeft, out)->store(rhs, out);
2277             return rhs;
2278         }
2279         case Token::LOGICALAND:
2280             return this->writeLogicalAnd(b, out);
2281         case Token::LOGICALOR:
2282             return this->writeLogicalOr(b, out);
2283         default:
2284             break;
2285     }
2286 
2287     std::unique_ptr<LValue> lvalue;
2288     SpvId lhs;
2289     if (is_assignment(b.fOperator)) {
2290         lvalue = this->getLValue(*b.fLeft, out);
2291         lhs = lvalue->load(out);
2292     } else {
2293         lvalue = nullptr;
2294         lhs = this->writeExpression(*b.fLeft, out);
2295     }
2296     SpvId rhs = this->writeExpression(*b.fRight, out);
2297     SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
2298                                                b.fRight->fType, rhs, b.fType, out);
2299     if (lvalue) {
2300         lvalue->store(result, out);
2301     }
2302     return result;
2303 }
2304 
writeLogicalAnd(const BinaryExpression & a,OutputStream & out)2305 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2306     SkASSERT(a.fOperator == Token::LOGICALAND);
2307     BoolLiteral falseLiteral(fContext, -1, false);
2308     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2309     SpvId lhs = this->writeExpression(*a.fLeft, out);
2310     SpvId rhsLabel = this->nextId();
2311     SpvId end = this->nextId();
2312     SpvId lhsBlock = fCurrentBlock;
2313     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2314     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2315     this->writeLabel(rhsLabel, out);
2316     SpvId rhs = this->writeExpression(*a.fRight, out);
2317     SpvId rhsBlock = fCurrentBlock;
2318     this->writeInstruction(SpvOpBranch, end, out);
2319     this->writeLabel(end, out);
2320     SpvId result = this->nextId();
2321     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2322                            lhsBlock, rhs, rhsBlock, out);
2323     return result;
2324 }
2325 
writeLogicalOr(const BinaryExpression & o,OutputStream & out)2326 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2327     SkASSERT(o.fOperator == Token::LOGICALOR);
2328     BoolLiteral trueLiteral(fContext, -1, true);
2329     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2330     SpvId lhs = this->writeExpression(*o.fLeft, out);
2331     SpvId rhsLabel = this->nextId();
2332     SpvId end = this->nextId();
2333     SpvId lhsBlock = fCurrentBlock;
2334     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2335     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2336     this->writeLabel(rhsLabel, out);
2337     SpvId rhs = this->writeExpression(*o.fRight, out);
2338     SpvId rhsBlock = fCurrentBlock;
2339     this->writeInstruction(SpvOpBranch, end, out);
2340     this->writeLabel(end, out);
2341     SpvId result = this->nextId();
2342     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2343                            lhsBlock, rhs, rhsBlock, out);
2344     return result;
2345 }
2346 
writeTernaryExpression(const TernaryExpression & t,OutputStream & out)2347 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2348     SpvId test = this->writeExpression(*t.fTest, out);
2349     if (t.fIfTrue->fType.columns() == 1 && t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2350         // both true and false are constants, can just use OpSelect
2351         SpvId result = this->nextId();
2352         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2353         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2354         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2355                                out);
2356         return result;
2357     }
2358     // was originally using OpPhi to choose the result, but for some reason that is crashing on
2359     // Adreno. Switched to storing the result in a temp variable as glslang does.
2360     SpvId var = this->nextId();
2361     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2362                            var, SpvStorageClassFunction, fVariableBuffer);
2363     SpvId trueLabel = this->nextId();
2364     SpvId falseLabel = this->nextId();
2365     SpvId end = this->nextId();
2366     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2367     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2368     this->writeLabel(trueLabel, out);
2369     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2370     this->writeInstruction(SpvOpBranch, end, out);
2371     this->writeLabel(falseLabel, out);
2372     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2373     this->writeInstruction(SpvOpBranch, end, out);
2374     this->writeLabel(end, out);
2375     SpvId result = this->nextId();
2376     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2377     this->writePrecisionModifier(t.fType, result);
2378     return result;
2379 }
2380 
writePrefixExpression(const PrefixExpression & p,OutputStream & out)2381 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2382     if (p.fOperator == Token::MINUS) {
2383         SpvId result = this->nextId();
2384         SpvId typeId = this->getType(p.fType);
2385         SpvId expr = this->writeExpression(*p.fOperand, out);
2386         if (is_float(fContext, p.fType)) {
2387             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2388         } else if (is_signed(fContext, p.fType)) {
2389             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2390         } else {
2391             ABORT("unsupported prefix expression %s", p.description().c_str());
2392         }
2393         this->writePrecisionModifier(p.fType, result);
2394         return result;
2395     }
2396     switch (p.fOperator) {
2397         case Token::PLUS:
2398             return this->writeExpression(*p.fOperand, out);
2399         case Token::PLUSPLUS: {
2400             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2401             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2402             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2403                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2404                                                       out);
2405             lv->store(result, out);
2406             return result;
2407         }
2408         case Token::MINUSMINUS: {
2409             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2410             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2411             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2412                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2413                                                       out);
2414             lv->store(result, out);
2415             return result;
2416         }
2417         case Token::LOGICALNOT: {
2418             SkASSERT(p.fOperand->fType == *fContext.fBool_Type);
2419             SpvId result = this->nextId();
2420             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2421                                    this->writeExpression(*p.fOperand, out), out);
2422             return result;
2423         }
2424         case Token::BITWISENOT: {
2425             SpvId result = this->nextId();
2426             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2427                                    this->writeExpression(*p.fOperand, out), out);
2428             return result;
2429         }
2430         default:
2431             ABORT("unsupported prefix expression: %s", p.description().c_str());
2432     }
2433 }
2434 
writePostfixExpression(const PostfixExpression & p,OutputStream & out)2435 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2436     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2437     SpvId result = lv->load(out);
2438     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2439     switch (p.fOperator) {
2440         case Token::PLUSPLUS: {
2441             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2442                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2443             lv->store(temp, out);
2444             return result;
2445         }
2446         case Token::MINUSMINUS: {
2447             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2448                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
2449             lv->store(temp, out);
2450             return result;
2451         }
2452         default:
2453             ABORT("unsupported postfix expression %s", p.description().c_str());
2454     }
2455 }
2456 
writeBoolLiteral(const BoolLiteral & b)2457 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2458     if (b.fValue) {
2459         if (fBoolTrue == 0) {
2460             fBoolTrue = this->nextId();
2461             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2462                                    fConstantBuffer);
2463         }
2464         return fBoolTrue;
2465     } else {
2466         if (fBoolFalse == 0) {
2467             fBoolFalse = this->nextId();
2468             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2469                                    fConstantBuffer);
2470         }
2471         return fBoolFalse;
2472     }
2473 }
2474 
writeIntLiteral(const IntLiteral & i)2475 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2476     ConstantType type;
2477     if (i.fType == *fContext.fInt_Type) {
2478         type = ConstantType::kInt;
2479     } else if (i.fType == *fContext.fUInt_Type) {
2480         type = ConstantType::kUInt;
2481     } else if (i.fType == *fContext.fShort_Type || i.fType == *fContext.fByte_Type) {
2482         type = ConstantType::kShort;
2483     } else if (i.fType == *fContext.fUShort_Type || i.fType == *fContext.fUByte_Type) {
2484         type = ConstantType::kUShort;
2485     } else {
2486         SkASSERT(false);
2487     }
2488     std::pair<ConstantValue, ConstantType> key(i.fValue, type);
2489     auto entry = fNumberConstants.find(key);
2490     if (entry == fNumberConstants.end()) {
2491         SpvId result = this->nextId();
2492         this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2493                                fConstantBuffer);
2494         fNumberConstants[key] = result;
2495         return result;
2496     }
2497     return entry->second;
2498 }
2499 
writeFloatLiteral(const FloatLiteral & f)2500 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2501     if (f.fType != *fContext.fDouble_Type) {
2502         ConstantType type;
2503         if (f.fType == *fContext.fHalf_Type) {
2504             type = ConstantType::kHalf;
2505         } else {
2506             type = ConstantType::kFloat;
2507         }
2508         float value = (float) f.fValue;
2509         std::pair<ConstantValue, ConstantType> key(f.fValue, type);
2510         auto entry = fNumberConstants.find(key);
2511         if (entry == fNumberConstants.end()) {
2512             SpvId result = this->nextId();
2513             uint32_t bits;
2514             SkASSERT(sizeof(bits) == sizeof(value));
2515             memcpy(&bits, &value, sizeof(bits));
2516             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2517                                    fConstantBuffer);
2518             fNumberConstants[key] = result;
2519             return result;
2520         }
2521         return entry->second;
2522     } else {
2523         std::pair<ConstantValue, ConstantType> key(f.fValue, ConstantType::kDouble);
2524         auto entry = fNumberConstants.find(key);
2525         if (entry == fNumberConstants.end()) {
2526             SpvId result = this->nextId();
2527             uint64_t bits;
2528             SkASSERT(sizeof(bits) == sizeof(f.fValue));
2529             memcpy(&bits, &f.fValue, sizeof(bits));
2530             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2531                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
2532             fNumberConstants[key] = result;
2533             return result;
2534         }
2535         return entry->second;
2536     }
2537 }
2538 
writeFunctionStart(const FunctionDeclaration & f,OutputStream & out)2539 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2540     SpvId result = fFunctionMap[&f];
2541     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2542                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2543     this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2544     for (size_t i = 0; i < f.fParameters.size(); i++) {
2545         SpvId id = this->nextId();
2546         fVariableMap[f.fParameters[i]] = id;
2547         SpvId type;
2548         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2549         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2550     }
2551     return result;
2552 }
2553 
writeFunction(const FunctionDefinition & f,OutputStream & out)2554 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2555     fVariableBuffer.reset();
2556     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2557     this->writeLabel(this->nextId(), out);
2558     StringStream bodyBuffer;
2559     this->writeBlock((Block&) *f.fBody, bodyBuffer);
2560     write_stringstream(fVariableBuffer, out);
2561     if (f.fDeclaration.fName == "main") {
2562         write_stringstream(fGlobalInitializersBuffer, out);
2563     }
2564     write_stringstream(bodyBuffer, out);
2565     if (fCurrentBlock) {
2566         if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2567             this->writeInstruction(SpvOpReturn, out);
2568         } else {
2569             this->writeInstruction(SpvOpUnreachable, out);
2570         }
2571     }
2572     this->writeInstruction(SpvOpFunctionEnd, out);
2573     return result;
2574 }
2575 
writeLayout(const Layout & layout,SpvId target)2576 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2577     if (layout.fLocation >= 0) {
2578         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2579                                fDecorationBuffer);
2580     }
2581     if (layout.fBinding >= 0) {
2582         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2583                                fDecorationBuffer);
2584     }
2585     if (layout.fIndex >= 0) {
2586         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2587                                fDecorationBuffer);
2588     }
2589     if (layout.fSet >= 0) {
2590         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2591                                fDecorationBuffer);
2592     }
2593     if (layout.fInputAttachmentIndex >= 0) {
2594         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2595                                layout.fInputAttachmentIndex, fDecorationBuffer);
2596         fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2597     }
2598     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2599         layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
2600         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2601                                fDecorationBuffer);
2602     }
2603 }
2604 
writeLayout(const Layout & layout,SpvId target,int member)2605 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2606     if (layout.fLocation >= 0) {
2607         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2608                                layout.fLocation, fDecorationBuffer);
2609     }
2610     if (layout.fBinding >= 0) {
2611         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2612                                layout.fBinding, fDecorationBuffer);
2613     }
2614     if (layout.fIndex >= 0) {
2615         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2616                                layout.fIndex, fDecorationBuffer);
2617     }
2618     if (layout.fSet >= 0) {
2619         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2620                                layout.fSet, fDecorationBuffer);
2621     }
2622     if (layout.fInputAttachmentIndex >= 0) {
2623         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2624                                layout.fInputAttachmentIndex, fDecorationBuffer);
2625     }
2626     if (layout.fBuiltin >= 0) {
2627         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2628                                layout.fBuiltin, fDecorationBuffer);
2629     }
2630 }
2631 
update_sk_in_count(const Modifiers & m,int * outSkInCount)2632 static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
2633     switch (m.fLayout.fPrimitive) {
2634         case Layout::kPoints_Primitive:
2635             *outSkInCount = 1;
2636             break;
2637         case Layout::kLines_Primitive:
2638             *outSkInCount = 2;
2639             break;
2640         case Layout::kLinesAdjacency_Primitive:
2641             *outSkInCount = 4;
2642             break;
2643         case Layout::kTriangles_Primitive:
2644             *outSkInCount = 3;
2645             break;
2646         case Layout::kTrianglesAdjacency_Primitive:
2647             *outSkInCount = 6;
2648             break;
2649         default:
2650             return;
2651     }
2652 }
2653 
writeInterfaceBlock(const InterfaceBlock & intf)2654 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2655     bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2656     bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2657                                Layout::kPushConstant_Flag));
2658     MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
2659                                 MemoryLayout(MemoryLayout::k430_Standard) :
2660                                 fDefaultLayout;
2661     SpvId result = this->nextId();
2662     const Type* type = &intf.fVariable.fType;
2663     if (fProgram.fInputs.fRTHeight) {
2664         SkASSERT(fRTHeightStructId == (SpvId) -1);
2665         SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
2666         std::vector<Type::Field> fields = type->fields();
2667         fRTHeightStructId = result;
2668         fRTHeightFieldIndex = fields.size();
2669         fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2670         type = new Type(type->fOffset, type->name(), fields);
2671     }
2672     SpvId typeId;
2673     if (intf.fVariable.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2674         for (const auto& e : fProgram) {
2675             if (e.fKind == ProgramElement::kModifiers_Kind) {
2676                 const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
2677                 update_sk_in_count(m, &fSkInCount);
2678             }
2679         }
2680         typeId = this->getType(Type("sk_in", Type::kArray_Kind, intf.fVariable.fType.componentType(),
2681                                   fSkInCount), memoryLayout);
2682     } else {
2683         typeId = this->getType(*type, memoryLayout);
2684     }
2685     if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2686         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2687     } else if (intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
2688         this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2689     }
2690     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2691     SpvId ptrType = this->nextId();
2692     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2693     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2694     Layout layout = intf.fVariable.fModifiers.fLayout;
2695     if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
2696         layout.fSet = 0;
2697     }
2698     this->writeLayout(layout, result);
2699     fVariableMap[&intf.fVariable] = result;
2700     if (fProgram.fInputs.fRTHeight) {
2701         delete type;
2702     }
2703     return result;
2704 }
2705 
writePrecisionModifier(const Type & type,SpvId id)2706 void SPIRVCodeGenerator::writePrecisionModifier(const Type& type, SpvId id) {
2707     this->writePrecisionModifier(type.highPrecision() ? Precision::kHigh : Precision::kLow, id);
2708 }
2709 
writePrecisionModifier(Precision precision,SpvId id)2710 void SPIRVCodeGenerator::writePrecisionModifier(Precision precision, SpvId id) {
2711     if (precision == Precision::kLow) {
2712         this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2713     }
2714 }
2715 
2716 #define BUILTIN_IGNORE 9999
writeGlobalVars(Program::Kind kind,const VarDeclarations & decl,OutputStream & out)2717 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2718                                          OutputStream& out) {
2719     for (size_t i = 0; i < decl.fVars.size(); i++) {
2720         if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2721             continue;
2722         }
2723         const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2724         const Variable* var = varDecl.fVar;
2725         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2726         // in the OpenGL backend.
2727         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2728                                            Modifiers::kWriteOnly_Flag |
2729                                            Modifiers::kCoherent_Flag |
2730                                            Modifiers::kVolatile_Flag |
2731                                            Modifiers::kRestrict_Flag)));
2732         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2733             continue;
2734         }
2735         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2736             kind != Program::kFragment_Kind) {
2737             SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
2738             continue;
2739         }
2740         if (!var->fReadCount && !var->fWriteCount &&
2741                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2742                                             Modifiers::kOut_Flag |
2743                                             Modifiers::kUniform_Flag |
2744                                             Modifiers::kBuffer_Flag))) {
2745             // variable is dead and not an input / output var (the Vulkan debug layers complain if
2746             // we elide an interface var, even if it's dead)
2747             continue;
2748         }
2749         SpvStorageClass_ storageClass;
2750         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2751             storageClass = SpvStorageClassInput;
2752         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2753             storageClass = SpvStorageClassOutput;
2754         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2755             if (var->fType.kind() == Type::kSampler_Kind ||
2756                 var->fType.kind() == Type::kSeparateSampler_Kind ||
2757                 var->fType.kind() == Type::kTexture_Kind) {
2758                 storageClass = SpvStorageClassUniformConstant;
2759             } else {
2760                 storageClass = SpvStorageClassUniform;
2761             }
2762         } else {
2763             storageClass = SpvStorageClassPrivate;
2764         }
2765         SpvId id = this->nextId();
2766         fVariableMap[var] = id;
2767         SpvId type;
2768         if (var->fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
2769             type = this->getPointerType(Type("sk_in", Type::kArray_Kind,
2770                                              var->fType.componentType(), fSkInCount),
2771                                         storageClass);
2772         } else {
2773             type = this->getPointerType(var->fType, storageClass);
2774         }
2775         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2776         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2777         this->writePrecisionModifier(var->fType, id);
2778         if (varDecl.fValue) {
2779             SkASSERT(!fCurrentBlock);
2780             fCurrentBlock = -1;
2781             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2782             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2783             fCurrentBlock = 0;
2784         }
2785         this->writeLayout(var->fModifiers.fLayout, id);
2786         if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2787             this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2788         }
2789         if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2790             this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2791                                    fDecorationBuffer);
2792         }
2793     }
2794 }
2795 
writeVarDeclarations(const VarDeclarations & decl,OutputStream & out)2796 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2797     for (const auto& stmt : decl.fVars) {
2798         SkASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2799         VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2800         const Variable* var = varDecl.fVar;
2801         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2802         // in the OpenGL backend.
2803         SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2804                                            Modifiers::kWriteOnly_Flag |
2805                                            Modifiers::kCoherent_Flag |
2806                                            Modifiers::kVolatile_Flag |
2807                                            Modifiers::kRestrict_Flag)));
2808         SpvId id = this->nextId();
2809         fVariableMap[var] = id;
2810         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2811         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2812         this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2813         if (varDecl.fValue) {
2814             SpvId value = this->writeExpression(*varDecl.fValue, out);
2815             this->writeInstruction(SpvOpStore, id, value, out);
2816         }
2817     }
2818 }
2819 
writeStatement(const Statement & s,OutputStream & out)2820 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2821     switch (s.fKind) {
2822         case Statement::kNop_Kind:
2823             break;
2824         case Statement::kBlock_Kind:
2825             this->writeBlock((Block&) s, out);
2826             break;
2827         case Statement::kExpression_Kind:
2828             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2829             break;
2830         case Statement::kReturn_Kind:
2831             this->writeReturnStatement((ReturnStatement&) s, out);
2832             break;
2833         case Statement::kVarDeclarations_Kind:
2834             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2835             break;
2836         case Statement::kIf_Kind:
2837             this->writeIfStatement((IfStatement&) s, out);
2838             break;
2839         case Statement::kFor_Kind:
2840             this->writeForStatement((ForStatement&) s, out);
2841             break;
2842         case Statement::kWhile_Kind:
2843             this->writeWhileStatement((WhileStatement&) s, out);
2844             break;
2845         case Statement::kDo_Kind:
2846             this->writeDoStatement((DoStatement&) s, out);
2847             break;
2848         case Statement::kSwitch_Kind:
2849             this->writeSwitchStatement((SwitchStatement&) s, out);
2850             break;
2851         case Statement::kBreak_Kind:
2852             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2853             break;
2854         case Statement::kContinue_Kind:
2855             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2856             break;
2857         case Statement::kDiscard_Kind:
2858             this->writeInstruction(SpvOpKill, out);
2859             break;
2860         default:
2861             ABORT("unsupported statement: %s", s.description().c_str());
2862     }
2863 }
2864 
writeBlock(const Block & b,OutputStream & out)2865 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2866     for (size_t i = 0; i < b.fStatements.size(); i++) {
2867         this->writeStatement(*b.fStatements[i], out);
2868     }
2869 }
2870 
writeIfStatement(const IfStatement & stmt,OutputStream & out)2871 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2872     SpvId test = this->writeExpression(*stmt.fTest, out);
2873     SpvId ifTrue = this->nextId();
2874     SpvId ifFalse = this->nextId();
2875     if (stmt.fIfFalse) {
2876         SpvId end = this->nextId();
2877         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2878         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2879         this->writeLabel(ifTrue, out);
2880         this->writeStatement(*stmt.fIfTrue, out);
2881         if (fCurrentBlock) {
2882             this->writeInstruction(SpvOpBranch, end, out);
2883         }
2884         this->writeLabel(ifFalse, out);
2885         this->writeStatement(*stmt.fIfFalse, out);
2886         if (fCurrentBlock) {
2887             this->writeInstruction(SpvOpBranch, end, out);
2888         }
2889         this->writeLabel(end, out);
2890     } else {
2891         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2892         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2893         this->writeLabel(ifTrue, out);
2894         this->writeStatement(*stmt.fIfTrue, out);
2895         if (fCurrentBlock) {
2896             this->writeInstruction(SpvOpBranch, ifFalse, out);
2897         }
2898         this->writeLabel(ifFalse, out);
2899     }
2900 }
2901 
writeForStatement(const ForStatement & f,OutputStream & out)2902 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2903     if (f.fInitializer) {
2904         this->writeStatement(*f.fInitializer, out);
2905     }
2906     SpvId header = this->nextId();
2907     SpvId start = this->nextId();
2908     SpvId body = this->nextId();
2909     SpvId next = this->nextId();
2910     fContinueTarget.push(next);
2911     SpvId end = this->nextId();
2912     fBreakTarget.push(end);
2913     this->writeInstruction(SpvOpBranch, header, out);
2914     this->writeLabel(header, out);
2915     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2916     this->writeInstruction(SpvOpBranch, start, out);
2917     this->writeLabel(start, out);
2918     if (f.fTest) {
2919         SpvId test = this->writeExpression(*f.fTest, out);
2920         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2921     }
2922     this->writeLabel(body, out);
2923     this->writeStatement(*f.fStatement, out);
2924     if (fCurrentBlock) {
2925         this->writeInstruction(SpvOpBranch, next, out);
2926     }
2927     this->writeLabel(next, out);
2928     if (f.fNext) {
2929         this->writeExpression(*f.fNext, out);
2930     }
2931     this->writeInstruction(SpvOpBranch, header, out);
2932     this->writeLabel(end, out);
2933     fBreakTarget.pop();
2934     fContinueTarget.pop();
2935 }
2936 
writeWhileStatement(const WhileStatement & w,OutputStream & out)2937 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2938     SpvId header = this->nextId();
2939     SpvId start = this->nextId();
2940     SpvId body = this->nextId();
2941     SpvId continueTarget = this->nextId();
2942     fContinueTarget.push(continueTarget);
2943     SpvId end = this->nextId();
2944     fBreakTarget.push(end);
2945     this->writeInstruction(SpvOpBranch, header, out);
2946     this->writeLabel(header, out);
2947     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
2948     this->writeInstruction(SpvOpBranch, start, out);
2949     this->writeLabel(start, out);
2950     SpvId test = this->writeExpression(*w.fTest, out);
2951     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2952     this->writeLabel(body, out);
2953     this->writeStatement(*w.fStatement, out);
2954     if (fCurrentBlock) {
2955         this->writeInstruction(SpvOpBranch, continueTarget, out);
2956     }
2957     this->writeLabel(continueTarget, out);
2958     this->writeInstruction(SpvOpBranch, header, out);
2959     this->writeLabel(end, out);
2960     fBreakTarget.pop();
2961     fContinueTarget.pop();
2962 }
2963 
writeDoStatement(const DoStatement & d,OutputStream & out)2964 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2965     // We believe the do loop code below will work, but Skia doesn't actually use them and
2966     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2967     // the time being, we just fail with an error due to the lack of testing. If you encounter this
2968     // message, simply remove the error call below to see whether our do loop support actually
2969     // works.
2970     fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2971                   "SkSLSPIRVCodeGenerator.cpp for details");
2972 
2973     SpvId header = this->nextId();
2974     SpvId start = this->nextId();
2975     SpvId next = this->nextId();
2976     SpvId continueTarget = this->nextId();
2977     fContinueTarget.push(continueTarget);
2978     SpvId end = this->nextId();
2979     fBreakTarget.push(end);
2980     this->writeInstruction(SpvOpBranch, header, out);
2981     this->writeLabel(header, out);
2982     this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
2983     this->writeInstruction(SpvOpBranch, start, out);
2984     this->writeLabel(start, out);
2985     this->writeStatement(*d.fStatement, out);
2986     if (fCurrentBlock) {
2987         this->writeInstruction(SpvOpBranch, next, out);
2988     }
2989     this->writeLabel(next, out);
2990     SpvId test = this->writeExpression(*d.fTest, out);
2991     this->writeInstruction(SpvOpBranchConditional, test, continueTarget, end, out);
2992     this->writeLabel(continueTarget, out);
2993     this->writeInstruction(SpvOpBranch, header, out);
2994     this->writeLabel(end, out);
2995     fBreakTarget.pop();
2996     fContinueTarget.pop();
2997 }
2998 
writeSwitchStatement(const SwitchStatement & s,OutputStream & out)2999 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3000     SpvId value = this->writeExpression(*s.fValue, out);
3001     std::vector<SpvId> labels;
3002     SpvId end = this->nextId();
3003     SpvId defaultLabel = end;
3004     fBreakTarget.push(end);
3005     int size = 3;
3006     for (const auto& c : s.fCases) {
3007         SpvId label = this->nextId();
3008         labels.push_back(label);
3009         if (c->fValue) {
3010             size += 2;
3011         } else {
3012             defaultLabel = label;
3013         }
3014     }
3015     labels.push_back(end);
3016     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3017     this->writeOpCode(SpvOpSwitch, size, out);
3018     this->writeWord(value, out);
3019     this->writeWord(defaultLabel, out);
3020     for (size_t i = 0; i < s.fCases.size(); ++i) {
3021         if (!s.fCases[i]->fValue) {
3022             continue;
3023         }
3024         SkASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
3025         this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
3026         this->writeWord(labels[i], out);
3027     }
3028     for (size_t i = 0; i < s.fCases.size(); ++i) {
3029         this->writeLabel(labels[i], out);
3030         for (const auto& stmt : s.fCases[i]->fStatements) {
3031             this->writeStatement(*stmt, out);
3032         }
3033         if (fCurrentBlock) {
3034             this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3035         }
3036     }
3037     this->writeLabel(end, out);
3038     fBreakTarget.pop();
3039 }
3040 
writeReturnStatement(const ReturnStatement & r,OutputStream & out)3041 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3042     if (r.fExpression) {
3043         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
3044                                out);
3045     } else {
3046         this->writeInstruction(SpvOpReturn, out);
3047     }
3048 }
3049 
writeGeometryShaderExecutionMode(SpvId entryPoint,OutputStream & out)3050 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
3051     SkASSERT(fProgram.fKind == Program::kGeometry_Kind);
3052     int invocations = 1;
3053     for (const auto& e : fProgram) {
3054         if (e.fKind == ProgramElement::kModifiers_Kind) {
3055             const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3056             if (m.fFlags & Modifiers::kIn_Flag) {
3057                 if (m.fLayout.fInvocations != -1) {
3058                     invocations = m.fLayout.fInvocations;
3059                 }
3060                 SpvId input;
3061                 switch (m.fLayout.fPrimitive) {
3062                     case Layout::kPoints_Primitive:
3063                         input = SpvExecutionModeInputPoints;
3064                         break;
3065                     case Layout::kLines_Primitive:
3066                         input = SpvExecutionModeInputLines;
3067                         break;
3068                     case Layout::kLinesAdjacency_Primitive:
3069                         input = SpvExecutionModeInputLinesAdjacency;
3070                         break;
3071                     case Layout::kTriangles_Primitive:
3072                         input = SpvExecutionModeTriangles;
3073                         break;
3074                     case Layout::kTrianglesAdjacency_Primitive:
3075                         input = SpvExecutionModeInputTrianglesAdjacency;
3076                         break;
3077                     default:
3078                         input = 0;
3079                         break;
3080                 }
3081                 update_sk_in_count(m, &fSkInCount);
3082                 if (input) {
3083                     this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
3084                 }
3085             } else if (m.fFlags & Modifiers::kOut_Flag) {
3086                 SpvId output;
3087                 switch (m.fLayout.fPrimitive) {
3088                     case Layout::kPoints_Primitive:
3089                         output = SpvExecutionModeOutputPoints;
3090                         break;
3091                     case Layout::kLineStrip_Primitive:
3092                         output = SpvExecutionModeOutputLineStrip;
3093                         break;
3094                     case Layout::kTriangleStrip_Primitive:
3095                         output = SpvExecutionModeOutputTriangleStrip;
3096                         break;
3097                     default:
3098                         output = 0;
3099                         break;
3100                 }
3101                 if (output) {
3102                     this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
3103                 }
3104                 if (m.fLayout.fMaxVertices != -1) {
3105                     this->writeInstruction(SpvOpExecutionMode, entryPoint,
3106                                            SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
3107                                            out);
3108                 }
3109             }
3110         }
3111     }
3112     this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
3113                            invocations, out);
3114 }
3115 
writeInstructions(const Program & program,OutputStream & out)3116 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3117     fGLSLExtendedInstructions = this->nextId();
3118     StringStream body;
3119     std::set<SpvId> interfaceVars;
3120     // assign IDs to functions, determine sk_in size
3121     int skInSize = -1;
3122     for (const auto& e : program) {
3123         switch (e.fKind) {
3124             case ProgramElement::kFunction_Kind: {
3125                 FunctionDefinition& f = (FunctionDefinition&) e;
3126                 fFunctionMap[&f.fDeclaration] = this->nextId();
3127                 break;
3128             }
3129             case ProgramElement::kModifiers_Kind: {
3130                 Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
3131                 if (m.fFlags & Modifiers::kIn_Flag) {
3132                     switch (m.fLayout.fPrimitive) {
3133                         case Layout::kPoints_Primitive: // break
3134                         case Layout::kLines_Primitive:
3135                             skInSize = 1;
3136                             break;
3137                         case Layout::kLinesAdjacency_Primitive: // break
3138                             skInSize = 2;
3139                             break;
3140                         case Layout::kTriangles_Primitive: // break
3141                         case Layout::kTrianglesAdjacency_Primitive:
3142                             skInSize = 3;
3143                             break;
3144                         default:
3145                             break;
3146                     }
3147                 }
3148                 break;
3149             }
3150             default:
3151                 break;
3152         }
3153     }
3154     for (const auto& e : program) {
3155         if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
3156             InterfaceBlock& intf = (InterfaceBlock&) e;
3157             if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
3158                 SkASSERT(skInSize != -1);
3159                 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
3160             }
3161             SpvId id = this->writeInterfaceBlock(intf);
3162             if (((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
3163                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) &&
3164                 intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
3165                 interfaceVars.insert(id);
3166             }
3167         }
3168     }
3169     for (const auto& e : program) {
3170         if (e.fKind == ProgramElement::kVar_Kind) {
3171             this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
3172         }
3173     }
3174     for (const auto& e : program) {
3175         if (e.fKind == ProgramElement::kFunction_Kind) {
3176             this->writeFunction(((FunctionDefinition&) e), body);
3177         }
3178     }
3179     const FunctionDeclaration* main = nullptr;
3180     for (auto entry : fFunctionMap) {
3181         if (entry.first->fName == "main") {
3182             main = entry.first;
3183         }
3184     }
3185     if (!main) {
3186         fErrors.error(0, "program does not contain a main() function");
3187         return;
3188     }
3189     for (auto entry : fVariableMap) {
3190         const Variable* var = entry.first;
3191         if (var->fStorage == Variable::kGlobal_Storage &&
3192             ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
3193              (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
3194             interfaceVars.insert(entry.second);
3195         }
3196     }
3197     this->writeCapabilities(out);
3198     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3199     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3200     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
3201                       (int32_t) interfaceVars.size(), out);
3202     switch (program.fKind) {
3203         case Program::kVertex_Kind:
3204             this->writeWord(SpvExecutionModelVertex, out);
3205             break;
3206         case Program::kFragment_Kind:
3207             this->writeWord(SpvExecutionModelFragment, out);
3208             break;
3209         case Program::kGeometry_Kind:
3210             this->writeWord(SpvExecutionModelGeometry, out);
3211             break;
3212         default:
3213             ABORT("cannot write this kind of program to SPIR-V\n");
3214     }
3215     SpvId entryPoint = fFunctionMap[main];
3216     this->writeWord(entryPoint, out);
3217     this->writeString(main->fName.fChars, main->fName.fLength, out);
3218     for (int var : interfaceVars) {
3219         this->writeWord(var, out);
3220     }
3221     if (program.fKind == Program::kGeometry_Kind) {
3222         this->writeGeometryShaderExecutionMode(entryPoint, out);
3223     }
3224     if (program.fKind == Program::kFragment_Kind) {
3225         this->writeInstruction(SpvOpExecutionMode,
3226                                fFunctionMap[main],
3227                                SpvExecutionModeOriginUpperLeft,
3228                                out);
3229     }
3230     for (const auto& e : program) {
3231         if (e.fKind == ProgramElement::kExtension_Kind) {
3232             this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
3233         }
3234     }
3235 
3236     write_stringstream(fExtraGlobalsBuffer, out);
3237     write_stringstream(fNameBuffer, out);
3238     write_stringstream(fDecorationBuffer, out);
3239     write_stringstream(fConstantBuffer, out);
3240     write_stringstream(fExternalFunctionsBuffer, out);
3241     write_stringstream(body, out);
3242 }
3243 
generateCode()3244 bool SPIRVCodeGenerator::generateCode() {
3245     SkASSERT(!fErrors.errorCount());
3246     this->writeWord(SpvMagicNumber, *fOut);
3247     this->writeWord(SpvVersion, *fOut);
3248     this->writeWord(SKSL_MAGIC, *fOut);
3249     StringStream buffer;
3250     this->writeInstructions(fProgram, buffer);
3251     this->writeWord(fIdCount, *fOut);
3252     this->writeWord(0, *fOut); // reserved, always zero
3253     write_stringstream(buffer, *fOut);
3254     return 0 == fErrors.errorCount();
3255 }
3256 
3257 }
3258