1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #ifndef SKSL_SPIRVCODEGENERATOR 9 #define SKSL_SPIRVCODEGENERATOR 10 11 #include <stack> 12 #include <tuple> 13 #include <unordered_map> 14 15 #include "src/sksl/SkSLCodeGenerator.h" 16 #include "src/sksl/SkSLMemoryLayout.h" 17 #include "src/sksl/SkSLStringStream.h" 18 #include "src/sksl/ir/SkSLBinaryExpression.h" 19 #include "src/sksl/ir/SkSLBoolLiteral.h" 20 #include "src/sksl/ir/SkSLConstructor.h" 21 #include "src/sksl/ir/SkSLDoStatement.h" 22 #include "src/sksl/ir/SkSLFieldAccess.h" 23 #include "src/sksl/ir/SkSLFloatLiteral.h" 24 #include "src/sksl/ir/SkSLForStatement.h" 25 #include "src/sksl/ir/SkSLFunctionCall.h" 26 #include "src/sksl/ir/SkSLFunctionDeclaration.h" 27 #include "src/sksl/ir/SkSLFunctionDefinition.h" 28 #include "src/sksl/ir/SkSLIfStatement.h" 29 #include "src/sksl/ir/SkSLIndexExpression.h" 30 #include "src/sksl/ir/SkSLIntLiteral.h" 31 #include "src/sksl/ir/SkSLInterfaceBlock.h" 32 #include "src/sksl/ir/SkSLPostfixExpression.h" 33 #include "src/sksl/ir/SkSLPrefixExpression.h" 34 #include "src/sksl/ir/SkSLProgramElement.h" 35 #include "src/sksl/ir/SkSLReturnStatement.h" 36 #include "src/sksl/ir/SkSLStatement.h" 37 #include "src/sksl/ir/SkSLSwitchStatement.h" 38 #include "src/sksl/ir/SkSLSwizzle.h" 39 #include "src/sksl/ir/SkSLTernaryExpression.h" 40 #include "src/sksl/ir/SkSLVarDeclarations.h" 41 #include "src/sksl/ir/SkSLVariableReference.h" 42 #include "src/sksl/ir/SkSLWhileStatement.h" 43 #include "src/sksl/spirv.h" 44 45 union ConstantValue { ConstantValue(int64_t i)46 ConstantValue(int64_t i) 47 : fInt(i) { 48 SkASSERT(sizeof(*this) == sizeof(int64_t)); 49 } 50 ConstantValue(SKSL_FLOAT f)51 ConstantValue(SKSL_FLOAT f) { 52 memset(this, 0, sizeof(*this)); 53 fFloat = f; 54 } 55 56 bool operator==(const ConstantValue& other) const { 57 return fInt == other.fInt; 58 } 59 60 int64_t fInt; 61 SKSL_FLOAT fFloat; 62 }; 63 64 enum class ConstantType { 65 kInt, 66 kUInt, 67 kShort, 68 kUShort, 69 kFloat, 70 kDouble, 71 kHalf, 72 }; 73 74 namespace std { 75 76 template <> 77 struct hash<std::pair<ConstantValue, ConstantType>> { 78 size_t operator()(const std::pair<ConstantValue, ConstantType>& key) const { 79 return key.first.fInt ^ (int) key.second; 80 } 81 }; 82 83 } // namespace std 84 85 namespace SkSL { 86 87 #define kLast_Capability SpvCapabilityMultiViewport 88 89 /** 90 * Converts a Program into a SPIR-V binary. 91 */ 92 class SPIRVCodeGenerator : public CodeGenerator { 93 public: 94 class LValue { 95 public: 96 virtual ~LValue() {} 97 98 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced 99 // by a pointer (e.g. vector swizzles), returns 0. 100 virtual SpvId getPointer() = 0; 101 102 virtual SpvId load(OutputStream& out) = 0; 103 104 virtual void store(SpvId value, OutputStream& out) = 0; 105 }; 106 107 SPIRVCodeGenerator(const Context* context, 108 const Program* program, 109 ErrorReporter* errors, 110 OutputStream* out) 111 : INHERITED(program, errors, out) 112 , fContext(*context) 113 , fDefaultLayout(MemoryLayout::k140_Standard) 114 , fCapabilities(0) 115 , fIdCount(1) 116 , fBoolTrue(0) 117 , fBoolFalse(0) 118 , fSetupFragPosition(false) 119 , fCurrentBlock(0) 120 , fSynthetics(errors, /*builtin=*/true) { 121 this->setupIntrinsics(); 122 } 123 124 bool generateCode() override; 125 126 private: 127 enum IntrinsicKind { 128 kGLSL_STD_450_IntrinsicKind, 129 kSPIRV_IntrinsicKind, 130 kSpecial_IntrinsicKind 131 }; 132 133 enum SpecialIntrinsic { 134 kAtan_SpecialIntrinsic, 135 kClamp_SpecialIntrinsic, 136 kMax_SpecialIntrinsic, 137 kMin_SpecialIntrinsic, 138 kMix_SpecialIntrinsic, 139 kMod_SpecialIntrinsic, 140 kDFdy_SpecialIntrinsic, 141 kSaturate_SpecialIntrinsic, 142 kSampledImage_SpecialIntrinsic, 143 kSubpassLoad_SpecialIntrinsic, 144 kTexture_SpecialIntrinsic, 145 }; 146 147 enum class Precision { 148 kLow, 149 kHigh, 150 }; 151 152 void setupIntrinsics(); 153 154 SpvId nextId(); 155 156 const Type& getActualType(const Type& type); 157 158 SpvId getType(const Type& type); 159 160 SpvId getType(const Type& type, const MemoryLayout& layout); 161 162 SpvId getImageType(const Type& type); 163 164 SpvId getFunctionType(const FunctionDeclaration& function); 165 166 SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass); 167 168 SpvId getPointerType(const Type& type, const MemoryLayout& layout, 169 SpvStorageClass_ storageClass); 170 171 void writePrecisionModifier(Precision precision, SpvId id); 172 173 void writePrecisionModifier(const Type& type, SpvId id); 174 175 std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out); 176 177 void writeLayout(const Layout& layout, SpvId target); 178 179 void writeLayout(const Layout& layout, SpvId target, int member); 180 181 void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId); 182 183 void writeProgramElement(const ProgramElement& pe, OutputStream& out); 184 185 SpvId writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTHeight = true); 186 187 SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out); 188 189 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out); 190 191 SpvId writeFunction(const FunctionDefinition& f, OutputStream& out); 192 193 void writeGlobalVar(Program::Kind kind, const VarDeclaration& v, OutputStream& out); 194 195 void writeVarDeclaration(const VarDeclaration& var, OutputStream& out); 196 197 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out); 198 199 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out); 200 201 SpvId writeExpression(const Expression& expr, OutputStream& out); 202 203 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out); 204 205 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out); 206 207 208 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst, 209 SpvId signedInst, SpvId unsignedInst, 210 const std::vector<SpvId>& args, OutputStream& out); 211 212 /** 213 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the 214 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2), 215 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float, 216 * vec2, vec3). 217 */ 218 std::vector<SpvId> vectorize(const ExpressionArray& args, OutputStream& out); 219 220 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out); 221 222 SpvId writeConstantVector(const Constructor& c); 223 224 SpvId writeFloatConstructor(const Constructor& c, OutputStream& out); 225 226 SpvId writeIntConstructor(const Constructor& c, OutputStream& out); 227 228 SpvId writeUIntConstructor(const Constructor& c, OutputStream& out); 229 230 /** 231 * Writes a matrix with the diagonal entries all equal to the provided expression, and all other 232 * entries equal to zero. 233 */ 234 void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out); 235 236 /** 237 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the 238 * source matrix are filled with zero; entries which do not exist in the destination matrix are 239 * ignored. 240 */ 241 void writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, const Type& dstType, 242 OutputStream& out); 243 244 void addColumnEntry(SpvId columnType, Precision precision, std::vector<SpvId>* currentColumn, 245 std::vector<SpvId>* columnIds, int* currentCount, int rows, SpvId entry, 246 OutputStream& out); 247 248 SpvId writeMatrixConstructor(const Constructor& c, OutputStream& out); 249 250 SpvId writeVectorConstructor(const Constructor& c, OutputStream& out); 251 252 SpvId writeArrayConstructor(const Constructor& c, OutputStream& out); 253 254 SpvId writeConstructor(const Constructor& c, OutputStream& out); 255 256 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out); 257 258 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out); 259 260 /** 261 * Folds the potentially-vector result of a logical operation down to a single bool. If 262 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the 263 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise, 264 * returns the original id value. 265 */ 266 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out); 267 268 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator, 269 SpvOp_ intOperator, SpvOp_ vectorMergeOperator, 270 SpvOp_ mergeOperator, OutputStream& out); 271 272 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs, 273 SpvOp_ floatOperator, SpvOp_ intOperator, 274 OutputStream& out); 275 276 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs, 277 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt, 278 SpvOp_ ifBool, OutputStream& out); 279 280 SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt, 281 SpvOp_ ifUInt, OutputStream& out); 282 283 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op, 284 const Type& rightType, SpvId rhs, const Type& resultType, 285 OutputStream& out); 286 287 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out); 288 289 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out); 290 291 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out); 292 293 SpvId writeLogicalAnd(const BinaryExpression& b, OutputStream& out); 294 295 SpvId writeLogicalOr(const BinaryExpression& o, OutputStream& out); 296 297 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out); 298 299 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out); 300 301 SpvId writeBoolLiteral(const BoolLiteral& b); 302 303 SpvId writeIntLiteral(const IntLiteral& i); 304 305 SpvId writeFloatLiteral(const FloatLiteral& f); 306 307 void writeStatement(const Statement& s, OutputStream& out); 308 309 void writeBlock(const Block& b, OutputStream& out); 310 311 void writeIfStatement(const IfStatement& stmt, OutputStream& out); 312 313 void writeForStatement(const ForStatement& f, OutputStream& out); 314 315 void writeWhileStatement(const WhileStatement& w, OutputStream& out); 316 317 void writeDoStatement(const DoStatement& d, OutputStream& out); 318 319 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out); 320 321 void writeReturnStatement(const ReturnStatement& r, OutputStream& out); 322 323 void writeCapabilities(OutputStream& out); 324 325 void writeInstructions(const Program& program, OutputStream& out); 326 327 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out); 328 329 void writeWord(int32_t word, OutputStream& out); 330 331 void writeString(const char* string, size_t length, OutputStream& out); 332 333 void writeLabel(SpvId id, OutputStream& out); 334 335 void writeInstruction(SpvOp_ opCode, OutputStream& out); 336 337 void writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out); 338 339 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out); 340 341 void writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, OutputStream& out); 342 343 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, StringFragment string, 344 OutputStream& out); 345 346 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out); 347 348 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, 349 OutputStream& out); 350 351 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 352 OutputStream& out); 353 354 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 355 int32_t word5, OutputStream& out); 356 357 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 358 int32_t word5, int32_t word6, OutputStream& out); 359 360 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 361 int32_t word5, int32_t word6, int32_t word7, OutputStream& out); 362 363 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4, 364 int32_t word5, int32_t word6, int32_t word7, int32_t word8, 365 OutputStream& out); 366 367 void writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out); 368 369 const Context& fContext; 370 const MemoryLayout fDefaultLayout; 371 372 uint64_t fCapabilities; 373 SpvId fIdCount; 374 SpvId fGLSLExtendedInstructions; 375 typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic; 376 std::unordered_map<String, Intrinsic> fIntrinsicMap; 377 std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap; 378 std::unordered_map<const Variable*, SpvId> fVariableMap; 379 std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap; 380 std::unordered_map<String, SpvId> fImageTypeMap; 381 std::unordered_map<String, SpvId> fTypeMap; 382 StringStream fCapabilitiesBuffer; 383 StringStream fGlobalInitializersBuffer; 384 StringStream fConstantBuffer; 385 StringStream fExtraGlobalsBuffer; 386 StringStream fExternalFunctionsBuffer; 387 StringStream fVariableBuffer; 388 StringStream fNameBuffer; 389 StringStream fDecorationBuffer; 390 391 SpvId fBoolTrue; 392 SpvId fBoolFalse; 393 std::unordered_map<std::pair<ConstantValue, ConstantType>, SpvId> fNumberConstants; 394 bool fSetupFragPosition; 395 // label of the current block, or 0 if we are not in a block 396 SpvId fCurrentBlock; 397 std::stack<SpvId> fBreakTarget; 398 std::stack<SpvId> fContinueTarget; 399 SpvId fRTHeightStructId = (SpvId) -1; 400 SpvId fRTHeightFieldIndex = (SpvId) -1; 401 SpvStorageClass_ fRTHeightStorageClass; 402 // holds variables synthesized during output, for lifetime purposes 403 SymbolTable fSynthetics; 404 int fSkInCount = 1; 405 406 friend class PointerLValue; 407 friend class SwizzleLValue; 408 409 using INHERITED = CodeGenerator; 410 }; 411 412 } // namespace SkSL 413 414 #endif 415