1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)89 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
90 {
91     const TIntermBinary *binaryNode = node->getAsBinaryNode();
92     if (binaryNode)
93     {
94         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
95         {
96             return binaryNode->getLeft()->getType().getInterfaceBlock();
97         }
98     }
99 
100     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
101     if (symbolNode)
102     {
103         const TVariable &variable = symbolNode->variable();
104         const TType &variableType = variable.getType();
105 
106         if (variableType.getQualifier() == EvqUniform &&
107             variable.symbolType() == SymbolType::UserDefined)
108         {
109             return variableType.getInterfaceBlock();
110         }
111     }
112 
113     return nullptr;
114 }
115 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)116 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
117 {
118     switch (op)
119     {
120         case EOpAtomicAdd:
121             return "InterlockedAdd(";
122         case EOpAtomicMin:
123             return "InterlockedMin(";
124         case EOpAtomicMax:
125             return "InterlockedMax(";
126         case EOpAtomicAnd:
127             return "InterlockedAnd(";
128         case EOpAtomicOr:
129             return "InterlockedOr(";
130         case EOpAtomicXor:
131             return "InterlockedXor(";
132         case EOpAtomicExchange:
133             return "InterlockedExchange(";
134         case EOpAtomicCompSwap:
135             return "InterlockedCompareExchange(";
136         default:
137             UNREACHABLE();
138             return "";
139     }
140 }
141 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)142 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
143 {
144     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
145     if (aggregateNode == nullptr)
146     {
147         return false;
148     }
149 
150     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
151     {
152         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
153     }
154 
155     return false;
156 }
157 
158 const char *kZeros       = "_ANGLE_ZEROS_";
159 constexpr int kZeroCount = 256;
DefineZeroArray()160 std::string DefineZeroArray()
161 {
162     std::stringstream ss = sh::InitializeStream<std::stringstream>();
163     // For 'static', if the declaration does not include an initializer, the value is set to zero.
164     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
165     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
166     return ss.str();
167 }
168 
GetZeroInitializer(size_t size)169 std::string GetZeroInitializer(size_t size)
170 {
171     std::stringstream ss = sh::InitializeStream<std::stringstream>();
172     size_t quotient      = size / kZeroCount;
173     size_t reminder      = size % kZeroCount;
174 
175     for (size_t i = 0; i < quotient; ++i)
176     {
177         if (i != 0)
178         {
179             ss << ", ";
180         }
181         ss << kZeros;
182     }
183 
184     for (size_t i = 0; i < reminder; ++i)
185     {
186         if (quotient != 0 || i != 0)
187         {
188             ss << ", ";
189         }
190         ss << "0";
191     }
192 
193     return ss.str();
194 }
195 
196 }  // anonymous namespace
197 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)198 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
199                                    const TVariable *aInstanceVariable)
200     : block(aBlock), instanceVariable(aInstanceVariable)
201 {}
202 
needStructMapping(TIntermTyped * node)203 bool OutputHLSL::needStructMapping(TIntermTyped *node)
204 {
205     ASSERT(node->getBasicType() == EbtStruct);
206     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
207     {
208         TIntermNode *ancestor               = getAncestorNode(n);
209         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
210         if (ancestorBinary)
211         {
212             switch (ancestorBinary->getOp())
213             {
214                 case EOpIndexDirectStruct:
215                 {
216                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
217                     const TIntermConstantUnion *index =
218                         ancestorBinary->getRight()->getAsConstantUnion();
219                     const TField *field = structure->fields()[index->getIConst(0)];
220                     if (field->type()->getStruct() == nullptr)
221                     {
222                         return false;
223                     }
224                     break;
225                 }
226                 case EOpIndexDirect:
227                 case EOpIndexIndirect:
228                     break;
229                 default:
230                     return true;
231             }
232         }
233         else
234         {
235             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
236             if (ancestorAggregate)
237             {
238                 return true;
239             }
240             return false;
241         }
242     }
243     return true;
244 }
245 
writeFloat(TInfoSinkBase & out,float f)246 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
247 {
248     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
249     // regardless.
250     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
251         mOutputType == SH_HLSL_4_1_OUTPUT)
252     {
253         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
254     }
255     else
256     {
257         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
258     }
259 }
260 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)261 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
262 {
263     ASSERT(constUnion != nullptr);
264     switch (constUnion->getType())
265     {
266         case EbtFloat:
267             writeFloat(out, constUnion->getFConst());
268             break;
269         case EbtInt:
270             out << constUnion->getIConst();
271             break;
272         case EbtUInt:
273             out << constUnion->getUConst();
274             break;
275         case EbtBool:
276             out << constUnion->getBConst();
277             break;
278         default:
279             UNREACHABLE();
280     }
281 }
282 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)283 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
284                                                           const TConstantUnion *const constUnion,
285                                                           const size_t size)
286 {
287     const TConstantUnion *constUnionIterated = constUnion;
288     for (size_t i = 0; i < size; i++, constUnionIterated++)
289     {
290         writeSingleConstant(out, constUnionIterated);
291 
292         if (i != size - 1)
293         {
294             out << ", ";
295         }
296     }
297     return constUnionIterated;
298 }
299 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::map<int,const TInterfaceBlock * > & uniformBlocksTranslatedToStructuredBuffers,const std::vector<InterfaceBlock> & shaderStorageBlocks)300 OutputHLSL::OutputHLSL(
301     sh::GLenum shaderType,
302     ShShaderSpec shaderSpec,
303     int shaderVersion,
304     const TExtensionBehavior &extensionBehavior,
305     const char *sourcePath,
306     ShShaderOutput outputType,
307     int numRenderTargets,
308     int maxDualSourceDrawBuffers,
309     const std::vector<ShaderVariable> &uniforms,
310     ShCompileOptions compileOptions,
311     sh::WorkGroupSize workGroupSize,
312     TSymbolTable *symbolTable,
313     PerformanceDiagnostics *perfDiagnostics,
314     const std::map<int, const TInterfaceBlock *> &uniformBlocksTranslatedToStructuredBuffers,
315     const std::vector<InterfaceBlock> &shaderStorageBlocks)
316     : TIntermTraverser(true, true, true, symbolTable),
317       mShaderType(shaderType),
318       mShaderSpec(shaderSpec),
319       mShaderVersion(shaderVersion),
320       mExtensionBehavior(extensionBehavior),
321       mSourcePath(sourcePath),
322       mOutputType(outputType),
323       mCompileOptions(compileOptions),
324       mInsideFunction(false),
325       mInsideMain(false),
326       mUniformBlocksTranslatedToStructuredBuffers(uniformBlocksTranslatedToStructuredBuffers),
327       mNumRenderTargets(numRenderTargets),
328       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
329       mCurrentFunctionMetadata(nullptr),
330       mWorkGroupSize(workGroupSize),
331       mPerfDiagnostics(perfDiagnostics),
332       mNeedStructMapping(false)
333 {
334     mUsesFragColor        = false;
335     mUsesFragData         = false;
336     mUsesDepthRange       = false;
337     mUsesFragCoord        = false;
338     mUsesPointCoord       = false;
339     mUsesFrontFacing      = false;
340     mUsesHelperInvocation = false;
341     mUsesPointSize        = false;
342     mUsesInstanceID       = false;
343     mHasMultiviewExtensionEnabled =
344         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
345         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
346     mUsesViewID                  = false;
347     mUsesVertexID                = false;
348     mUsesFragDepth               = false;
349     mUsesNumWorkGroups           = false;
350     mUsesWorkGroupID             = false;
351     mUsesLocalInvocationID       = false;
352     mUsesGlobalInvocationID      = false;
353     mUsesLocalInvocationIndex    = false;
354     mUsesXor                     = false;
355     mUsesDiscardRewriting        = false;
356     mUsesNestedBreak             = false;
357     mRequiresIEEEStrictCompiling = false;
358     mUseZeroArray                = false;
359     mUsesSecondaryColor          = false;
360 
361     mUniqueIndex = 0;
362 
363     mOutputLod0Function      = false;
364     mInsideDiscontinuousLoop = false;
365     mNestedLoopDepth         = 0;
366 
367     mExcessiveLoopIndex = nullptr;
368 
369     mStructureHLSL       = new StructureHLSL;
370     mTextureFunctionHLSL = new TextureFunctionHLSL;
371     mImageFunctionHLSL   = new ImageFunctionHLSL;
372     mAtomicCounterFunctionHLSL =
373         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
374 
375     unsigned int firstUniformRegister =
376         ((compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0) ? 1u : 0u;
377     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
378 
379     if (mOutputType == SH_HLSL_3_0_OUTPUT)
380     {
381         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
382         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
383         // dx_ViewAdjust.
384         // In both cases total 3 uniform registers need to be reserved.
385         mResourcesHLSL->reserveUniformRegisters(3);
386     }
387 
388     // Reserve registers for the default uniform block and driver constants
389     mResourcesHLSL->reserveUniformBlockRegisters(2);
390 
391     mSSBOOutputHLSL =
392         new ShaderStorageBlockOutputHLSL(this, symbolTable, mResourcesHLSL, shaderStorageBlocks);
393 }
394 
~OutputHLSL()395 OutputHLSL::~OutputHLSL()
396 {
397     SafeDelete(mSSBOOutputHLSL);
398     SafeDelete(mStructureHLSL);
399     SafeDelete(mResourcesHLSL);
400     SafeDelete(mTextureFunctionHLSL);
401     SafeDelete(mImageFunctionHLSL);
402     SafeDelete(mAtomicCounterFunctionHLSL);
403     for (auto &eqFunction : mStructEqualityFunctions)
404     {
405         SafeDelete(eqFunction);
406     }
407     for (auto &eqFunction : mArrayEqualityFunctions)
408     {
409         SafeDelete(eqFunction);
410     }
411 }
412 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)413 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
414 {
415     BuiltInFunctionEmulator builtInFunctionEmulator;
416     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
417     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
418     {
419         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
420                                                            mShaderVersion);
421     }
422 
423     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
424 
425     // Now that we are done changing the AST, do the analyses need for HLSL generation
426     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
427     ASSERT(success == CallDAG::INITDAG_SUCCESS);
428     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
429 
430     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
431     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
432     // the shader code. When we add shader storage blocks we might also consider an alternative
433     // solution, since the struct mapping won't work very well for shader storage blocks.
434 
435     // Output the body and footer first to determine what has to go in the header
436     mInfoSinkStack.push(&mBody);
437     treeRoot->traverse(this);
438     mInfoSinkStack.pop();
439 
440     mInfoSinkStack.push(&mFooter);
441     mInfoSinkStack.pop();
442 
443     mInfoSinkStack.push(&mHeader);
444     header(mHeader, std140Structs, &builtInFunctionEmulator);
445     mInfoSinkStack.pop();
446 
447     objSink << mHeader.c_str();
448     objSink << mBody.c_str();
449     objSink << mFooter.c_str();
450 
451     builtInFunctionEmulator.cleanup();
452 }
453 
getShaderStorageBlockRegisterMap() const454 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
455 {
456     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
457 }
458 
getUniformBlockRegisterMap() const459 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
460 {
461     return mResourcesHLSL->getUniformBlockRegisterMap();
462 }
463 
getUniformBlockUseStructuredBufferMap() const464 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
465 {
466     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
467 }
468 
getUniformRegisterMap() const469 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
470 {
471     return mResourcesHLSL->getUniformRegisterMap();
472 }
473 
getReadonlyImage2DRegisterIndex() const474 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
475 {
476     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
477 }
478 
getImage2DRegisterIndex() const479 unsigned int OutputHLSL::getImage2DRegisterIndex() const
480 {
481     return mResourcesHLSL->getImage2DRegisterIndex();
482 }
483 
getUsedImage2DFunctionNames() const484 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
485 {
486     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
487 }
488 
structInitializerString(int indent,const TType & type,const TString & name) const489 TString OutputHLSL::structInitializerString(int indent,
490                                             const TType &type,
491                                             const TString &name) const
492 {
493     TString init;
494 
495     TString indentString;
496     for (int spaces = 0; spaces < indent; spaces++)
497     {
498         indentString += "    ";
499     }
500 
501     if (type.isArray())
502     {
503         init += indentString + "{\n";
504         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
505         {
506             TStringStream indexedString = sh::InitializeStream<TStringStream>();
507             indexedString << name << "[" << arrayIndex << "]";
508             TType elementType = type;
509             elementType.toArrayElementType();
510             init += structInitializerString(indent + 1, elementType, indexedString.str());
511             if (arrayIndex < type.getOutermostArraySize() - 1)
512             {
513                 init += ",";
514             }
515             init += "\n";
516         }
517         init += indentString + "}";
518     }
519     else if (type.getBasicType() == EbtStruct)
520     {
521         init += indentString + "{\n";
522         const TStructure &structure = *type.getStruct();
523         const TFieldList &fields    = structure.fields();
524         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
525         {
526             const TField &field      = *fields[fieldIndex];
527             const TString &fieldName = name + "." + Decorate(field.name());
528             const TType &fieldType   = *field.type();
529 
530             init += structInitializerString(indent + 1, fieldType, fieldName);
531             if (fieldIndex < fields.size() - 1)
532             {
533                 init += ",";
534             }
535             init += "\n";
536         }
537         init += indentString + "}";
538     }
539     else
540     {
541         init += indentString + name;
542     }
543 
544     return init;
545 }
546 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const547 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
548 {
549     TString mappedStructs;
550 
551     for (auto &mappedStruct : std140Structs)
552     {
553         const TInterfaceBlock *interfaceBlock =
554             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
555         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
556         switch (qualifier)
557         {
558             case EvqUniform:
559                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
560                 {
561                     continue;
562                 }
563                 break;
564             case EvqBuffer:
565                 continue;
566             default:
567                 UNREACHABLE();
568                 return mappedStructs;
569         }
570 
571         unsigned int instanceCount = 1u;
572         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
573         if (isInstanceArray)
574         {
575             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
576         }
577 
578         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
579              ++instanceArrayIndex)
580         {
581             TString originalName;
582             TString mappedName("map");
583 
584             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
585             {
586                 const ImmutableString &instanceName =
587                     mappedStruct.blockDeclarator->variable().name();
588                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
589                 if (isInstanceArray)
590                     instanceStringArrayIndex = instanceArrayIndex;
591                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
592                     instanceName, instanceStringArrayIndex);
593                 originalName += instanceString;
594                 mappedName += instanceString;
595                 originalName += ".";
596                 mappedName += "_";
597             }
598 
599             TString fieldName = Decorate(mappedStruct.field->name());
600             originalName += fieldName;
601             mappedName += fieldName;
602 
603             TType *structType = mappedStruct.field->type();
604             mappedStructs +=
605                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
606 
607             if (structType->isArray())
608             {
609                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
610             }
611 
612             mappedStructs += " =\n";
613             mappedStructs += structInitializerString(0, *structType, originalName);
614             mappedStructs += ";\n";
615         }
616     }
617     return mappedStructs;
618 }
619 
writeReferencedAttributes(TInfoSinkBase & out) const620 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
621 {
622     for (const auto &attribute : mReferencedAttributes)
623     {
624         const TType &type           = attribute.second->getType();
625         const ImmutableString &name = attribute.second->name();
626 
627         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
628             << zeroInitializer(type) << ";\n";
629     }
630 }
631 
writeReferencedVaryings(TInfoSinkBase & out) const632 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
633 {
634     for (const auto &varying : mReferencedVaryings)
635     {
636         const TType &type = varying.second->getType();
637 
638         // Program linking depends on this exact format
639         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
640             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
641             << zeroInitializer(type) << ";\n";
642     }
643 }
644 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const645 void OutputHLSL::header(TInfoSinkBase &out,
646                         const std::vector<MappedStruct> &std140Structs,
647                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
648 {
649     TString mappedStructs;
650     if (mNeedStructMapping)
651     {
652         mappedStructs = generateStructMapping(std140Structs);
653     }
654 
655     // Suppress some common warnings:
656     // 3556 : Integer divides might be much slower, try using uints if possible.
657     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
658     //        conditionally handle negative values if you expect them.
659     out << "#pragma warning( disable: 3556 3571 )\n";
660 
661     out << mStructureHLSL->structsHeader();
662 
663     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
664     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks,
665                                                mUniformBlocksTranslatedToStructuredBuffers);
666     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
667 
668     if (!mEqualityFunctions.empty())
669     {
670         out << "\n// Equality functions\n\n";
671         for (const auto &eqFunction : mEqualityFunctions)
672         {
673             out << eqFunction->functionDefinition << "\n";
674         }
675     }
676     if (!mArrayAssignmentFunctions.empty())
677     {
678         out << "\n// Assignment functions\n\n";
679         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
680         {
681             out << assignmentFunction.functionDefinition << "\n";
682         }
683     }
684     if (!mArrayConstructIntoFunctions.empty())
685     {
686         out << "\n// Array constructor functions\n\n";
687         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
688         {
689             out << constructIntoFunction.functionDefinition << "\n";
690         }
691     }
692 
693     if (mUsesDiscardRewriting)
694     {
695         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
696     }
697 
698     if (mUsesNestedBreak)
699     {
700         out << "#define ANGLE_USES_NESTED_BREAK\n";
701     }
702 
703     if (mRequiresIEEEStrictCompiling)
704     {
705         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
706     }
707 
708     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
709            "#define LOOP [loop]\n"
710            "#define FLATTEN [flatten]\n"
711            "#else\n"
712            "#define LOOP\n"
713            "#define FLATTEN\n"
714            "#endif\n";
715 
716     // array stride for atomic counter buffers is always 4 per original extension
717     // ARB_shader_atomic_counters and discussion on
718     // https://github.com/KhronosGroup/OpenGL-API/issues/5
719     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
720 
721     if (mUseZeroArray)
722     {
723         out << DefineZeroArray() << "\n";
724     }
725 
726     if (mShaderType == GL_FRAGMENT_SHADER)
727     {
728         const bool usingMRTExtension =
729             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
730         const bool usingBFEExtension =
731             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
732 
733         out << "// Varyings\n";
734         writeReferencedVaryings(out);
735         out << "\n";
736 
737         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
738             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
739         {
740             for (const auto &outputVariable : mReferencedOutputVariables)
741             {
742                 const ImmutableString &variableName = outputVariable.second->name();
743                 const TType &variableType           = outputVariable.second->getType();
744 
745                 out << "static " << TypeString(variableType) << " out_" << variableName
746                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
747             }
748         }
749         else
750         {
751             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
752 
753             out << "static float4 gl_Color[" << numColorValues
754                 << "] =\n"
755                    "{\n";
756             for (unsigned int i = 0; i < numColorValues; i++)
757             {
758                 out << "    float4(0, 0, 0, 0)";
759                 if (i + 1 != numColorValues)
760                 {
761                     out << ",";
762                 }
763                 out << "\n";
764             }
765 
766             out << "};\n";
767 
768             if (usingBFEExtension && mUsesSecondaryColor)
769             {
770                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
771                     << "] = \n"
772                        "{\n";
773                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
774                 {
775                     out << "    float4(0, 0, 0, 0)";
776                     if (i + 1 != mMaxDualSourceDrawBuffers)
777                     {
778                         out << ",";
779                     }
780                     out << "\n";
781                 }
782                 out << "};\n";
783             }
784         }
785 
786         if (mUsesFragDepth)
787         {
788             out << "static float gl_Depth = 0.0;\n";
789         }
790 
791         if (mUsesFragCoord)
792         {
793             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
794         }
795 
796         if (mUsesPointCoord)
797         {
798             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
799         }
800 
801         if (mUsesFrontFacing)
802         {
803             out << "static bool gl_FrontFacing = false;\n";
804         }
805 
806         if (mUsesHelperInvocation)
807         {
808             out << "static bool gl_HelperInvocation = false;\n";
809         }
810 
811         out << "\n";
812 
813         if (mUsesDepthRange)
814         {
815             out << "struct gl_DepthRangeParameters\n"
816                    "{\n"
817                    "    float near;\n"
818                    "    float far;\n"
819                    "    float diff;\n"
820                    "};\n"
821                    "\n";
822         }
823 
824         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
825         {
826             out << "cbuffer DriverConstants : register(b1)\n"
827                    "{\n";
828 
829             if (mUsesDepthRange)
830             {
831                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
832             }
833 
834             if (mUsesFragCoord)
835             {
836                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
837             }
838 
839             if (mUsesFragCoord || mUsesFrontFacing)
840             {
841                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
842             }
843 
844             if (mUsesFragCoord)
845             {
846                 // dx_ViewScale is only used in the fragment shader to correct
847                 // the value for glFragCoord if necessary
848                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
849             }
850 
851             if (mHasMultiviewExtensionEnabled)
852             {
853                 // We have to add a value which we can use to keep track of which multi-view code
854                 // path is to be selected in the GS.
855                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
856             }
857 
858             if (mOutputType == SH_HLSL_4_1_OUTPUT)
859             {
860                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
861             }
862 
863             out << "};\n";
864         }
865         else
866         {
867             if (mUsesDepthRange)
868             {
869                 out << "uniform float3 dx_DepthRange : register(c0);";
870             }
871 
872             if (mUsesFragCoord)
873             {
874                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
875             }
876 
877             if (mUsesFragCoord || mUsesFrontFacing)
878             {
879                 out << "uniform float3 dx_DepthFront : register(c2);\n";
880             }
881         }
882 
883         out << "\n";
884 
885         if (mUsesDepthRange)
886         {
887             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
888                    "dx_DepthRange.y, dx_DepthRange.z};\n"
889                    "\n";
890         }
891 
892         if (usingMRTExtension && mNumRenderTargets > 1)
893         {
894             out << "#define GL_USES_MRT\n";
895         }
896 
897         if (mUsesFragColor)
898         {
899             out << "#define GL_USES_FRAG_COLOR\n";
900         }
901 
902         if (mUsesFragData)
903         {
904             out << "#define GL_USES_FRAG_DATA\n";
905         }
906 
907         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
908         {
909             out << "#define GL_USES_SECONDARY_COLOR\n";
910         }
911     }
912     else if (mShaderType == GL_VERTEX_SHADER)
913     {
914         out << "// Attributes\n";
915         writeReferencedAttributes(out);
916         out << "\n"
917                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
918 
919         if (mUsesPointSize)
920         {
921             out << "static float gl_PointSize = float(1);\n";
922         }
923 
924         if (mUsesInstanceID)
925         {
926             out << "static int gl_InstanceID;";
927         }
928 
929         if (mUsesVertexID)
930         {
931             out << "static int gl_VertexID;";
932         }
933 
934         out << "\n"
935                "// Varyings\n";
936         writeReferencedVaryings(out);
937         out << "\n";
938 
939         if (mUsesDepthRange)
940         {
941             out << "struct gl_DepthRangeParameters\n"
942                    "{\n"
943                    "    float near;\n"
944                    "    float far;\n"
945                    "    float diff;\n"
946                    "};\n"
947                    "\n";
948         }
949 
950         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
951         {
952             out << "cbuffer DriverConstants : register(b1)\n"
953                    "{\n";
954 
955             if (mUsesDepthRange)
956             {
957                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
958             }
959 
960             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
961             // shaders. However, we declare it for all shaders (including Feature Level 10+).
962             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
963             // if it's unused.
964             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
965             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
966             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
967 
968             if (mHasMultiviewExtensionEnabled)
969             {
970                 // We have to add a value which we can use to keep track of which multi-view code
971                 // path is to be selected in the GS.
972                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
973             }
974 
975             if (mOutputType == SH_HLSL_4_1_OUTPUT)
976             {
977                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
978             }
979 
980             if (mUsesVertexID)
981             {
982                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
983             }
984 
985             out << "};\n"
986                    "\n";
987         }
988         else
989         {
990             if (mUsesDepthRange)
991             {
992                 out << "uniform float3 dx_DepthRange : register(c0);\n";
993             }
994 
995             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
996             out << "uniform float2 dx_ViewCoords : register(c2);\n"
997                    "\n";
998         }
999 
1000         if (mUsesDepthRange)
1001         {
1002             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1003                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1004                    "\n";
1005         }
1006     }
1007     else  // Compute shader
1008     {
1009         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1010 
1011         out << "cbuffer DriverConstants : register(b1)\n"
1012                "{\n";
1013         if (mUsesNumWorkGroups)
1014         {
1015             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1016         }
1017         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1018         unsigned int registerIndex = 1;
1019         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1020         // Sampler metadata struct must be two 4-vec, 32 bytes.
1021         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1022         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1023         out << "};\n";
1024 
1025         out << kImage2DFunctionString << "\n";
1026 
1027         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1028         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1029 
1030         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1031         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1032                                 << "{\n";
1033 
1034         if (mUsesWorkGroupID)
1035         {
1036             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1037             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1038                                    << "SV_GroupID;\n";
1039             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1040         }
1041 
1042         if (mUsesLocalInvocationID)
1043         {
1044             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1045             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1046                                    << "SV_GroupThreadID;\n";
1047             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1048         }
1049 
1050         if (mUsesGlobalInvocationID)
1051         {
1052             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1053             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1054                                    << "SV_DispatchThreadID;\n";
1055             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1056         }
1057 
1058         if (mUsesLocalInvocationIndex)
1059         {
1060             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1061             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1062                                    << "SV_GroupIndex;\n";
1063             glBuiltinInitialization
1064                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1065         }
1066 
1067         systemValueDeclaration << "};\n\n";
1068         glBuiltinInitialization << "};\n\n";
1069 
1070         out << systemValueDeclaration.str();
1071         out << glBuiltinInitialization.str();
1072     }
1073 
1074     if (!mappedStructs.empty())
1075     {
1076         out << "// Structures from std140 blocks with padding removed\n";
1077         out << "\n";
1078         out << mappedStructs;
1079         out << "\n";
1080     }
1081 
1082     bool getDimensionsIgnoresBaseLevel =
1083         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1084     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1085     mImageFunctionHLSL->imageFunctionHeader(out);
1086     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1087 
1088     if (mUsesFragCoord)
1089     {
1090         out << "#define GL_USES_FRAG_COORD\n";
1091     }
1092 
1093     if (mUsesPointCoord)
1094     {
1095         out << "#define GL_USES_POINT_COORD\n";
1096     }
1097 
1098     if (mUsesFrontFacing)
1099     {
1100         out << "#define GL_USES_FRONT_FACING\n";
1101     }
1102 
1103     if (mUsesHelperInvocation)
1104     {
1105         out << "#define GL_USES_HELPER_INVOCATION\n";
1106     }
1107 
1108     if (mUsesPointSize)
1109     {
1110         out << "#define GL_USES_POINT_SIZE\n";
1111     }
1112 
1113     if (mHasMultiviewExtensionEnabled)
1114     {
1115         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1116     }
1117 
1118     if (mUsesVertexID)
1119     {
1120         out << "#define GL_USES_VERTEX_ID\n";
1121     }
1122 
1123     if (mUsesViewID)
1124     {
1125         out << "#define GL_USES_VIEW_ID\n";
1126     }
1127 
1128     if (mUsesFragDepth)
1129     {
1130         out << "#define GL_USES_FRAG_DEPTH\n";
1131     }
1132 
1133     if (mUsesDepthRange)
1134     {
1135         out << "#define GL_USES_DEPTH_RANGE\n";
1136     }
1137 
1138     if (mUsesXor)
1139     {
1140         out << "bool xor(bool p, bool q)\n"
1141                "{\n"
1142                "    return (p || q) && !(p && q);\n"
1143                "}\n"
1144                "\n";
1145     }
1146 
1147     builtInFunctionEmulator->outputEmulatedFunctions(out);
1148 }
1149 
visitSymbol(TIntermSymbol * node)1150 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1151 {
1152     const TVariable &variable = node->variable();
1153 
1154     // Empty symbols can only appear in declarations and function arguments, and in either of those
1155     // cases the symbol nodes are not visited.
1156     ASSERT(variable.symbolType() != SymbolType::Empty);
1157 
1158     TInfoSinkBase &out = getInfoSink();
1159 
1160     // Handle accessing std140 structs by value
1161     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1162         needStructMapping(node))
1163     {
1164         mNeedStructMapping = true;
1165         out << "map";
1166     }
1167 
1168     const ImmutableString &name     = variable.name();
1169     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1170 
1171     if (name == "gl_DepthRange")
1172     {
1173         mUsesDepthRange = true;
1174         out << name;
1175     }
1176     else if (IsAtomicCounter(variable.getType().getBasicType()))
1177     {
1178         const TType &variableType = variable.getType();
1179         if (variableType.getQualifier() == EvqUniform)
1180         {
1181             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1182             mReferencedUniforms[uniqueId.get()] = &variable;
1183             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1184         }
1185         else
1186         {
1187             TString varName = DecorateVariableIfNeeded(variable);
1188             out << varName << ", " << varName << "_offset";
1189         }
1190     }
1191     else
1192     {
1193         const TType &variableType = variable.getType();
1194         TQualifier qualifier      = variable.getType().getQualifier();
1195 
1196         ensureStructDefined(variableType);
1197 
1198         if (qualifier == EvqUniform)
1199         {
1200             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1201 
1202             if (interfaceBlock)
1203             {
1204                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1205                 {
1206                     const TVariable *instanceVariable = nullptr;
1207                     if (variableType.isInterfaceBlock())
1208                     {
1209                         instanceVariable = &variable;
1210                     }
1211                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1212                         new TReferencedBlock(interfaceBlock, instanceVariable);
1213                 }
1214             }
1215             else
1216             {
1217                 mReferencedUniforms[uniqueId.get()] = &variable;
1218             }
1219 
1220             out << DecorateVariableIfNeeded(variable);
1221         }
1222         else if (qualifier == EvqBuffer)
1223         {
1224             UNREACHABLE();
1225         }
1226         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1227         {
1228             mReferencedAttributes[uniqueId.get()] = &variable;
1229             out << Decorate(name);
1230         }
1231         else if (IsVarying(qualifier))
1232         {
1233             mReferencedVaryings[uniqueId.get()] = &variable;
1234             out << DecorateVariableIfNeeded(variable);
1235             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1236             {
1237                 mUsesViewID = true;
1238             }
1239         }
1240         else if (qualifier == EvqFragmentOut)
1241         {
1242             mReferencedOutputVariables[uniqueId.get()] = &variable;
1243             out << "out_" << name;
1244         }
1245         else if (qualifier == EvqFragColor)
1246         {
1247             out << "gl_Color[0]";
1248             mUsesFragColor = true;
1249         }
1250         else if (qualifier == EvqFragData)
1251         {
1252             out << "gl_Color";
1253             mUsesFragData = true;
1254         }
1255         else if (qualifier == EvqSecondaryFragColorEXT)
1256         {
1257             out << "gl_SecondaryColor[0]";
1258             mUsesSecondaryColor = true;
1259         }
1260         else if (qualifier == EvqSecondaryFragDataEXT)
1261         {
1262             out << "gl_SecondaryColor";
1263             mUsesSecondaryColor = true;
1264         }
1265         else if (qualifier == EvqFragCoord)
1266         {
1267             mUsesFragCoord = true;
1268             out << name;
1269         }
1270         else if (qualifier == EvqPointCoord)
1271         {
1272             mUsesPointCoord = true;
1273             out << name;
1274         }
1275         else if (qualifier == EvqFrontFacing)
1276         {
1277             mUsesFrontFacing = true;
1278             out << name;
1279         }
1280         else if (qualifier == EvqHelperInvocation)
1281         {
1282             mUsesHelperInvocation = true;
1283             out << name;
1284         }
1285         else if (qualifier == EvqPointSize)
1286         {
1287             mUsesPointSize = true;
1288             out << name;
1289         }
1290         else if (qualifier == EvqInstanceID)
1291         {
1292             mUsesInstanceID = true;
1293             out << name;
1294         }
1295         else if (qualifier == EvqVertexID)
1296         {
1297             mUsesVertexID = true;
1298             out << name;
1299         }
1300         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1301         {
1302             mUsesFragDepth = true;
1303             out << "gl_Depth";
1304         }
1305         else if (qualifier == EvqNumWorkGroups)
1306         {
1307             mUsesNumWorkGroups = true;
1308             out << name;
1309         }
1310         else if (qualifier == EvqWorkGroupID)
1311         {
1312             mUsesWorkGroupID = true;
1313             out << name;
1314         }
1315         else if (qualifier == EvqLocalInvocationID)
1316         {
1317             mUsesLocalInvocationID = true;
1318             out << name;
1319         }
1320         else if (qualifier == EvqGlobalInvocationID)
1321         {
1322             mUsesGlobalInvocationID = true;
1323             out << name;
1324         }
1325         else if (qualifier == EvqLocalInvocationIndex)
1326         {
1327             mUsesLocalInvocationIndex = true;
1328             out << name;
1329         }
1330         else
1331         {
1332             out << DecorateVariableIfNeeded(variable);
1333         }
1334     }
1335 }
1336 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1337 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1338 {
1339     if (type.isScalar() && !type.isArray())
1340     {
1341         if (op == EOpEqual)
1342         {
1343             outputTriplet(out, visit, "(", " == ", ")");
1344         }
1345         else
1346         {
1347             outputTriplet(out, visit, "(", " != ", ")");
1348         }
1349     }
1350     else
1351     {
1352         if (visit == PreVisit && op == EOpNotEqual)
1353         {
1354             out << "!";
1355         }
1356 
1357         if (type.isArray())
1358         {
1359             const TString &functionName = addArrayEqualityFunction(type);
1360             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1361         }
1362         else if (type.getBasicType() == EbtStruct)
1363         {
1364             const TStructure &structure = *type.getStruct();
1365             const TString &functionName = addStructEqualityFunction(structure);
1366             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1367         }
1368         else
1369         {
1370             ASSERT(type.isMatrix() || type.isVector());
1371             outputTriplet(out, visit, "all(", " == ", ")");
1372         }
1373     }
1374 }
1375 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1376 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1377 {
1378     if (type.isArray())
1379     {
1380         const TString &functionName = addArrayAssignmentFunction(type);
1381         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1382     }
1383     else
1384     {
1385         outputTriplet(out, visit, "(", " = ", ")");
1386     }
1387 }
1388 
ancestorEvaluatesToSamplerInStruct()1389 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1390 {
1391     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1392     {
1393         TIntermNode *ancestor               = getAncestorNode(n);
1394         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1395         if (ancestorBinary == nullptr)
1396         {
1397             return false;
1398         }
1399         switch (ancestorBinary->getOp())
1400         {
1401             case EOpIndexDirectStruct:
1402             {
1403                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1404                 const TIntermConstantUnion *index =
1405                     ancestorBinary->getRight()->getAsConstantUnion();
1406                 const TField *field = structure->fields()[index->getIConst(0)];
1407                 if (IsSampler(field->type()->getBasicType()))
1408                 {
1409                     return true;
1410                 }
1411                 break;
1412             }
1413             case EOpIndexDirect:
1414                 break;
1415             default:
1416                 // Returning a sampler from indirect indexing is not supported.
1417                 return false;
1418         }
1419     }
1420     return false;
1421 }
1422 
visitSwizzle(Visit visit,TIntermSwizzle * node)1423 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1424 {
1425     TInfoSinkBase &out = getInfoSink();
1426     if (visit == PostVisit)
1427     {
1428         out << ".";
1429         node->writeOffsetsAsXYZW(&out);
1430     }
1431     return true;
1432 }
1433 
visitBinary(Visit visit,TIntermBinary * node)1434 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1435 {
1436     TInfoSinkBase &out = getInfoSink();
1437 
1438     switch (node->getOp())
1439     {
1440         case EOpComma:
1441             outputTriplet(out, visit, "(", ", ", ")");
1442             break;
1443         case EOpAssign:
1444             if (node->isArray())
1445             {
1446                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1447                 if (rightAgg != nullptr && rightAgg->isConstructor())
1448                 {
1449                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1450                     out << functionName << "(";
1451                     node->getLeft()->traverse(this);
1452                     TIntermSequence *seq = rightAgg->getSequence();
1453                     for (auto &arrayElement : *seq)
1454                     {
1455                         out << ", ";
1456                         arrayElement->traverse(this);
1457                     }
1458                     out << ")";
1459                     return false;
1460                 }
1461                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1462                 // function call is assigned.
1463                 ASSERT(rightAgg == nullptr);
1464             }
1465             // Assignment expressions with atomic functions should be transformed into atomic
1466             // function calls in HLSL.
1467             // e.g. original_value = atomicAdd(dest, value) should be translated into
1468             //      InterlockedAdd(dest, value, original_value);
1469             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1470             {
1471                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1472                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1473                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1474                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1475                 ASSERT(argumentSeq->size() >= 2u);
1476                 for (auto &argument : *argumentSeq)
1477                 {
1478                     argument->traverse(this);
1479                     out << ", ";
1480                 }
1481                 node->getLeft()->traverse(this);
1482                 out << ")";
1483                 return false;
1484             }
1485             else if (IsInShaderStorageBlock(node->getLeft()))
1486             {
1487                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1488                 out << ", ";
1489                 if (IsInShaderStorageBlock(node->getRight()))
1490                 {
1491                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1492                 }
1493                 else
1494                 {
1495                     node->getRight()->traverse(this);
1496                 }
1497 
1498                 out << ")";
1499                 return false;
1500             }
1501             else if (IsInShaderStorageBlock(node->getRight()))
1502             {
1503                 node->getLeft()->traverse(this);
1504                 out << " = ";
1505                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1506                 return false;
1507             }
1508 
1509             outputAssign(visit, node->getType(), out);
1510             break;
1511         case EOpInitialize:
1512             if (visit == PreVisit)
1513             {
1514                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1515                 ASSERT(symbolNode);
1516                 TIntermTyped *initializer = node->getRight();
1517 
1518                 // Global initializers must be constant at this point.
1519                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1520 
1521                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1522                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1523                 // new variable is created before the assignment is evaluated), so we need to
1524                 // convert
1525                 // this to "float t = x, x = t;".
1526                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1527                 {
1528                     // Skip initializing the rest of the expression
1529                     return false;
1530                 }
1531                 else if (writeConstantInitialization(out, symbolNode, initializer))
1532                 {
1533                     return false;
1534                 }
1535             }
1536             else if (visit == InVisit)
1537             {
1538                 out << " = ";
1539                 if (IsInShaderStorageBlock(node->getRight()))
1540                 {
1541                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1542                     return false;
1543                 }
1544             }
1545             break;
1546         case EOpAddAssign:
1547             outputTriplet(out, visit, "(", " += ", ")");
1548             break;
1549         case EOpSubAssign:
1550             outputTriplet(out, visit, "(", " -= ", ")");
1551             break;
1552         case EOpMulAssign:
1553             outputTriplet(out, visit, "(", " *= ", ")");
1554             break;
1555         case EOpVectorTimesScalarAssign:
1556             outputTriplet(out, visit, "(", " *= ", ")");
1557             break;
1558         case EOpMatrixTimesScalarAssign:
1559             outputTriplet(out, visit, "(", " *= ", ")");
1560             break;
1561         case EOpVectorTimesMatrixAssign:
1562             if (visit == PreVisit)
1563             {
1564                 out << "(";
1565             }
1566             else if (visit == InVisit)
1567             {
1568                 out << " = mul(";
1569                 node->getLeft()->traverse(this);
1570                 out << ", transpose(";
1571             }
1572             else
1573             {
1574                 out << ")))";
1575             }
1576             break;
1577         case EOpMatrixTimesMatrixAssign:
1578             if (visit == PreVisit)
1579             {
1580                 out << "(";
1581             }
1582             else if (visit == InVisit)
1583             {
1584                 out << " = transpose(mul(transpose(";
1585                 node->getLeft()->traverse(this);
1586                 out << "), transpose(";
1587             }
1588             else
1589             {
1590                 out << "))))";
1591             }
1592             break;
1593         case EOpDivAssign:
1594             outputTriplet(out, visit, "(", " /= ", ")");
1595             break;
1596         case EOpIModAssign:
1597             outputTriplet(out, visit, "(", " %= ", ")");
1598             break;
1599         case EOpBitShiftLeftAssign:
1600             outputTriplet(out, visit, "(", " <<= ", ")");
1601             break;
1602         case EOpBitShiftRightAssign:
1603             outputTriplet(out, visit, "(", " >>= ", ")");
1604             break;
1605         case EOpBitwiseAndAssign:
1606             outputTriplet(out, visit, "(", " &= ", ")");
1607             break;
1608         case EOpBitwiseXorAssign:
1609             outputTriplet(out, visit, "(", " ^= ", ")");
1610             break;
1611         case EOpBitwiseOrAssign:
1612             outputTriplet(out, visit, "(", " |= ", ")");
1613             break;
1614         case EOpIndexDirect:
1615         {
1616             const TType &leftType = node->getLeft()->getType();
1617             if (leftType.isInterfaceBlock())
1618             {
1619                 if (visit == PreVisit)
1620                 {
1621                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1622                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1623 
1624                     ASSERT(leftType.getQualifier() == EvqUniform);
1625                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1626                     {
1627                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1628                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1629                     }
1630                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1631                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1632                         instanceArraySymbol->getName(), arrayIndex);
1633                     return false;
1634                 }
1635             }
1636             else if (ancestorEvaluatesToSamplerInStruct())
1637             {
1638                 // All parts of an expression that access a sampler in a struct need to use _ as
1639                 // separator to access the sampler variable that has been moved out of the struct.
1640                 outputTriplet(out, visit, "", "_", "");
1641             }
1642             else if (IsAtomicCounter(leftType.getBasicType()))
1643             {
1644                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1645             }
1646             else
1647             {
1648                 outputTriplet(out, visit, "", "[", "]");
1649                 if (visit == PostVisit)
1650                 {
1651                     const TInterfaceBlock *interfaceBlock =
1652                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1653                     if (interfaceBlock && mUniformBlocksTranslatedToStructuredBuffers.count(
1654                                               interfaceBlock->uniqueId().get()) != 0)
1655                     {
1656                         // If the uniform block member's type is not structure, we had explicitly
1657                         // packed the member into a structure, so need to add an operator of field
1658                         // slection.
1659                         const TField *field    = interfaceBlock->fields()[0];
1660                         const TType *fieldType = field->type();
1661                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1662                             fieldType->isScalarArray())
1663                         {
1664                             out << "." << Decorate(field->name());
1665                         }
1666                     }
1667                 }
1668             }
1669         }
1670         break;
1671         case EOpIndexIndirect:
1672         {
1673             // We do not currently support indirect references to interface blocks
1674             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1675 
1676             const TType &leftType = node->getLeft()->getType();
1677             if (IsAtomicCounter(leftType.getBasicType()))
1678             {
1679                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1680             }
1681             else
1682             {
1683                 outputTriplet(out, visit, "", "[", "]");
1684                 if (visit == PostVisit)
1685                 {
1686                     const TInterfaceBlock *interfaceBlock =
1687                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1688                     if (interfaceBlock && mUniformBlocksTranslatedToStructuredBuffers.count(
1689                                               interfaceBlock->uniqueId().get()) != 0)
1690                     {
1691                         // If the uniform block member's type is not structure, we had explicitly
1692                         // packed the member into a structure, so need to add an operator of field
1693                         // slection.
1694                         const TField *field    = interfaceBlock->fields()[0];
1695                         const TType *fieldType = field->type();
1696                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1697                             fieldType->isScalarArray())
1698                         {
1699                             out << "." << Decorate(field->name());
1700                         }
1701                     }
1702                 }
1703             }
1704             break;
1705         }
1706         case EOpIndexDirectStruct:
1707         {
1708             const TStructure *structure       = node->getLeft()->getType().getStruct();
1709             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1710             const TField *field               = structure->fields()[index->getIConst(0)];
1711 
1712             // In cases where indexing returns a sampler, we need to access the sampler variable
1713             // that has been moved out of the struct.
1714             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1715             if (visit == PreVisit && indexingReturnsSampler)
1716             {
1717                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1718                 // This prefix is only output at the beginning of the indexing expression, which
1719                 // may have multiple parts.
1720                 out << "angle";
1721             }
1722             if (!indexingReturnsSampler)
1723             {
1724                 // All parts of an expression that access a sampler in a struct need to use _ as
1725                 // separator to access the sampler variable that has been moved out of the struct.
1726                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1727             }
1728             if (visit == InVisit)
1729             {
1730                 if (indexingReturnsSampler)
1731                 {
1732                     out << "_" << field->name();
1733                 }
1734                 else
1735                 {
1736                     out << "." << DecorateField(field->name(), *structure);
1737                 }
1738 
1739                 return false;
1740             }
1741         }
1742         break;
1743         case EOpIndexDirectInterfaceBlock:
1744         {
1745             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1746             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1747                                               IsInStd140UniformBlock(node->getLeft()) &&
1748                                               needStructMapping(node);
1749             if (visit == PreVisit && structInStd140UniformBlock)
1750             {
1751                 mNeedStructMapping = true;
1752                 out << "map";
1753             }
1754             if (visit == InVisit)
1755             {
1756                 const TInterfaceBlock *interfaceBlock =
1757                     node->getLeft()->getType().getInterfaceBlock();
1758                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1759                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1760                 if (structInStd140UniformBlock || mUniformBlocksTranslatedToStructuredBuffers.count(
1761                                                       interfaceBlock->uniqueId().get()) != 0)
1762                 {
1763                     out << "_";
1764                 }
1765                 else
1766                 {
1767                     out << ".";
1768                 }
1769                 out << Decorate(field->name());
1770 
1771                 return false;
1772             }
1773             break;
1774         }
1775         case EOpAdd:
1776             outputTriplet(out, visit, "(", " + ", ")");
1777             break;
1778         case EOpSub:
1779             outputTriplet(out, visit, "(", " - ", ")");
1780             break;
1781         case EOpMul:
1782             outputTriplet(out, visit, "(", " * ", ")");
1783             break;
1784         case EOpDiv:
1785             outputTriplet(out, visit, "(", " / ", ")");
1786             break;
1787         case EOpIMod:
1788             outputTriplet(out, visit, "(", " % ", ")");
1789             break;
1790         case EOpBitShiftLeft:
1791             outputTriplet(out, visit, "(", " << ", ")");
1792             break;
1793         case EOpBitShiftRight:
1794             outputTriplet(out, visit, "(", " >> ", ")");
1795             break;
1796         case EOpBitwiseAnd:
1797             outputTriplet(out, visit, "(", " & ", ")");
1798             break;
1799         case EOpBitwiseXor:
1800             outputTriplet(out, visit, "(", " ^ ", ")");
1801             break;
1802         case EOpBitwiseOr:
1803             outputTriplet(out, visit, "(", " | ", ")");
1804             break;
1805         case EOpEqual:
1806         case EOpNotEqual:
1807             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1808             break;
1809         case EOpLessThan:
1810             outputTriplet(out, visit, "(", " < ", ")");
1811             break;
1812         case EOpGreaterThan:
1813             outputTriplet(out, visit, "(", " > ", ")");
1814             break;
1815         case EOpLessThanEqual:
1816             outputTriplet(out, visit, "(", " <= ", ")");
1817             break;
1818         case EOpGreaterThanEqual:
1819             outputTriplet(out, visit, "(", " >= ", ")");
1820             break;
1821         case EOpVectorTimesScalar:
1822             outputTriplet(out, visit, "(", " * ", ")");
1823             break;
1824         case EOpMatrixTimesScalar:
1825             outputTriplet(out, visit, "(", " * ", ")");
1826             break;
1827         case EOpVectorTimesMatrix:
1828             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1829             break;
1830         case EOpMatrixTimesVector:
1831             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1832             break;
1833         case EOpMatrixTimesMatrix:
1834             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1835             break;
1836         case EOpLogicalOr:
1837             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1838             // been unfolded.
1839             ASSERT(!node->getRight()->hasSideEffects());
1840             outputTriplet(out, visit, "(", " || ", ")");
1841             return true;
1842         case EOpLogicalXor:
1843             mUsesXor = true;
1844             outputTriplet(out, visit, "xor(", ", ", ")");
1845             break;
1846         case EOpLogicalAnd:
1847             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1848             // been unfolded.
1849             ASSERT(!node->getRight()->hasSideEffects());
1850             outputTriplet(out, visit, "(", " && ", ")");
1851             return true;
1852         default:
1853             UNREACHABLE();
1854     }
1855 
1856     return true;
1857 }
1858 
visitUnary(Visit visit,TIntermUnary * node)1859 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1860 {
1861     TInfoSinkBase &out = getInfoSink();
1862 
1863     switch (node->getOp())
1864     {
1865         case EOpNegative:
1866             outputTriplet(out, visit, "(-", "", ")");
1867             break;
1868         case EOpPositive:
1869             outputTriplet(out, visit, "(+", "", ")");
1870             break;
1871         case EOpLogicalNot:
1872             outputTriplet(out, visit, "(!", "", ")");
1873             break;
1874         case EOpBitwiseNot:
1875             outputTriplet(out, visit, "(~", "", ")");
1876             break;
1877         case EOpPostIncrement:
1878             outputTriplet(out, visit, "(", "", "++)");
1879             break;
1880         case EOpPostDecrement:
1881             outputTriplet(out, visit, "(", "", "--)");
1882             break;
1883         case EOpPreIncrement:
1884             outputTriplet(out, visit, "(++", "", ")");
1885             break;
1886         case EOpPreDecrement:
1887             outputTriplet(out, visit, "(--", "", ")");
1888             break;
1889         case EOpRadians:
1890             outputTriplet(out, visit, "radians(", "", ")");
1891             break;
1892         case EOpDegrees:
1893             outputTriplet(out, visit, "degrees(", "", ")");
1894             break;
1895         case EOpSin:
1896             outputTriplet(out, visit, "sin(", "", ")");
1897             break;
1898         case EOpCos:
1899             outputTriplet(out, visit, "cos(", "", ")");
1900             break;
1901         case EOpTan:
1902             outputTriplet(out, visit, "tan(", "", ")");
1903             break;
1904         case EOpAsin:
1905             outputTriplet(out, visit, "asin(", "", ")");
1906             break;
1907         case EOpAcos:
1908             outputTriplet(out, visit, "acos(", "", ")");
1909             break;
1910         case EOpAtan:
1911             outputTriplet(out, visit, "atan(", "", ")");
1912             break;
1913         case EOpSinh:
1914             outputTriplet(out, visit, "sinh(", "", ")");
1915             break;
1916         case EOpCosh:
1917             outputTriplet(out, visit, "cosh(", "", ")");
1918             break;
1919         case EOpTanh:
1920         case EOpAsinh:
1921         case EOpAcosh:
1922         case EOpAtanh:
1923             ASSERT(node->getUseEmulatedFunction());
1924             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1925             break;
1926         case EOpExp:
1927             outputTriplet(out, visit, "exp(", "", ")");
1928             break;
1929         case EOpLog:
1930             outputTriplet(out, visit, "log(", "", ")");
1931             break;
1932         case EOpExp2:
1933             outputTriplet(out, visit, "exp2(", "", ")");
1934             break;
1935         case EOpLog2:
1936             outputTriplet(out, visit, "log2(", "", ")");
1937             break;
1938         case EOpSqrt:
1939             outputTriplet(out, visit, "sqrt(", "", ")");
1940             break;
1941         case EOpInversesqrt:
1942             outputTriplet(out, visit, "rsqrt(", "", ")");
1943             break;
1944         case EOpAbs:
1945             outputTriplet(out, visit, "abs(", "", ")");
1946             break;
1947         case EOpSign:
1948             outputTriplet(out, visit, "sign(", "", ")");
1949             break;
1950         case EOpFloor:
1951             outputTriplet(out, visit, "floor(", "", ")");
1952             break;
1953         case EOpTrunc:
1954             outputTriplet(out, visit, "trunc(", "", ")");
1955             break;
1956         case EOpRound:
1957             outputTriplet(out, visit, "round(", "", ")");
1958             break;
1959         case EOpRoundEven:
1960             ASSERT(node->getUseEmulatedFunction());
1961             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1962             break;
1963         case EOpCeil:
1964             outputTriplet(out, visit, "ceil(", "", ")");
1965             break;
1966         case EOpFract:
1967             outputTriplet(out, visit, "frac(", "", ")");
1968             break;
1969         case EOpIsnan:
1970             if (node->getUseEmulatedFunction())
1971                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1972             else
1973                 outputTriplet(out, visit, "isnan(", "", ")");
1974             mRequiresIEEEStrictCompiling = true;
1975             break;
1976         case EOpIsinf:
1977             outputTriplet(out, visit, "isinf(", "", ")");
1978             break;
1979         case EOpFloatBitsToInt:
1980             outputTriplet(out, visit, "asint(", "", ")");
1981             break;
1982         case EOpFloatBitsToUint:
1983             outputTriplet(out, visit, "asuint(", "", ")");
1984             break;
1985         case EOpIntBitsToFloat:
1986             outputTriplet(out, visit, "asfloat(", "", ")");
1987             break;
1988         case EOpUintBitsToFloat:
1989             outputTriplet(out, visit, "asfloat(", "", ")");
1990             break;
1991         case EOpPackSnorm2x16:
1992         case EOpPackUnorm2x16:
1993         case EOpPackHalf2x16:
1994         case EOpUnpackSnorm2x16:
1995         case EOpUnpackUnorm2x16:
1996         case EOpUnpackHalf2x16:
1997         case EOpPackUnorm4x8:
1998         case EOpPackSnorm4x8:
1999         case EOpUnpackUnorm4x8:
2000         case EOpUnpackSnorm4x8:
2001             ASSERT(node->getUseEmulatedFunction());
2002             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2003             break;
2004         case EOpLength:
2005             outputTriplet(out, visit, "length(", "", ")");
2006             break;
2007         case EOpNormalize:
2008             outputTriplet(out, visit, "normalize(", "", ")");
2009             break;
2010         case EOpDFdx:
2011             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2012             {
2013                 outputTriplet(out, visit, "(", "", ", 0.0)");
2014             }
2015             else
2016             {
2017                 outputTriplet(out, visit, "ddx(", "", ")");
2018             }
2019             break;
2020         case EOpDFdy:
2021             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2022             {
2023                 outputTriplet(out, visit, "(", "", ", 0.0)");
2024             }
2025             else
2026             {
2027                 outputTriplet(out, visit, "ddy(", "", ")");
2028             }
2029             break;
2030         case EOpFwidth:
2031             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2032             {
2033                 outputTriplet(out, visit, "(", "", ", 0.0)");
2034             }
2035             else
2036             {
2037                 outputTriplet(out, visit, "fwidth(", "", ")");
2038             }
2039             break;
2040         case EOpTranspose:
2041             outputTriplet(out, visit, "transpose(", "", ")");
2042             break;
2043         case EOpDeterminant:
2044             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2045             break;
2046         case EOpInverse:
2047             ASSERT(node->getUseEmulatedFunction());
2048             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2049             break;
2050 
2051         case EOpAny:
2052             outputTriplet(out, visit, "any(", "", ")");
2053             break;
2054         case EOpAll:
2055             outputTriplet(out, visit, "all(", "", ")");
2056             break;
2057         case EOpLogicalNotComponentWise:
2058             outputTriplet(out, visit, "(!", "", ")");
2059             break;
2060         case EOpBitfieldReverse:
2061             outputTriplet(out, visit, "reversebits(", "", ")");
2062             break;
2063         case EOpBitCount:
2064             outputTriplet(out, visit, "countbits(", "", ")");
2065             break;
2066         case EOpFindLSB:
2067             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2068             // in GLSLTest and results are consistent with GL.
2069             outputTriplet(out, visit, "firstbitlow(", "", ")");
2070             break;
2071         case EOpFindMSB:
2072             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2073             // tested in GLSLTest and results are consistent with GL.
2074             outputTriplet(out, visit, "firstbithigh(", "", ")");
2075             break;
2076         case EOpArrayLength:
2077         {
2078             TIntermTyped *operand = node->getOperand();
2079             ASSERT(IsInShaderStorageBlock(operand));
2080             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2081             return false;
2082         }
2083         default:
2084             UNREACHABLE();
2085     }
2086 
2087     return true;
2088 }
2089 
samplerNamePrefixFromStruct(TIntermTyped * node)2090 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2091 {
2092     if (node->getAsSymbolNode())
2093     {
2094         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2095         return node->getAsSymbolNode()->getName();
2096     }
2097     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2098     switch (nodeBinary->getOp())
2099     {
2100         case EOpIndexDirect:
2101         {
2102             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2103 
2104             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2105             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2106             return ImmutableString(prefixSink.str());
2107         }
2108         case EOpIndexDirectStruct:
2109         {
2110             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2111             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2112             const TField *field = s->fields()[index];
2113 
2114             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2115             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2116                        << field->name();
2117             return ImmutableString(prefixSink.str());
2118         }
2119         default:
2120             UNREACHABLE();
2121             return kEmptyImmutableString;
2122     }
2123 }
2124 
visitBlock(Visit visit,TIntermBlock * node)2125 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2126 {
2127     TInfoSinkBase &out = getInfoSink();
2128 
2129     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2130 
2131     if (mInsideFunction)
2132     {
2133         outputLineDirective(out, node->getLine().first_line);
2134         out << "{\n";
2135         if (isMainBlock)
2136         {
2137             if (mShaderType == GL_COMPUTE_SHADER)
2138             {
2139                 out << "initGLBuiltins(input);\n";
2140             }
2141             else
2142             {
2143                 out << "@@ MAIN PROLOGUE @@\n";
2144             }
2145         }
2146     }
2147 
2148     for (TIntermNode *statement : *node->getSequence())
2149     {
2150         outputLineDirective(out, statement->getLine().first_line);
2151 
2152         statement->traverse(this);
2153 
2154         // Don't output ; after case labels, they're terminated by :
2155         // This is needed especially since outputting a ; after a case statement would turn empty
2156         // case statements into non-empty case statements, disallowing fall-through from them.
2157         // Also the output code is clearer if we don't output ; after statements where it is not
2158         // needed:
2159         //  * if statements
2160         //  * switch statements
2161         //  * blocks
2162         //  * function definitions
2163         //  * loops (do-while loops output the semicolon in VisitLoop)
2164         //  * declarations that don't generate output.
2165         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2166             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2167             statement->getAsSwitchNode() == nullptr &&
2168             statement->getAsFunctionDefinition() == nullptr &&
2169             (statement->getAsDeclarationNode() == nullptr ||
2170              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2171             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2172         {
2173             out << ";\n";
2174         }
2175     }
2176 
2177     if (mInsideFunction)
2178     {
2179         outputLineDirective(out, node->getLine().last_line);
2180         if (isMainBlock && shaderNeedsGenerateOutput())
2181         {
2182             // We could have an empty main, a main function without a branch at the end, or a main
2183             // function with a discard statement at the end. In these cases we need to add a return
2184             // statement.
2185             bool needReturnStatement =
2186                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2187                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2188             if (needReturnStatement)
2189             {
2190                 out << "return " << generateOutputCall() << ";\n";
2191             }
2192         }
2193         out << "}\n";
2194     }
2195 
2196     return false;
2197 }
2198 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2199 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2200 {
2201     TInfoSinkBase &out = getInfoSink();
2202 
2203     ASSERT(mCurrentFunctionMetadata == nullptr);
2204 
2205     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2206     ASSERT(index != CallDAG::InvalidIndex);
2207     mCurrentFunctionMetadata = &mASTMetadataList[index];
2208 
2209     const TFunction *func = node->getFunction();
2210 
2211     if (func->isMain())
2212     {
2213         // The stub strings below are replaced when shader is dynamically defined by its layout:
2214         switch (mShaderType)
2215         {
2216             case GL_VERTEX_SHADER:
2217                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2218                     << "@@ VERTEX OUTPUT @@\n\n"
2219                     << "VS_OUTPUT main(VS_INPUT input)";
2220                 break;
2221             case GL_FRAGMENT_SHADER:
2222                 out << "@@ PIXEL OUTPUT @@\n\n"
2223                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2224                 break;
2225             case GL_COMPUTE_SHADER:
2226                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2227                     << mWorkGroupSize[2] << ")]\n";
2228                 out << "void main(CS_INPUT input)";
2229                 break;
2230             default:
2231                 UNREACHABLE();
2232                 break;
2233         }
2234     }
2235     else
2236     {
2237         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2238         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2239             << (mOutputLod0Function ? "Lod0(" : "(");
2240 
2241         size_t paramCount = func->getParamCount();
2242         for (unsigned int i = 0; i < paramCount; i++)
2243         {
2244             const TVariable *param = func->getParam(i);
2245             ensureStructDefined(param->getType());
2246 
2247             writeParameter(param, out);
2248 
2249             if (i < paramCount - 1)
2250             {
2251                 out << ", ";
2252             }
2253         }
2254 
2255         out << ")\n";
2256     }
2257 
2258     mInsideFunction = true;
2259     if (func->isMain())
2260     {
2261         mInsideMain = true;
2262     }
2263     // The function body node will output braces.
2264     node->getBody()->traverse(this);
2265     mInsideFunction = false;
2266     mInsideMain     = false;
2267 
2268     mCurrentFunctionMetadata = nullptr;
2269 
2270     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2271     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2272     {
2273         ASSERT(!node->getFunction()->isMain());
2274         mOutputLod0Function = true;
2275         node->traverse(this);
2276         mOutputLod0Function = false;
2277     }
2278 
2279     return false;
2280 }
2281 
visitDeclaration(Visit visit,TIntermDeclaration * node)2282 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2283 {
2284     if (visit == PreVisit)
2285     {
2286         TIntermSequence *sequence = node->getSequence();
2287         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2288         ASSERT(sequence->size() == 1);
2289         ASSERT(declarator);
2290 
2291         if (IsDeclarationWrittenOut(node))
2292         {
2293             TInfoSinkBase &out = getInfoSink();
2294             ensureStructDefined(declarator->getType());
2295 
2296             if (!declarator->getAsSymbolNode() ||
2297                 declarator->getAsSymbolNode()->variable().symbolType() !=
2298                     SymbolType::Empty)  // Variable declaration
2299             {
2300                 if (declarator->getQualifier() == EvqShared)
2301                 {
2302                     out << "groupshared ";
2303                 }
2304                 else if (!mInsideFunction)
2305                 {
2306                     out << "static ";
2307                 }
2308 
2309                 out << TypeString(declarator->getType()) + " ";
2310 
2311                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2312 
2313                 if (symbol)
2314                 {
2315                     symbol->traverse(this);
2316                     out << ArrayString(symbol->getType());
2317                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2318                     // drivers to compile a compute shader if we add code to initialize a
2319                     // groupshared array variable with a large array size. And maybe produce
2320                     // incorrect result. See http://anglebug.com/3226.
2321                     if (declarator->getQualifier() != EvqShared)
2322                     {
2323                         out << " = " + zeroInitializer(symbol->getType());
2324                     }
2325                 }
2326                 else
2327                 {
2328                     declarator->traverse(this);
2329                 }
2330             }
2331         }
2332         else if (IsVaryingOut(declarator->getQualifier()))
2333         {
2334             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2335             ASSERT(symbol);  // Varying declarations can't have initializers.
2336 
2337             const TVariable &variable = symbol->variable();
2338 
2339             if (variable.symbolType() != SymbolType::Empty)
2340             {
2341                 // Vertex outputs which are declared but not written to should still be declared to
2342                 // allow successful linking.
2343                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2344             }
2345         }
2346     }
2347     return false;
2348 }
2349 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2350 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2351                                                  TIntermGlobalQualifierDeclaration *node)
2352 {
2353     // Do not do any translation
2354     return false;
2355 }
2356 
visitFunctionPrototype(TIntermFunctionPrototype * node)2357 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2358 {
2359     TInfoSinkBase &out = getInfoSink();
2360 
2361     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2362     // Skip the prototype if it is not implemented (and thus not used)
2363     if (index == CallDAG::InvalidIndex)
2364     {
2365         return;
2366     }
2367 
2368     const TFunction *func = node->getFunction();
2369 
2370     TString name = DecorateFunctionIfNeeded(func);
2371     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2372         << (mOutputLod0Function ? "Lod0(" : "(");
2373 
2374     size_t paramCount = func->getParamCount();
2375     for (unsigned int i = 0; i < paramCount; i++)
2376     {
2377         writeParameter(func->getParam(i), out);
2378 
2379         if (i < paramCount - 1)
2380         {
2381             out << ", ";
2382         }
2383     }
2384 
2385     out << ");\n";
2386 
2387     // Also prototype the Lod0 variant if needed
2388     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2389     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2390     {
2391         mOutputLod0Function = true;
2392         node->traverse(this);
2393         mOutputLod0Function = false;
2394     }
2395 }
2396 
visitAggregate(Visit visit,TIntermAggregate * node)2397 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2398 {
2399     TInfoSinkBase &out = getInfoSink();
2400 
2401     switch (node->getOp())
2402     {
2403         case EOpCallBuiltInFunction:
2404         case EOpCallFunctionInAST:
2405         case EOpCallInternalRawFunction:
2406         {
2407             TIntermSequence *arguments = node->getSequence();
2408 
2409             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2410                         mShaderType == GL_FRAGMENT_SHADER;
2411             if (node->getOp() == EOpCallFunctionInAST)
2412             {
2413                 if (node->isArray())
2414                 {
2415                     UNIMPLEMENTED();
2416                 }
2417                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2418                 ASSERT(index != CallDAG::InvalidIndex);
2419                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2420 
2421                 out << DecorateFunctionIfNeeded(node->getFunction());
2422                 out << DisambiguateFunctionName(node->getSequence());
2423                 out << (lod0 ? "Lod0(" : "(");
2424             }
2425             else if (node->getOp() == EOpCallInternalRawFunction)
2426             {
2427                 // This path is used for internal functions that don't have their definitions in the
2428                 // AST, such as precision emulation functions.
2429                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2430             }
2431             else if (node->getFunction()->isImageFunction())
2432             {
2433                 const ImmutableString &name              = node->getFunction()->name();
2434                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2435                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2436                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2437                     type.getMemoryQualifier().readonly);
2438                 out << imageFunctionName << "(";
2439             }
2440             else if (node->getFunction()->isAtomicCounterFunction())
2441             {
2442                 const ImmutableString &name = node->getFunction()->name();
2443                 ImmutableString atomicFunctionName =
2444                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2445                 out << atomicFunctionName << "(";
2446             }
2447             else
2448             {
2449                 const ImmutableString &name = node->getFunction()->name();
2450                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2451                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2452                 if (arguments->size() > 1)
2453                 {
2454                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2455                 }
2456                 const ImmutableString &textureFunctionName =
2457                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2458                                                              arguments->size(), lod0, mShaderType);
2459                 out << textureFunctionName << "(";
2460             }
2461 
2462             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2463             {
2464                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2465                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2466                 {
2467                     out << "texture_";
2468                     (*arg)->traverse(this);
2469                     out << ", sampler_";
2470                 }
2471 
2472                 (*arg)->traverse(this);
2473 
2474                 if (typedArg->getType().isStructureContainingSamplers())
2475                 {
2476                     const TType &argType = typedArg->getType();
2477                     TVector<const TVariable *> samplerSymbols;
2478                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2479                     std::string namePrefix     = "angle_";
2480                     namePrefix += structName.data();
2481                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2482                                                  nullptr, mSymbolTable);
2483                     for (const TVariable *sampler : samplerSymbols)
2484                     {
2485                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2486                         {
2487                             out << ", texture_" << sampler->name();
2488                             out << ", sampler_" << sampler->name();
2489                         }
2490                         else
2491                         {
2492                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2493                             // of D3D9, it's the sampler variable.
2494                             out << ", " << sampler->name();
2495                         }
2496                     }
2497                 }
2498 
2499                 if (arg < arguments->end() - 1)
2500                 {
2501                     out << ", ";
2502                 }
2503             }
2504 
2505             out << ")";
2506 
2507             return false;
2508         }
2509         case EOpConstruct:
2510             outputConstructor(out, visit, node);
2511             break;
2512         case EOpEqualComponentWise:
2513             outputTriplet(out, visit, "(", " == ", ")");
2514             break;
2515         case EOpNotEqualComponentWise:
2516             outputTriplet(out, visit, "(", " != ", ")");
2517             break;
2518         case EOpLessThanComponentWise:
2519             outputTriplet(out, visit, "(", " < ", ")");
2520             break;
2521         case EOpGreaterThanComponentWise:
2522             outputTriplet(out, visit, "(", " > ", ")");
2523             break;
2524         case EOpLessThanEqualComponentWise:
2525             outputTriplet(out, visit, "(", " <= ", ")");
2526             break;
2527         case EOpGreaterThanEqualComponentWise:
2528             outputTriplet(out, visit, "(", " >= ", ")");
2529             break;
2530         case EOpMod:
2531             ASSERT(node->getUseEmulatedFunction());
2532             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2533             break;
2534         case EOpModf:
2535             outputTriplet(out, visit, "modf(", ", ", ")");
2536             break;
2537         case EOpPow:
2538             outputTriplet(out, visit, "pow(", ", ", ")");
2539             break;
2540         case EOpAtan:
2541             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2542             ASSERT(node->getUseEmulatedFunction());
2543             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2544             break;
2545         case EOpMin:
2546             outputTriplet(out, visit, "min(", ", ", ")");
2547             break;
2548         case EOpMax:
2549             outputTriplet(out, visit, "max(", ", ", ")");
2550             break;
2551         case EOpClamp:
2552             outputTriplet(out, visit, "clamp(", ", ", ")");
2553             break;
2554         case EOpMix:
2555         {
2556             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2557             if (lastParamNode->getType().getBasicType() == EbtBool)
2558             {
2559                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2560                 // y, genBType a)",
2561                 // so use emulated version.
2562                 ASSERT(node->getUseEmulatedFunction());
2563                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2564             }
2565             else
2566             {
2567                 outputTriplet(out, visit, "lerp(", ", ", ")");
2568             }
2569             break;
2570         }
2571         case EOpStep:
2572             outputTriplet(out, visit, "step(", ", ", ")");
2573             break;
2574         case EOpSmoothstep:
2575             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2576             break;
2577         case EOpFma:
2578             outputTriplet(out, visit, "mad(", ", ", ")");
2579             break;
2580         case EOpFrexp:
2581         case EOpLdexp:
2582             ASSERT(node->getUseEmulatedFunction());
2583             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2584             break;
2585         case EOpDistance:
2586             outputTriplet(out, visit, "distance(", ", ", ")");
2587             break;
2588         case EOpDot:
2589             outputTriplet(out, visit, "dot(", ", ", ")");
2590             break;
2591         case EOpCross:
2592             outputTriplet(out, visit, "cross(", ", ", ")");
2593             break;
2594         case EOpFaceforward:
2595             ASSERT(node->getUseEmulatedFunction());
2596             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2597             break;
2598         case EOpReflect:
2599             outputTriplet(out, visit, "reflect(", ", ", ")");
2600             break;
2601         case EOpRefract:
2602             outputTriplet(out, visit, "refract(", ", ", ")");
2603             break;
2604         case EOpOuterProduct:
2605             ASSERT(node->getUseEmulatedFunction());
2606             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2607             break;
2608         case EOpMulMatrixComponentWise:
2609             outputTriplet(out, visit, "(", " * ", ")");
2610             break;
2611         case EOpBitfieldExtract:
2612         case EOpBitfieldInsert:
2613         case EOpUaddCarry:
2614         case EOpUsubBorrow:
2615         case EOpUmulExtended:
2616         case EOpImulExtended:
2617             ASSERT(node->getUseEmulatedFunction());
2618             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2619             break;
2620         case EOpBarrier:
2621             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2622             // cheapest *WithGroupSync() function, without any functionality loss, but
2623             // with the potential for severe performance loss.
2624             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2625             break;
2626         case EOpMemoryBarrierShared:
2627             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2628             break;
2629         case EOpMemoryBarrierAtomicCounter:
2630         case EOpMemoryBarrierBuffer:
2631         case EOpMemoryBarrierImage:
2632             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2633             break;
2634         case EOpGroupMemoryBarrier:
2635         case EOpMemoryBarrier:
2636             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2637             break;
2638 
2639         // Single atomic function calls without return value.
2640         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2641         case EOpAtomicAdd:
2642         case EOpAtomicMin:
2643         case EOpAtomicMax:
2644         case EOpAtomicAnd:
2645         case EOpAtomicOr:
2646         case EOpAtomicXor:
2647         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2648         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2649         // optional.
2650         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2651         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2652         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2653         // compare_value, value) should all be modified into the form of "int temp; temp =
2654         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2655         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2656         case EOpAtomicExchange:
2657         case EOpAtomicCompSwap:
2658         {
2659             ASSERT(node->getChildCount() > 1);
2660             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2661             if (IsInShaderStorageBlock(memNode))
2662             {
2663                 // Atomic memory functions for SSBO.
2664                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2665                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2666                 // Write the rest argument list to |out|.
2667                 for (size_t i = 1; i < node->getChildCount(); i++)
2668                 {
2669                     out << ", ";
2670                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2671                     if (IsInShaderStorageBlock(argument))
2672                     {
2673                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2674                     }
2675                     else
2676                     {
2677                         argument->traverse(this);
2678                     }
2679                 }
2680 
2681                 out << ")";
2682                 return false;
2683             }
2684             else
2685             {
2686                 // Atomic memory functions for shared variable.
2687                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2688                 {
2689                     outputTriplet(out, visit,
2690                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2691                                   ")");
2692                 }
2693                 else
2694                 {
2695                     UNREACHABLE();
2696                 }
2697             }
2698 
2699             break;
2700         }
2701         default:
2702             UNREACHABLE();
2703     }
2704 
2705     return true;
2706 }
2707 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2708 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2709 {
2710     out << "if (";
2711 
2712     node->getCondition()->traverse(this);
2713 
2714     out << ")\n";
2715 
2716     outputLineDirective(out, node->getLine().first_line);
2717 
2718     bool discard = false;
2719 
2720     if (node->getTrueBlock())
2721     {
2722         // The trueBlock child node will output braces.
2723         node->getTrueBlock()->traverse(this);
2724 
2725         // Detect true discard
2726         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2727     }
2728     else
2729     {
2730         // TODO(oetuaho): Check if the semicolon inside is necessary.
2731         // It's there as a result of conservative refactoring of the output.
2732         out << "{;}\n";
2733     }
2734 
2735     outputLineDirective(out, node->getLine().first_line);
2736 
2737     if (node->getFalseBlock())
2738     {
2739         out << "else\n";
2740 
2741         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2742 
2743         // The falseBlock child node will output braces.
2744         node->getFalseBlock()->traverse(this);
2745 
2746         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2747 
2748         // Detect false discard
2749         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2750     }
2751 
2752     // ANGLE issue 486: Detect problematic conditional discard
2753     if (discard)
2754     {
2755         mUsesDiscardRewriting = true;
2756     }
2757 }
2758 
visitTernary(Visit,TIntermTernary *)2759 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2760 {
2761     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2762     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2763     UNREACHABLE();
2764     return false;
2765 }
2766 
visitIfElse(Visit visit,TIntermIfElse * node)2767 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2768 {
2769     TInfoSinkBase &out = getInfoSink();
2770 
2771     ASSERT(mInsideFunction);
2772 
2773     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2774     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2775     {
2776         out << "FLATTEN ";
2777     }
2778 
2779     writeIfElse(out, node);
2780 
2781     return false;
2782 }
2783 
visitSwitch(Visit visit,TIntermSwitch * node)2784 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2785 {
2786     TInfoSinkBase &out = getInfoSink();
2787 
2788     ASSERT(node->getStatementList());
2789     if (visit == PreVisit)
2790     {
2791         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2792     }
2793     outputTriplet(out, visit, "switch (", ") ", "");
2794     // The curly braces get written when visiting the statementList block.
2795     return true;
2796 }
2797 
visitCase(Visit visit,TIntermCase * node)2798 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2799 {
2800     TInfoSinkBase &out = getInfoSink();
2801 
2802     if (node->hasCondition())
2803     {
2804         outputTriplet(out, visit, "case (", "", "):\n");
2805         return true;
2806     }
2807     else
2808     {
2809         out << "default:\n";
2810         return false;
2811     }
2812 }
2813 
visitConstantUnion(TIntermConstantUnion * node)2814 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2815 {
2816     TInfoSinkBase &out = getInfoSink();
2817     writeConstantUnion(out, node->getType(), node->getConstantValue());
2818 }
2819 
visitLoop(Visit visit,TIntermLoop * node)2820 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2821 {
2822     mNestedLoopDepth++;
2823 
2824     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2825     mInsideDiscontinuousLoop =
2826         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2827 
2828     TInfoSinkBase &out = getInfoSink();
2829 
2830     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2831     {
2832         if (handleExcessiveLoop(out, node))
2833         {
2834             mInsideDiscontinuousLoop = wasDiscontinuous;
2835             mNestedLoopDepth--;
2836 
2837             return false;
2838         }
2839     }
2840 
2841     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2842     if (node->getType() == ELoopDoWhile)
2843     {
2844         out << "{" << unroll << " do\n";
2845 
2846         outputLineDirective(out, node->getLine().first_line);
2847     }
2848     else
2849     {
2850         out << "{" << unroll << " for(";
2851 
2852         if (node->getInit())
2853         {
2854             node->getInit()->traverse(this);
2855         }
2856 
2857         out << "; ";
2858 
2859         if (node->getCondition())
2860         {
2861             node->getCondition()->traverse(this);
2862         }
2863 
2864         out << "; ";
2865 
2866         if (node->getExpression())
2867         {
2868             node->getExpression()->traverse(this);
2869         }
2870 
2871         out << ")\n";
2872 
2873         outputLineDirective(out, node->getLine().first_line);
2874     }
2875 
2876     if (node->getBody())
2877     {
2878         // The loop body node will output braces.
2879         node->getBody()->traverse(this);
2880     }
2881     else
2882     {
2883         // TODO(oetuaho): Check if the semicolon inside is necessary.
2884         // It's there as a result of conservative refactoring of the output.
2885         out << "{;}\n";
2886     }
2887 
2888     outputLineDirective(out, node->getLine().first_line);
2889 
2890     if (node->getType() == ELoopDoWhile)
2891     {
2892         outputLineDirective(out, node->getCondition()->getLine().first_line);
2893         out << "while (";
2894 
2895         node->getCondition()->traverse(this);
2896 
2897         out << ");\n";
2898     }
2899 
2900     out << "}\n";
2901 
2902     mInsideDiscontinuousLoop = wasDiscontinuous;
2903     mNestedLoopDepth--;
2904 
2905     return false;
2906 }
2907 
visitBranch(Visit visit,TIntermBranch * node)2908 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2909 {
2910     if (visit == PreVisit)
2911     {
2912         TInfoSinkBase &out = getInfoSink();
2913 
2914         switch (node->getFlowOp())
2915         {
2916             case EOpKill:
2917                 out << "discard";
2918                 break;
2919             case EOpBreak:
2920                 if (mNestedLoopDepth > 1)
2921                 {
2922                     mUsesNestedBreak = true;
2923                 }
2924 
2925                 if (mExcessiveLoopIndex)
2926                 {
2927                     out << "{Break";
2928                     mExcessiveLoopIndex->traverse(this);
2929                     out << " = true; break;}\n";
2930                 }
2931                 else
2932                 {
2933                     out << "break";
2934                 }
2935                 break;
2936             case EOpContinue:
2937                 out << "continue";
2938                 break;
2939             case EOpReturn:
2940                 if (node->getExpression())
2941                 {
2942                     ASSERT(!mInsideMain);
2943                     out << "return ";
2944                 }
2945                 else
2946                 {
2947                     if (mInsideMain && shaderNeedsGenerateOutput())
2948                     {
2949                         out << "return " << generateOutputCall();
2950                     }
2951                     else
2952                     {
2953                         out << "return";
2954                     }
2955                 }
2956                 break;
2957             default:
2958                 UNREACHABLE();
2959         }
2960     }
2961 
2962     return true;
2963 }
2964 
2965 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2966 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2967 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2968 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2969 {
2970     const int MAX_LOOP_ITERATIONS = 254;
2971 
2972     // Parse loops of the form:
2973     // for(int index = initial; index [comparator] limit; index += increment)
2974     TIntermSymbol *index = nullptr;
2975     TOperator comparator = EOpNull;
2976     int initial          = 0;
2977     int limit            = 0;
2978     int increment        = 0;
2979 
2980     // Parse index name and intial value
2981     if (node->getInit())
2982     {
2983         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2984 
2985         if (init)
2986         {
2987             TIntermSequence *sequence = init->getSequence();
2988             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2989 
2990             if (variable && variable->getQualifier() == EvqTemporary)
2991             {
2992                 TIntermBinary *assign = variable->getAsBinaryNode();
2993 
2994                 if (assign->getOp() == EOpInitialize)
2995                 {
2996                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2997                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2998 
2999                     if (symbol && constant)
3000                     {
3001                         if (constant->getBasicType() == EbtInt && constant->isScalar())
3002                         {
3003                             index   = symbol;
3004                             initial = constant->getIConst(0);
3005                         }
3006                     }
3007                 }
3008             }
3009         }
3010     }
3011 
3012     // Parse comparator and limit value
3013     if (index != nullptr && node->getCondition())
3014     {
3015         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3016 
3017         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3018         {
3019             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3020 
3021             if (constant)
3022             {
3023                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3024                 {
3025                     comparator = test->getOp();
3026                     limit      = constant->getIConst(0);
3027                 }
3028             }
3029         }
3030     }
3031 
3032     // Parse increment
3033     if (index != nullptr && comparator != EOpNull && node->getExpression())
3034     {
3035         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3036         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3037 
3038         if (binaryTerminal)
3039         {
3040             TOperator op                   = binaryTerminal->getOp();
3041             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3042 
3043             if (constant)
3044             {
3045                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3046                 {
3047                     int value = constant->getIConst(0);
3048 
3049                     switch (op)
3050                     {
3051                         case EOpAddAssign:
3052                             increment = value;
3053                             break;
3054                         case EOpSubAssign:
3055                             increment = -value;
3056                             break;
3057                         default:
3058                             UNIMPLEMENTED();
3059                     }
3060                 }
3061             }
3062         }
3063         else if (unaryTerminal)
3064         {
3065             TOperator op = unaryTerminal->getOp();
3066 
3067             switch (op)
3068             {
3069                 case EOpPostIncrement:
3070                     increment = 1;
3071                     break;
3072                 case EOpPostDecrement:
3073                     increment = -1;
3074                     break;
3075                 case EOpPreIncrement:
3076                     increment = 1;
3077                     break;
3078                 case EOpPreDecrement:
3079                     increment = -1;
3080                     break;
3081                 default:
3082                     UNIMPLEMENTED();
3083             }
3084         }
3085     }
3086 
3087     if (index != nullptr && comparator != EOpNull && increment != 0)
3088     {
3089         if (comparator == EOpLessThanEqual)
3090         {
3091             comparator = EOpLessThan;
3092             limit += 1;
3093         }
3094 
3095         if (comparator == EOpLessThan)
3096         {
3097             int iterations = (limit - initial) / increment;
3098 
3099             if (iterations <= MAX_LOOP_ITERATIONS)
3100             {
3101                 return false;  // Not an excessive loop
3102             }
3103 
3104             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3105             mExcessiveLoopIndex         = index;
3106 
3107             out << "{int ";
3108             index->traverse(this);
3109             out << ";\n"
3110                    "bool Break";
3111             index->traverse(this);
3112             out << " = false;\n";
3113 
3114             bool firstLoopFragment = true;
3115 
3116             while (iterations > 0)
3117             {
3118                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3119 
3120                 if (!firstLoopFragment)
3121                 {
3122                     out << "if (!Break";
3123                     index->traverse(this);
3124                     out << ") {\n";
3125                 }
3126 
3127                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3128                 {
3129                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3130                 }
3131 
3132                 // for(int index = initial; index < clampedLimit; index += increment)
3133                 const char *unroll =
3134                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3135 
3136                 out << unroll << " for(";
3137                 index->traverse(this);
3138                 out << " = ";
3139                 out << initial;
3140 
3141                 out << "; ";
3142                 index->traverse(this);
3143                 out << " < ";
3144                 out << clampedLimit;
3145 
3146                 out << "; ";
3147                 index->traverse(this);
3148                 out << " += ";
3149                 out << increment;
3150                 out << ")\n";
3151 
3152                 outputLineDirective(out, node->getLine().first_line);
3153                 out << "{\n";
3154 
3155                 if (node->getBody())
3156                 {
3157                     node->getBody()->traverse(this);
3158                 }
3159 
3160                 outputLineDirective(out, node->getLine().first_line);
3161                 out << ";}\n";
3162 
3163                 if (!firstLoopFragment)
3164                 {
3165                     out << "}\n";
3166                 }
3167 
3168                 firstLoopFragment = false;
3169 
3170                 initial += MAX_LOOP_ITERATIONS * increment;
3171                 iterations -= MAX_LOOP_ITERATIONS;
3172             }
3173 
3174             out << "}";
3175 
3176             mExcessiveLoopIndex = restoreIndex;
3177 
3178             return true;
3179         }
3180         else
3181             UNIMPLEMENTED();
3182     }
3183 
3184     return false;  // Not handled as an excessive loop
3185 }
3186 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3187 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3188                                Visit visit,
3189                                const char *preString,
3190                                const char *inString,
3191                                const char *postString)
3192 {
3193     if (visit == PreVisit)
3194     {
3195         out << preString;
3196     }
3197     else if (visit == InVisit)
3198     {
3199         out << inString;
3200     }
3201     else if (visit == PostVisit)
3202     {
3203         out << postString;
3204     }
3205 }
3206 
outputLineDirective(TInfoSinkBase & out,int line)3207 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3208 {
3209     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
3210     {
3211         out << "\n";
3212         out << "#line " << line;
3213 
3214         if (mSourcePath)
3215         {
3216             out << " \"" << mSourcePath << "\"";
3217         }
3218 
3219         out << "\n";
3220     }
3221 }
3222 
writeParameter(const TVariable * param,TInfoSinkBase & out)3223 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3224 {
3225     const TType &type    = param->getType();
3226     TQualifier qualifier = type.getQualifier();
3227 
3228     TString nameStr = DecorateVariableIfNeeded(*param);
3229     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3230 
3231     if (IsSampler(type.getBasicType()))
3232     {
3233         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3234         {
3235             // Samplers are passed as indices to the sampler array.
3236             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3237             out << "const uint " << nameStr << ArrayString(type);
3238             return;
3239         }
3240         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3241         {
3242             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3243                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3244                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3245                 << ArrayString(type);
3246             return;
3247         }
3248     }
3249 
3250     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3251     // buffer offset.
3252     if (IsAtomicCounter(type.getBasicType()))
3253     {
3254         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3255             << nameStr << "_offset";
3256     }
3257     else
3258     {
3259         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3260             << ArrayString(type);
3261     }
3262 
3263     // If the structure parameter contains samplers, they need to be passed into the function as
3264     // separate parameters. HLSL doesn't natively support samplers in structs.
3265     if (type.isStructureContainingSamplers())
3266     {
3267         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3268         TVector<const TVariable *> samplerSymbols;
3269         std::string namePrefix = "angle";
3270         namePrefix += nameStr.c_str();
3271         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3272                                   mSymbolTable);
3273         for (const TVariable *sampler : samplerSymbols)
3274         {
3275             const TType &samplerType = sampler->getType();
3276             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3277             {
3278                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3279             }
3280             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3281             {
3282                 ASSERT(IsSampler(samplerType.getBasicType()));
3283                 out << ", " << QualifierString(qualifier) << " "
3284                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3285                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3286                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3287                     << ArrayString(samplerType);
3288             }
3289             else
3290             {
3291                 ASSERT(IsSampler(samplerType.getBasicType()));
3292                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3293                     << sampler->name() << ArrayString(samplerType);
3294             }
3295         }
3296     }
3297 }
3298 
zeroInitializer(const TType & type) const3299 TString OutputHLSL::zeroInitializer(const TType &type) const
3300 {
3301     TString string;
3302 
3303     size_t size = type.getObjectSize();
3304     if (size >= kZeroCount)
3305     {
3306         mUseZeroArray = true;
3307     }
3308     string = GetZeroInitializer(size).c_str();
3309 
3310     return "{" + string + "}";
3311 }
3312 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3313 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3314 {
3315     // Array constructors should have been already pruned from the code.
3316     ASSERT(!node->getType().isArray());
3317 
3318     if (visit == PreVisit)
3319     {
3320         TString constructorName;
3321         if (node->getBasicType() == EbtStruct)
3322         {
3323             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3324         }
3325         else
3326         {
3327             constructorName =
3328                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3329         }
3330         out << constructorName << "(";
3331     }
3332     else if (visit == InVisit)
3333     {
3334         out << ", ";
3335     }
3336     else if (visit == PostVisit)
3337     {
3338         out << ")";
3339     }
3340 }
3341 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3342 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3343                                                      const TType &type,
3344                                                      const TConstantUnion *const constUnion)
3345 {
3346     ASSERT(!type.isArray());
3347 
3348     const TConstantUnion *constUnionIterated = constUnion;
3349 
3350     const TStructure *structure = type.getStruct();
3351     if (structure)
3352     {
3353         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3354 
3355         const TFieldList &fields = structure->fields();
3356 
3357         for (size_t i = 0; i < fields.size(); i++)
3358         {
3359             const TType *fieldType = fields[i]->type();
3360             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3361 
3362             if (i != fields.size() - 1)
3363             {
3364                 out << ", ";
3365             }
3366         }
3367 
3368         out << ")";
3369     }
3370     else
3371     {
3372         size_t size    = type.getObjectSize();
3373         bool writeType = size > 1;
3374 
3375         if (writeType)
3376         {
3377             out << TypeString(type) << "(";
3378         }
3379         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3380         if (writeType)
3381         {
3382             out << ")";
3383         }
3384     }
3385 
3386     return constUnionIterated;
3387 }
3388 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3389 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3390 {
3391     if (visit == PreVisit)
3392     {
3393         const char *opStr = GetOperatorString(op);
3394         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3395         out << "(";
3396     }
3397     else
3398     {
3399         outputTriplet(out, visit, nullptr, ", ", ")");
3400     }
3401 }
3402 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3403 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3404                                             TIntermSymbol *symbolNode,
3405                                             TIntermTyped *expression)
3406 {
3407     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3408     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3409 
3410     if (symbolInInitializer)
3411     {
3412         // Type already printed
3413         out << "t" + str(mUniqueIndex) + " = ";
3414         expression->traverse(this);
3415         out << ", ";
3416         symbolNode->traverse(this);
3417         out << " = t" + str(mUniqueIndex);
3418 
3419         mUniqueIndex++;
3420         return true;
3421     }
3422 
3423     return false;
3424 }
3425 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3426 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3427                                              TIntermSymbol *symbolNode,
3428                                              TIntermTyped *initializer)
3429 {
3430     if (initializer->hasConstantValue())
3431     {
3432         symbolNode->traverse(this);
3433         out << ArrayString(symbolNode->getType());
3434         out << " = {";
3435         writeConstantUnionArray(out, initializer->getConstantValue(),
3436                                 initializer->getType().getObjectSize());
3437         out << "}";
3438         return true;
3439     }
3440     return false;
3441 }
3442 
addStructEqualityFunction(const TStructure & structure)3443 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3444 {
3445     const TFieldList &fields = structure.fields();
3446 
3447     for (const auto &eqFunction : mStructEqualityFunctions)
3448     {
3449         if (eqFunction->structure == &structure)
3450         {
3451             return eqFunction->functionName;
3452         }
3453     }
3454 
3455     const TString &structNameString = StructNameString(structure);
3456 
3457     StructEqualityFunction *function = new StructEqualityFunction();
3458     function->structure              = &structure;
3459     function->functionName           = "angle_eq_" + structNameString;
3460 
3461     TInfoSinkBase fnOut;
3462 
3463     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3464           << structNameString + " b)\n"
3465           << "{\n"
3466              "    return ";
3467 
3468     for (size_t i = 0; i < fields.size(); i++)
3469     {
3470         const TField *field    = fields[i];
3471         const TType *fieldType = field->type();
3472 
3473         const TString &fieldNameA = "a." + Decorate(field->name());
3474         const TString &fieldNameB = "b." + Decorate(field->name());
3475 
3476         if (i > 0)
3477         {
3478             fnOut << " && ";
3479         }
3480 
3481         fnOut << "(";
3482         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3483         fnOut << fieldNameA;
3484         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3485         fnOut << fieldNameB;
3486         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3487         fnOut << ")";
3488     }
3489 
3490     fnOut << ";\n"
3491           << "}\n";
3492 
3493     function->functionDefinition = fnOut.c_str();
3494 
3495     mStructEqualityFunctions.push_back(function);
3496     mEqualityFunctions.push_back(function);
3497 
3498     return function->functionName;
3499 }
3500 
addArrayEqualityFunction(const TType & type)3501 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3502 {
3503     for (const auto &eqFunction : mArrayEqualityFunctions)
3504     {
3505         if (eqFunction->type == type)
3506         {
3507             return eqFunction->functionName;
3508         }
3509     }
3510 
3511     TType elementType(type);
3512     elementType.toArrayElementType();
3513 
3514     ArrayHelperFunction *function = new ArrayHelperFunction();
3515     function->type                = type;
3516 
3517     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3518 
3519     TInfoSinkBase fnOut;
3520 
3521     const TString &typeName = TypeString(type);
3522     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3523           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3524           << "{\n"
3525              "    for (int i = 0; i < "
3526           << type.getOutermostArraySize()
3527           << "; ++i)\n"
3528              "    {\n"
3529              "        if (";
3530 
3531     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3532     fnOut << "a[i]";
3533     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3534     fnOut << "b[i]";
3535     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3536 
3537     fnOut << ") { return false; }\n"
3538              "    }\n"
3539              "    return true;\n"
3540              "}\n";
3541 
3542     function->functionDefinition = fnOut.c_str();
3543 
3544     mArrayEqualityFunctions.push_back(function);
3545     mEqualityFunctions.push_back(function);
3546 
3547     return function->functionName;
3548 }
3549 
addArrayAssignmentFunction(const TType & type)3550 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3551 {
3552     for (const auto &assignFunction : mArrayAssignmentFunctions)
3553     {
3554         if (assignFunction.type == type)
3555         {
3556             return assignFunction.functionName;
3557         }
3558     }
3559 
3560     TType elementType(type);
3561     elementType.toArrayElementType();
3562 
3563     ArrayHelperFunction function;
3564     function.type = type;
3565 
3566     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3567 
3568     TInfoSinkBase fnOut;
3569 
3570     const TString &typeName = TypeString(type);
3571     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3572           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3573           << "{\n"
3574              "    for (int i = 0; i < "
3575           << type.getOutermostArraySize()
3576           << "; ++i)\n"
3577              "    {\n"
3578              "        ";
3579 
3580     outputAssign(PreVisit, elementType, fnOut);
3581     fnOut << "a[i]";
3582     outputAssign(InVisit, elementType, fnOut);
3583     fnOut << "b[i]";
3584     outputAssign(PostVisit, elementType, fnOut);
3585 
3586     fnOut << ";\n"
3587              "    }\n"
3588              "}\n";
3589 
3590     function.functionDefinition = fnOut.c_str();
3591 
3592     mArrayAssignmentFunctions.push_back(function);
3593 
3594     return function.functionName;
3595 }
3596 
addArrayConstructIntoFunction(const TType & type)3597 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3598 {
3599     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3600     {
3601         if (constructIntoFunction.type == type)
3602         {
3603             return constructIntoFunction.functionName;
3604         }
3605     }
3606 
3607     TType elementType(type);
3608     elementType.toArrayElementType();
3609 
3610     ArrayHelperFunction function;
3611     function.type = type;
3612 
3613     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3614 
3615     TInfoSinkBase fnOut;
3616 
3617     const TString &typeName = TypeString(type);
3618     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3619     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3620     {
3621         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3622     }
3623     fnOut << ")\n"
3624              "{\n";
3625 
3626     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3627     {
3628         fnOut << "    ";
3629         outputAssign(PreVisit, elementType, fnOut);
3630         fnOut << "a[" << i << "]";
3631         outputAssign(InVisit, elementType, fnOut);
3632         fnOut << "b" << i;
3633         outputAssign(PostVisit, elementType, fnOut);
3634         fnOut << ";\n";
3635     }
3636     fnOut << "}\n";
3637 
3638     function.functionDefinition = fnOut.c_str();
3639 
3640     mArrayConstructIntoFunctions.push_back(function);
3641 
3642     return function.functionName;
3643 }
3644 
ensureStructDefined(const TType & type)3645 void OutputHLSL::ensureStructDefined(const TType &type)
3646 {
3647     const TStructure *structure = type.getStruct();
3648     if (structure)
3649     {
3650         ASSERT(type.getBasicType() == EbtStruct);
3651         mStructureHLSL->ensureStructDefined(*structure);
3652     }
3653 }
3654 
shaderNeedsGenerateOutput() const3655 bool OutputHLSL::shaderNeedsGenerateOutput() const
3656 {
3657     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3658 }
3659 
generateOutputCall() const3660 const char *OutputHLSL::generateOutputCall() const
3661 {
3662     if (mShaderType == GL_VERTEX_SHADER)
3663     {
3664         return "generateOutput(input)";
3665     }
3666     else
3667     {
3668         return "generateOutput()";
3669     }
3670 }
3671 }  // namespace sh
3672