1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)89 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
90 {
91     switch (op)
92     {
93         case EOpAtomicAdd:
94             return "InterlockedAdd(";
95         case EOpAtomicMin:
96             return "InterlockedMin(";
97         case EOpAtomicMax:
98             return "InterlockedMax(";
99         case EOpAtomicAnd:
100             return "InterlockedAnd(";
101         case EOpAtomicOr:
102             return "InterlockedOr(";
103         case EOpAtomicXor:
104             return "InterlockedXor(";
105         case EOpAtomicExchange:
106             return "InterlockedExchange(";
107         case EOpAtomicCompSwap:
108             return "InterlockedCompareExchange(";
109         default:
110             UNREACHABLE();
111             return "";
112     }
113 }
114 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)115 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
116 {
117     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
118     if (aggregateNode == nullptr)
119     {
120         return false;
121     }
122 
123     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
124     {
125         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
126     }
127 
128     return false;
129 }
130 
131 const char *kZeros       = "_ANGLE_ZEROS_";
132 constexpr int kZeroCount = 256;
DefineZeroArray()133 std::string DefineZeroArray()
134 {
135     std::stringstream ss = sh::InitializeStream<std::stringstream>();
136     // For 'static', if the declaration does not include an initializer, the value is set to zero.
137     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
138     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
139     return ss.str();
140 }
141 
GetZeroInitializer(size_t size)142 std::string GetZeroInitializer(size_t size)
143 {
144     std::stringstream ss = sh::InitializeStream<std::stringstream>();
145     size_t quotient      = size / kZeroCount;
146     size_t reminder      = size % kZeroCount;
147 
148     for (size_t i = 0; i < quotient; ++i)
149     {
150         if (i != 0)
151         {
152             ss << ", ";
153         }
154         ss << kZeros;
155     }
156 
157     for (size_t i = 0; i < reminder; ++i)
158     {
159         if (quotient != 0 || i != 0)
160         {
161             ss << ", ";
162         }
163         ss << "0";
164     }
165 
166     return ss.str();
167 }
168 
169 }  // anonymous namespace
170 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)171 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
172                                    const TVariable *aInstanceVariable)
173     : block(aBlock), instanceVariable(aInstanceVariable)
174 {}
175 
needStructMapping(TIntermTyped * node)176 bool OutputHLSL::needStructMapping(TIntermTyped *node)
177 {
178     ASSERT(node->getBasicType() == EbtStruct);
179     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
180     {
181         TIntermNode *ancestor               = getAncestorNode(n);
182         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
183         if (ancestorBinary)
184         {
185             switch (ancestorBinary->getOp())
186             {
187                 case EOpIndexDirectStruct:
188                 {
189                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
190                     const TIntermConstantUnion *index =
191                         ancestorBinary->getRight()->getAsConstantUnion();
192                     const TField *field = structure->fields()[index->getIConst(0)];
193                     if (field->type()->getStruct() == nullptr)
194                     {
195                         return false;
196                     }
197                     break;
198                 }
199                 case EOpIndexDirect:
200                 case EOpIndexIndirect:
201                     break;
202                 default:
203                     return true;
204             }
205         }
206         else
207         {
208             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
209             if (ancestorAggregate)
210             {
211                 return true;
212             }
213             return false;
214         }
215     }
216     return true;
217 }
218 
writeFloat(TInfoSinkBase & out,float f)219 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
220 {
221     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
222     // regardless.
223     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
224         mOutputType == SH_HLSL_4_1_OUTPUT)
225     {
226         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
227     }
228     else
229     {
230         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
231     }
232 }
233 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)234 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
235 {
236     ASSERT(constUnion != nullptr);
237     switch (constUnion->getType())
238     {
239         case EbtFloat:
240             writeFloat(out, constUnion->getFConst());
241             break;
242         case EbtInt:
243             out << constUnion->getIConst();
244             break;
245         case EbtUInt:
246             out << constUnion->getUConst();
247             break;
248         case EbtBool:
249             out << constUnion->getBConst();
250             break;
251         default:
252             UNREACHABLE();
253     }
254 }
255 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)256 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
257                                                           const TConstantUnion *const constUnion,
258                                                           const size_t size)
259 {
260     const TConstantUnion *constUnionIterated = constUnion;
261     for (size_t i = 0; i < size; i++, constUnionIterated++)
262     {
263         writeSingleConstant(out, constUnionIterated);
264 
265         if (i != size - 1)
266         {
267             out << ", ";
268         }
269     }
270     return constUnionIterated;
271 }
272 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::vector<InterfaceBlock> & shaderStorageBlocks)273 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
274                        ShShaderSpec shaderSpec,
275                        int shaderVersion,
276                        const TExtensionBehavior &extensionBehavior,
277                        const char *sourcePath,
278                        ShShaderOutput outputType,
279                        int numRenderTargets,
280                        int maxDualSourceDrawBuffers,
281                        const std::vector<ShaderVariable> &uniforms,
282                        ShCompileOptions compileOptions,
283                        sh::WorkGroupSize workGroupSize,
284                        TSymbolTable *symbolTable,
285                        PerformanceDiagnostics *perfDiagnostics,
286                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
287     : TIntermTraverser(true, true, true, symbolTable),
288       mShaderType(shaderType),
289       mShaderSpec(shaderSpec),
290       mShaderVersion(shaderVersion),
291       mExtensionBehavior(extensionBehavior),
292       mSourcePath(sourcePath),
293       mOutputType(outputType),
294       mCompileOptions(compileOptions),
295       mInsideFunction(false),
296       mInsideMain(false),
297       mNumRenderTargets(numRenderTargets),
298       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
299       mCurrentFunctionMetadata(nullptr),
300       mWorkGroupSize(workGroupSize),
301       mPerfDiagnostics(perfDiagnostics),
302       mNeedStructMapping(false)
303 {
304     mUsesFragColor        = false;
305     mUsesFragData         = false;
306     mUsesDepthRange       = false;
307     mUsesFragCoord        = false;
308     mUsesPointCoord       = false;
309     mUsesFrontFacing      = false;
310     mUsesHelperInvocation = false;
311     mUsesPointSize        = false;
312     mUsesInstanceID       = false;
313     mHasMultiviewExtensionEnabled =
314         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
315         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
316     mUsesViewID                  = false;
317     mUsesVertexID                = false;
318     mUsesFragDepth               = false;
319     mUsesNumWorkGroups           = false;
320     mUsesWorkGroupID             = false;
321     mUsesLocalInvocationID       = false;
322     mUsesGlobalInvocationID      = false;
323     mUsesLocalInvocationIndex    = false;
324     mUsesXor                     = false;
325     mUsesDiscardRewriting        = false;
326     mUsesNestedBreak             = false;
327     mRequiresIEEEStrictCompiling = false;
328     mUseZeroArray                = false;
329     mUsesSecondaryColor          = false;
330 
331     mUniqueIndex = 0;
332 
333     mOutputLod0Function      = false;
334     mInsideDiscontinuousLoop = false;
335     mNestedLoopDepth         = 0;
336 
337     mExcessiveLoopIndex = nullptr;
338 
339     mStructureHLSL       = new StructureHLSL;
340     mTextureFunctionHLSL = new TextureFunctionHLSL;
341     mImageFunctionHLSL   = new ImageFunctionHLSL;
342     mAtomicCounterFunctionHLSL =
343         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
344 
345     unsigned int firstUniformRegister =
346         ((compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0) ? 1u : 0u;
347     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, compileOptions, uniforms,
348                                        firstUniformRegister);
349 
350     if (mOutputType == SH_HLSL_3_0_OUTPUT)
351     {
352         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
353         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
354         // dx_ViewAdjust.
355         // In both cases total 3 uniform registers need to be reserved.
356         mResourcesHLSL->reserveUniformRegisters(3);
357     }
358 
359     // Reserve registers for the default uniform block and driver constants
360     mResourcesHLSL->reserveUniformBlockRegisters(2);
361 
362     mSSBOOutputHLSL =
363         new ShaderStorageBlockOutputHLSL(this, symbolTable, mResourcesHLSL, shaderStorageBlocks);
364 }
365 
~OutputHLSL()366 OutputHLSL::~OutputHLSL()
367 {
368     SafeDelete(mSSBOOutputHLSL);
369     SafeDelete(mStructureHLSL);
370     SafeDelete(mResourcesHLSL);
371     SafeDelete(mTextureFunctionHLSL);
372     SafeDelete(mImageFunctionHLSL);
373     SafeDelete(mAtomicCounterFunctionHLSL);
374     for (auto &eqFunction : mStructEqualityFunctions)
375     {
376         SafeDelete(eqFunction);
377     }
378     for (auto &eqFunction : mArrayEqualityFunctions)
379     {
380         SafeDelete(eqFunction);
381     }
382 }
383 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)384 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
385 {
386     BuiltInFunctionEmulator builtInFunctionEmulator;
387     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
388     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
389     {
390         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
391                                                            mShaderVersion);
392     }
393 
394     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
395 
396     // Now that we are done changing the AST, do the analyses need for HLSL generation
397     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
398     ASSERT(success == CallDAG::INITDAG_SUCCESS);
399     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
400 
401     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
402     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
403     // the shader code. When we add shader storage blocks we might also consider an alternative
404     // solution, since the struct mapping won't work very well for shader storage blocks.
405 
406     // Output the body and footer first to determine what has to go in the header
407     mInfoSinkStack.push(&mBody);
408     treeRoot->traverse(this);
409     mInfoSinkStack.pop();
410 
411     mInfoSinkStack.push(&mFooter);
412     mInfoSinkStack.pop();
413 
414     mInfoSinkStack.push(&mHeader);
415     header(mHeader, std140Structs, &builtInFunctionEmulator);
416     mInfoSinkStack.pop();
417 
418     objSink << mHeader.c_str();
419     objSink << mBody.c_str();
420     objSink << mFooter.c_str();
421 
422     builtInFunctionEmulator.cleanup();
423 }
424 
getShaderStorageBlockRegisterMap() const425 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
426 {
427     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
428 }
429 
getUniformBlockRegisterMap() const430 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
431 {
432     return mResourcesHLSL->getUniformBlockRegisterMap();
433 }
434 
getUniformBlockUseStructuredBufferMap() const435 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
436 {
437     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
438 }
439 
getUniformRegisterMap() const440 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
441 {
442     return mResourcesHLSL->getUniformRegisterMap();
443 }
444 
getReadonlyImage2DRegisterIndex() const445 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
446 {
447     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
448 }
449 
getImage2DRegisterIndex() const450 unsigned int OutputHLSL::getImage2DRegisterIndex() const
451 {
452     return mResourcesHLSL->getImage2DRegisterIndex();
453 }
454 
getUsedImage2DFunctionNames() const455 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
456 {
457     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
458 }
459 
structInitializerString(int indent,const TType & type,const TString & name) const460 TString OutputHLSL::structInitializerString(int indent,
461                                             const TType &type,
462                                             const TString &name) const
463 {
464     TString init;
465 
466     TString indentString;
467     for (int spaces = 0; spaces < indent; spaces++)
468     {
469         indentString += "    ";
470     }
471 
472     if (type.isArray())
473     {
474         init += indentString + "{\n";
475         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
476         {
477             TStringStream indexedString = sh::InitializeStream<TStringStream>();
478             indexedString << name << "[" << arrayIndex << "]";
479             TType elementType = type;
480             elementType.toArrayElementType();
481             init += structInitializerString(indent + 1, elementType, indexedString.str());
482             if (arrayIndex < type.getOutermostArraySize() - 1)
483             {
484                 init += ",";
485             }
486             init += "\n";
487         }
488         init += indentString + "}";
489     }
490     else if (type.getBasicType() == EbtStruct)
491     {
492         init += indentString + "{\n";
493         const TStructure &structure = *type.getStruct();
494         const TFieldList &fields    = structure.fields();
495         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
496         {
497             const TField &field      = *fields[fieldIndex];
498             const TString &fieldName = name + "." + Decorate(field.name());
499             const TType &fieldType   = *field.type();
500 
501             init += structInitializerString(indent + 1, fieldType, fieldName);
502             if (fieldIndex < fields.size() - 1)
503             {
504                 init += ",";
505             }
506             init += "\n";
507         }
508         init += indentString + "}";
509     }
510     else
511     {
512         init += indentString + name;
513     }
514 
515     return init;
516 }
517 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const518 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
519 {
520     TString mappedStructs;
521 
522     for (auto &mappedStruct : std140Structs)
523     {
524         const TInterfaceBlock *interfaceBlock =
525             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
526         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
527         switch (qualifier)
528         {
529             case EvqUniform:
530                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
531                 {
532                     continue;
533                 }
534                 break;
535             case EvqBuffer:
536                 continue;
537             default:
538                 UNREACHABLE();
539                 return mappedStructs;
540         }
541 
542         unsigned int instanceCount = 1u;
543         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
544         if (isInstanceArray)
545         {
546             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
547         }
548 
549         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
550              ++instanceArrayIndex)
551         {
552             TString originalName;
553             TString mappedName("map");
554 
555             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
556             {
557                 const ImmutableString &instanceName =
558                     mappedStruct.blockDeclarator->variable().name();
559                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
560                 if (isInstanceArray)
561                     instanceStringArrayIndex = instanceArrayIndex;
562                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
563                     instanceName, instanceStringArrayIndex);
564                 originalName += instanceString;
565                 mappedName += instanceString;
566                 originalName += ".";
567                 mappedName += "_";
568             }
569 
570             TString fieldName = Decorate(mappedStruct.field->name());
571             originalName += fieldName;
572             mappedName += fieldName;
573 
574             TType *structType = mappedStruct.field->type();
575             mappedStructs +=
576                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
577 
578             if (structType->isArray())
579             {
580                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
581             }
582 
583             mappedStructs += " =\n";
584             mappedStructs += structInitializerString(0, *structType, originalName);
585             mappedStructs += ";\n";
586         }
587     }
588     return mappedStructs;
589 }
590 
writeReferencedAttributes(TInfoSinkBase & out) const591 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
592 {
593     for (const auto &attribute : mReferencedAttributes)
594     {
595         const TType &type           = attribute.second->getType();
596         const ImmutableString &name = attribute.second->name();
597 
598         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
599             << zeroInitializer(type) << ";\n";
600     }
601 }
602 
writeReferencedVaryings(TInfoSinkBase & out) const603 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
604 {
605     for (const auto &varying : mReferencedVaryings)
606     {
607         const TType &type = varying.second->getType();
608 
609         // Program linking depends on this exact format
610         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
611             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
612             << zeroInitializer(type) << ";\n";
613     }
614 }
615 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const616 void OutputHLSL::header(TInfoSinkBase &out,
617                         const std::vector<MappedStruct> &std140Structs,
618                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
619 {
620     TString mappedStructs;
621     if (mNeedStructMapping)
622     {
623         mappedStructs = generateStructMapping(std140Structs);
624     }
625 
626     out << mStructureHLSL->structsHeader();
627 
628     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
629     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks);
630     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
631 
632     if (!mEqualityFunctions.empty())
633     {
634         out << "\n// Equality functions\n\n";
635         for (const auto &eqFunction : mEqualityFunctions)
636         {
637             out << eqFunction->functionDefinition << "\n";
638         }
639     }
640     if (!mArrayAssignmentFunctions.empty())
641     {
642         out << "\n// Assignment functions\n\n";
643         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
644         {
645             out << assignmentFunction.functionDefinition << "\n";
646         }
647     }
648     if (!mArrayConstructIntoFunctions.empty())
649     {
650         out << "\n// Array constructor functions\n\n";
651         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
652         {
653             out << constructIntoFunction.functionDefinition << "\n";
654         }
655     }
656 
657     if (mUsesDiscardRewriting)
658     {
659         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
660     }
661 
662     if (mUsesNestedBreak)
663     {
664         out << "#define ANGLE_USES_NESTED_BREAK\n";
665     }
666 
667     if (mRequiresIEEEStrictCompiling)
668     {
669         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
670     }
671 
672     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
673            "#define LOOP [loop]\n"
674            "#define FLATTEN [flatten]\n"
675            "#else\n"
676            "#define LOOP\n"
677            "#define FLATTEN\n"
678            "#endif\n";
679 
680     // array stride for atomic counter buffers is always 4 per original extension
681     // ARB_shader_atomic_counters and discussion on
682     // https://github.com/KhronosGroup/OpenGL-API/issues/5
683     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
684 
685     if (mUseZeroArray)
686     {
687         out << DefineZeroArray() << "\n";
688     }
689 
690     if (mShaderType == GL_FRAGMENT_SHADER)
691     {
692         const bool usingMRTExtension =
693             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
694         const bool usingBFEExtension =
695             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
696 
697         out << "// Varyings\n";
698         writeReferencedVaryings(out);
699         out << "\n";
700 
701         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
702             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
703         {
704             for (const auto &outputVariable : mReferencedOutputVariables)
705             {
706                 const ImmutableString &variableName = outputVariable.second->name();
707                 const TType &variableType           = outputVariable.second->getType();
708 
709                 out << "static " << TypeString(variableType) << " out_" << variableName
710                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
711             }
712         }
713         else
714         {
715             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
716 
717             out << "static float4 gl_Color[" << numColorValues
718                 << "] =\n"
719                    "{\n";
720             for (unsigned int i = 0; i < numColorValues; i++)
721             {
722                 out << "    float4(0, 0, 0, 0)";
723                 if (i + 1 != numColorValues)
724                 {
725                     out << ",";
726                 }
727                 out << "\n";
728             }
729 
730             out << "};\n";
731 
732             if (usingBFEExtension && mUsesSecondaryColor)
733             {
734                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
735                     << "] = \n"
736                        "{\n";
737                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
738                 {
739                     out << "    float4(0, 0, 0, 0)";
740                     if (i + 1 != mMaxDualSourceDrawBuffers)
741                     {
742                         out << ",";
743                     }
744                     out << "\n";
745                 }
746                 out << "};\n";
747             }
748         }
749 
750         if (mUsesFragDepth)
751         {
752             out << "static float gl_Depth = 0.0;\n";
753         }
754 
755         if (mUsesFragCoord)
756         {
757             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
758         }
759 
760         if (mUsesPointCoord)
761         {
762             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
763         }
764 
765         if (mUsesFrontFacing)
766         {
767             out << "static bool gl_FrontFacing = false;\n";
768         }
769 
770         if (mUsesHelperInvocation)
771         {
772             out << "static bool gl_HelperInvocation = false;\n";
773         }
774 
775         out << "\n";
776 
777         if (mUsesDepthRange)
778         {
779             out << "struct gl_DepthRangeParameters\n"
780                    "{\n"
781                    "    float near;\n"
782                    "    float far;\n"
783                    "    float diff;\n"
784                    "};\n"
785                    "\n";
786         }
787 
788         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
789         {
790             out << "cbuffer DriverConstants : register(b1)\n"
791                    "{\n";
792 
793             if (mUsesDepthRange)
794             {
795                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
796             }
797 
798             if (mUsesFragCoord)
799             {
800                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
801             }
802 
803             if (mUsesFragCoord || mUsesFrontFacing)
804             {
805                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
806             }
807 
808             if (mUsesFragCoord)
809             {
810                 // dx_ViewScale is only used in the fragment shader to correct
811                 // the value for glFragCoord if necessary
812                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
813             }
814 
815             if (mHasMultiviewExtensionEnabled)
816             {
817                 // We have to add a value which we can use to keep track of which multi-view code
818                 // path is to be selected in the GS.
819                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
820             }
821 
822             if (mOutputType == SH_HLSL_4_1_OUTPUT)
823             {
824                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
825             }
826 
827             out << "};\n";
828         }
829         else
830         {
831             if (mUsesDepthRange)
832             {
833                 out << "uniform float3 dx_DepthRange : register(c0);";
834             }
835 
836             if (mUsesFragCoord)
837             {
838                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
839             }
840 
841             if (mUsesFragCoord || mUsesFrontFacing)
842             {
843                 out << "uniform float3 dx_DepthFront : register(c2);\n";
844             }
845         }
846 
847         out << "\n";
848 
849         if (mUsesDepthRange)
850         {
851             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
852                    "dx_DepthRange.y, dx_DepthRange.z};\n"
853                    "\n";
854         }
855 
856         if (usingMRTExtension && mNumRenderTargets > 1)
857         {
858             out << "#define GL_USES_MRT\n";
859         }
860 
861         if (mUsesFragColor)
862         {
863             out << "#define GL_USES_FRAG_COLOR\n";
864         }
865 
866         if (mUsesFragData)
867         {
868             out << "#define GL_USES_FRAG_DATA\n";
869         }
870 
871         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
872         {
873             out << "#define GL_USES_SECONDARY_COLOR\n";
874         }
875     }
876     else if (mShaderType == GL_VERTEX_SHADER)
877     {
878         out << "// Attributes\n";
879         writeReferencedAttributes(out);
880         out << "\n"
881                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
882 
883         if (mUsesPointSize)
884         {
885             out << "static float gl_PointSize = float(1);\n";
886         }
887 
888         if (mUsesInstanceID)
889         {
890             out << "static int gl_InstanceID;";
891         }
892 
893         if (mUsesVertexID)
894         {
895             out << "static int gl_VertexID;";
896         }
897 
898         out << "\n"
899                "// Varyings\n";
900         writeReferencedVaryings(out);
901         out << "\n";
902 
903         if (mUsesDepthRange)
904         {
905             out << "struct gl_DepthRangeParameters\n"
906                    "{\n"
907                    "    float near;\n"
908                    "    float far;\n"
909                    "    float diff;\n"
910                    "};\n"
911                    "\n";
912         }
913 
914         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
915         {
916             out << "cbuffer DriverConstants : register(b1)\n"
917                    "{\n";
918 
919             if (mUsesDepthRange)
920             {
921                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
922             }
923 
924             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
925             // shaders. However, we declare it for all shaders (including Feature Level 10+).
926             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
927             // if it's unused.
928             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
929             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
930             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
931 
932             if (mHasMultiviewExtensionEnabled)
933             {
934                 // We have to add a value which we can use to keep track of which multi-view code
935                 // path is to be selected in the GS.
936                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
937             }
938 
939             if (mOutputType == SH_HLSL_4_1_OUTPUT)
940             {
941                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
942             }
943 
944             if (mUsesVertexID)
945             {
946                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
947             }
948 
949             out << "};\n"
950                    "\n";
951         }
952         else
953         {
954             if (mUsesDepthRange)
955             {
956                 out << "uniform float3 dx_DepthRange : register(c0);\n";
957             }
958 
959             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
960             out << "uniform float2 dx_ViewCoords : register(c2);\n"
961                    "\n";
962         }
963 
964         if (mUsesDepthRange)
965         {
966             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
967                    "dx_DepthRange.y, dx_DepthRange.z};\n"
968                    "\n";
969         }
970     }
971     else  // Compute shader
972     {
973         ASSERT(mShaderType == GL_COMPUTE_SHADER);
974 
975         out << "cbuffer DriverConstants : register(b1)\n"
976                "{\n";
977         if (mUsesNumWorkGroups)
978         {
979             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
980         }
981         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
982         unsigned int registerIndex = 1;
983         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
984         // Sampler metadata struct must be two 4-vec, 32 bytes.
985         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
986         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
987         out << "};\n";
988 
989         out << kImage2DFunctionString << "\n";
990 
991         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
992         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
993 
994         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
995         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
996                                 << "{\n";
997 
998         if (mUsesWorkGroupID)
999         {
1000             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1001             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1002                                    << "SV_GroupID;\n";
1003             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1004         }
1005 
1006         if (mUsesLocalInvocationID)
1007         {
1008             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1009             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1010                                    << "SV_GroupThreadID;\n";
1011             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1012         }
1013 
1014         if (mUsesGlobalInvocationID)
1015         {
1016             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1017             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1018                                    << "SV_DispatchThreadID;\n";
1019             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1020         }
1021 
1022         if (mUsesLocalInvocationIndex)
1023         {
1024             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1025             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1026                                    << "SV_GroupIndex;\n";
1027             glBuiltinInitialization
1028                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1029         }
1030 
1031         systemValueDeclaration << "};\n\n";
1032         glBuiltinInitialization << "};\n\n";
1033 
1034         out << systemValueDeclaration.str();
1035         out << glBuiltinInitialization.str();
1036     }
1037 
1038     if (!mappedStructs.empty())
1039     {
1040         out << "// Structures from std140 blocks with padding removed\n";
1041         out << "\n";
1042         out << mappedStructs;
1043         out << "\n";
1044     }
1045 
1046     bool getDimensionsIgnoresBaseLevel =
1047         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1048     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1049     mImageFunctionHLSL->imageFunctionHeader(out);
1050     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1051 
1052     if (mUsesFragCoord)
1053     {
1054         out << "#define GL_USES_FRAG_COORD\n";
1055     }
1056 
1057     if (mUsesPointCoord)
1058     {
1059         out << "#define GL_USES_POINT_COORD\n";
1060     }
1061 
1062     if (mUsesFrontFacing)
1063     {
1064         out << "#define GL_USES_FRONT_FACING\n";
1065     }
1066 
1067     if (mUsesHelperInvocation)
1068     {
1069         out << "#define GL_USES_HELPER_INVOCATION\n";
1070     }
1071 
1072     if (mUsesPointSize)
1073     {
1074         out << "#define GL_USES_POINT_SIZE\n";
1075     }
1076 
1077     if (mHasMultiviewExtensionEnabled)
1078     {
1079         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1080     }
1081 
1082     if (mUsesVertexID)
1083     {
1084         out << "#define GL_USES_VERTEX_ID\n";
1085     }
1086 
1087     if (mUsesViewID)
1088     {
1089         out << "#define GL_USES_VIEW_ID\n";
1090     }
1091 
1092     if (mUsesFragDepth)
1093     {
1094         out << "#define GL_USES_FRAG_DEPTH\n";
1095     }
1096 
1097     if (mUsesDepthRange)
1098     {
1099         out << "#define GL_USES_DEPTH_RANGE\n";
1100     }
1101 
1102     if (mUsesXor)
1103     {
1104         out << "bool xor(bool p, bool q)\n"
1105                "{\n"
1106                "    return (p || q) && !(p && q);\n"
1107                "}\n"
1108                "\n";
1109     }
1110 
1111     builtInFunctionEmulator->outputEmulatedFunctions(out);
1112 }
1113 
visitSymbol(TIntermSymbol * node)1114 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1115 {
1116     const TVariable &variable = node->variable();
1117 
1118     // Empty symbols can only appear in declarations and function arguments, and in either of those
1119     // cases the symbol nodes are not visited.
1120     ASSERT(variable.symbolType() != SymbolType::Empty);
1121 
1122     TInfoSinkBase &out = getInfoSink();
1123 
1124     // Handle accessing std140 structs by value
1125     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1126         needStructMapping(node))
1127     {
1128         mNeedStructMapping = true;
1129         out << "map";
1130     }
1131 
1132     const ImmutableString &name     = variable.name();
1133     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1134 
1135     if (name == "gl_DepthRange")
1136     {
1137         mUsesDepthRange = true;
1138         out << name;
1139     }
1140     else if (IsAtomicCounter(variable.getType().getBasicType()))
1141     {
1142         const TType &variableType = variable.getType();
1143         if (variableType.getQualifier() == EvqUniform)
1144         {
1145             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1146             mReferencedUniforms[uniqueId.get()] = &variable;
1147             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1148         }
1149         else
1150         {
1151             TString varName = DecorateVariableIfNeeded(variable);
1152             out << varName << ", " << varName << "_offset";
1153         }
1154     }
1155     else
1156     {
1157         const TType &variableType = variable.getType();
1158         TQualifier qualifier      = variable.getType().getQualifier();
1159 
1160         ensureStructDefined(variableType);
1161 
1162         if (qualifier == EvqUniform)
1163         {
1164             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1165 
1166             if (interfaceBlock)
1167             {
1168                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1169                 {
1170                     const TVariable *instanceVariable = nullptr;
1171                     if (variableType.isInterfaceBlock())
1172                     {
1173                         instanceVariable = &variable;
1174                     }
1175                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1176                         new TReferencedBlock(interfaceBlock, instanceVariable);
1177                 }
1178             }
1179             else
1180             {
1181                 mReferencedUniforms[uniqueId.get()] = &variable;
1182             }
1183 
1184             out << DecorateVariableIfNeeded(variable);
1185         }
1186         else if (qualifier == EvqBuffer)
1187         {
1188             UNREACHABLE();
1189         }
1190         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1191         {
1192             mReferencedAttributes[uniqueId.get()] = &variable;
1193             out << Decorate(name);
1194         }
1195         else if (IsVarying(qualifier))
1196         {
1197             mReferencedVaryings[uniqueId.get()] = &variable;
1198             out << DecorateVariableIfNeeded(variable);
1199             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1200             {
1201                 mUsesViewID = true;
1202             }
1203         }
1204         else if (qualifier == EvqFragmentOut)
1205         {
1206             mReferencedOutputVariables[uniqueId.get()] = &variable;
1207             out << "out_" << name;
1208         }
1209         else if (qualifier == EvqFragColor)
1210         {
1211             out << "gl_Color[0]";
1212             mUsesFragColor = true;
1213         }
1214         else if (qualifier == EvqFragData)
1215         {
1216             out << "gl_Color";
1217             mUsesFragData = true;
1218         }
1219         else if (qualifier == EvqSecondaryFragColorEXT)
1220         {
1221             out << "gl_SecondaryColor[0]";
1222             mUsesSecondaryColor = true;
1223         }
1224         else if (qualifier == EvqSecondaryFragDataEXT)
1225         {
1226             out << "gl_SecondaryColor";
1227             mUsesSecondaryColor = true;
1228         }
1229         else if (qualifier == EvqFragCoord)
1230         {
1231             mUsesFragCoord = true;
1232             out << name;
1233         }
1234         else if (qualifier == EvqPointCoord)
1235         {
1236             mUsesPointCoord = true;
1237             out << name;
1238         }
1239         else if (qualifier == EvqFrontFacing)
1240         {
1241             mUsesFrontFacing = true;
1242             out << name;
1243         }
1244         else if (qualifier == EvqHelperInvocation)
1245         {
1246             mUsesHelperInvocation = true;
1247             out << name;
1248         }
1249         else if (qualifier == EvqPointSize)
1250         {
1251             mUsesPointSize = true;
1252             out << name;
1253         }
1254         else if (qualifier == EvqInstanceID)
1255         {
1256             mUsesInstanceID = true;
1257             out << name;
1258         }
1259         else if (qualifier == EvqVertexID)
1260         {
1261             mUsesVertexID = true;
1262             out << name;
1263         }
1264         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1265         {
1266             mUsesFragDepth = true;
1267             out << "gl_Depth";
1268         }
1269         else if (qualifier == EvqNumWorkGroups)
1270         {
1271             mUsesNumWorkGroups = true;
1272             out << name;
1273         }
1274         else if (qualifier == EvqWorkGroupID)
1275         {
1276             mUsesWorkGroupID = true;
1277             out << name;
1278         }
1279         else if (qualifier == EvqLocalInvocationID)
1280         {
1281             mUsesLocalInvocationID = true;
1282             out << name;
1283         }
1284         else if (qualifier == EvqGlobalInvocationID)
1285         {
1286             mUsesGlobalInvocationID = true;
1287             out << name;
1288         }
1289         else if (qualifier == EvqLocalInvocationIndex)
1290         {
1291             mUsesLocalInvocationIndex = true;
1292             out << name;
1293         }
1294         else
1295         {
1296             out << DecorateVariableIfNeeded(variable);
1297         }
1298     }
1299 }
1300 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1301 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1302 {
1303     if (type.isScalar() && !type.isArray())
1304     {
1305         if (op == EOpEqual)
1306         {
1307             outputTriplet(out, visit, "(", " == ", ")");
1308         }
1309         else
1310         {
1311             outputTriplet(out, visit, "(", " != ", ")");
1312         }
1313     }
1314     else
1315     {
1316         if (visit == PreVisit && op == EOpNotEqual)
1317         {
1318             out << "!";
1319         }
1320 
1321         if (type.isArray())
1322         {
1323             const TString &functionName = addArrayEqualityFunction(type);
1324             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1325         }
1326         else if (type.getBasicType() == EbtStruct)
1327         {
1328             const TStructure &structure = *type.getStruct();
1329             const TString &functionName = addStructEqualityFunction(structure);
1330             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1331         }
1332         else
1333         {
1334             ASSERT(type.isMatrix() || type.isVector());
1335             outputTriplet(out, visit, "all(", " == ", ")");
1336         }
1337     }
1338 }
1339 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1340 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1341 {
1342     if (type.isArray())
1343     {
1344         const TString &functionName = addArrayAssignmentFunction(type);
1345         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1346     }
1347     else
1348     {
1349         outputTriplet(out, visit, "(", " = ", ")");
1350     }
1351 }
1352 
ancestorEvaluatesToSamplerInStruct()1353 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1354 {
1355     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1356     {
1357         TIntermNode *ancestor               = getAncestorNode(n);
1358         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1359         if (ancestorBinary == nullptr)
1360         {
1361             return false;
1362         }
1363         switch (ancestorBinary->getOp())
1364         {
1365             case EOpIndexDirectStruct:
1366             {
1367                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1368                 const TIntermConstantUnion *index =
1369                     ancestorBinary->getRight()->getAsConstantUnion();
1370                 const TField *field = structure->fields()[index->getIConst(0)];
1371                 if (IsSampler(field->type()->getBasicType()))
1372                 {
1373                     return true;
1374                 }
1375                 break;
1376             }
1377             case EOpIndexDirect:
1378                 break;
1379             default:
1380                 // Returning a sampler from indirect indexing is not supported.
1381                 return false;
1382         }
1383     }
1384     return false;
1385 }
1386 
visitSwizzle(Visit visit,TIntermSwizzle * node)1387 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1388 {
1389     TInfoSinkBase &out = getInfoSink();
1390     if (visit == PostVisit)
1391     {
1392         out << ".";
1393         node->writeOffsetsAsXYZW(&out);
1394     }
1395     return true;
1396 }
1397 
visitBinary(Visit visit,TIntermBinary * node)1398 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1399 {
1400     TInfoSinkBase &out = getInfoSink();
1401 
1402     switch (node->getOp())
1403     {
1404         case EOpComma:
1405             outputTriplet(out, visit, "(", ", ", ")");
1406             break;
1407         case EOpAssign:
1408             if (node->isArray())
1409             {
1410                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1411                 if (rightAgg != nullptr && rightAgg->isConstructor())
1412                 {
1413                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1414                     out << functionName << "(";
1415                     node->getLeft()->traverse(this);
1416                     TIntermSequence *seq = rightAgg->getSequence();
1417                     for (auto &arrayElement : *seq)
1418                     {
1419                         out << ", ";
1420                         arrayElement->traverse(this);
1421                     }
1422                     out << ")";
1423                     return false;
1424                 }
1425                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1426                 // function call is assigned.
1427                 ASSERT(rightAgg == nullptr);
1428             }
1429             // Assignment expressions with atomic functions should be transformed into atomic
1430             // function calls in HLSL.
1431             // e.g. original_value = atomicAdd(dest, value) should be translated into
1432             //      InterlockedAdd(dest, value, original_value);
1433             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1434             {
1435                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1436                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1437                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1438                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1439                 ASSERT(argumentSeq->size() >= 2u);
1440                 for (auto &argument : *argumentSeq)
1441                 {
1442                     argument->traverse(this);
1443                     out << ", ";
1444                 }
1445                 node->getLeft()->traverse(this);
1446                 out << ")";
1447                 return false;
1448             }
1449             else if (IsInShaderStorageBlock(node->getLeft()))
1450             {
1451                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1452                 out << ", ";
1453                 if (IsInShaderStorageBlock(node->getRight()))
1454                 {
1455                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1456                 }
1457                 else
1458                 {
1459                     node->getRight()->traverse(this);
1460                 }
1461 
1462                 out << ")";
1463                 return false;
1464             }
1465             else if (IsInShaderStorageBlock(node->getRight()))
1466             {
1467                 node->getLeft()->traverse(this);
1468                 out << " = ";
1469                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1470                 return false;
1471             }
1472 
1473             outputAssign(visit, node->getType(), out);
1474             break;
1475         case EOpInitialize:
1476             if (visit == PreVisit)
1477             {
1478                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1479                 ASSERT(symbolNode);
1480                 TIntermTyped *initializer = node->getRight();
1481 
1482                 // Global initializers must be constant at this point.
1483                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1484 
1485                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1486                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1487                 // new variable is created before the assignment is evaluated), so we need to
1488                 // convert
1489                 // this to "float t = x, x = t;".
1490                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1491                 {
1492                     // Skip initializing the rest of the expression
1493                     return false;
1494                 }
1495                 else if (writeConstantInitialization(out, symbolNode, initializer))
1496                 {
1497                     return false;
1498                 }
1499             }
1500             else if (visit == InVisit)
1501             {
1502                 out << " = ";
1503                 if (IsInShaderStorageBlock(node->getRight()))
1504                 {
1505                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1506                     return false;
1507                 }
1508             }
1509             break;
1510         case EOpAddAssign:
1511             outputTriplet(out, visit, "(", " += ", ")");
1512             break;
1513         case EOpSubAssign:
1514             outputTriplet(out, visit, "(", " -= ", ")");
1515             break;
1516         case EOpMulAssign:
1517             outputTriplet(out, visit, "(", " *= ", ")");
1518             break;
1519         case EOpVectorTimesScalarAssign:
1520             outputTriplet(out, visit, "(", " *= ", ")");
1521             break;
1522         case EOpMatrixTimesScalarAssign:
1523             outputTriplet(out, visit, "(", " *= ", ")");
1524             break;
1525         case EOpVectorTimesMatrixAssign:
1526             if (visit == PreVisit)
1527             {
1528                 out << "(";
1529             }
1530             else if (visit == InVisit)
1531             {
1532                 out << " = mul(";
1533                 node->getLeft()->traverse(this);
1534                 out << ", transpose(";
1535             }
1536             else
1537             {
1538                 out << ")))";
1539             }
1540             break;
1541         case EOpMatrixTimesMatrixAssign:
1542             if (visit == PreVisit)
1543             {
1544                 out << "(";
1545             }
1546             else if (visit == InVisit)
1547             {
1548                 out << " = transpose(mul(transpose(";
1549                 node->getLeft()->traverse(this);
1550                 out << "), transpose(";
1551             }
1552             else
1553             {
1554                 out << "))))";
1555             }
1556             break;
1557         case EOpDivAssign:
1558             outputTriplet(out, visit, "(", " /= ", ")");
1559             break;
1560         case EOpIModAssign:
1561             outputTriplet(out, visit, "(", " %= ", ")");
1562             break;
1563         case EOpBitShiftLeftAssign:
1564             outputTriplet(out, visit, "(", " <<= ", ")");
1565             break;
1566         case EOpBitShiftRightAssign:
1567             outputTriplet(out, visit, "(", " >>= ", ")");
1568             break;
1569         case EOpBitwiseAndAssign:
1570             outputTriplet(out, visit, "(", " &= ", ")");
1571             break;
1572         case EOpBitwiseXorAssign:
1573             outputTriplet(out, visit, "(", " ^= ", ")");
1574             break;
1575         case EOpBitwiseOrAssign:
1576             outputTriplet(out, visit, "(", " |= ", ")");
1577             break;
1578         case EOpIndexDirect:
1579         {
1580             const TType &leftType = node->getLeft()->getType();
1581             if (leftType.isInterfaceBlock())
1582             {
1583                 if (visit == PreVisit)
1584                 {
1585                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1586                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1587 
1588                     ASSERT(leftType.getQualifier() == EvqUniform);
1589                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1590                     {
1591                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1592                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1593                     }
1594                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1595                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1596                         instanceArraySymbol->getName(), arrayIndex);
1597                     return false;
1598                 }
1599             }
1600             else if (ancestorEvaluatesToSamplerInStruct())
1601             {
1602                 // All parts of an expression that access a sampler in a struct need to use _ as
1603                 // separator to access the sampler variable that has been moved out of the struct.
1604                 outputTriplet(out, visit, "", "_", "");
1605             }
1606             else if (IsAtomicCounter(leftType.getBasicType()))
1607             {
1608                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1609             }
1610             else
1611             {
1612                 outputTriplet(out, visit, "", "[", "]");
1613             }
1614         }
1615         break;
1616         case EOpIndexIndirect:
1617         {
1618             // We do not currently support indirect references to interface blocks
1619             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1620 
1621             const TType &leftType = node->getLeft()->getType();
1622             if (IsAtomicCounter(leftType.getBasicType()))
1623             {
1624                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1625             }
1626             else
1627             {
1628                 outputTriplet(out, visit, "", "[", "]");
1629             }
1630             break;
1631         }
1632         case EOpIndexDirectStruct:
1633         {
1634             const TStructure *structure       = node->getLeft()->getType().getStruct();
1635             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1636             const TField *field               = structure->fields()[index->getIConst(0)];
1637 
1638             // In cases where indexing returns a sampler, we need to access the sampler variable
1639             // that has been moved out of the struct.
1640             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1641             if (visit == PreVisit && indexingReturnsSampler)
1642             {
1643                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1644                 // This prefix is only output at the beginning of the indexing expression, which
1645                 // may have multiple parts.
1646                 out << "angle";
1647             }
1648             if (!indexingReturnsSampler)
1649             {
1650                 // All parts of an expression that access a sampler in a struct need to use _ as
1651                 // separator to access the sampler variable that has been moved out of the struct.
1652                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1653             }
1654             if (visit == InVisit)
1655             {
1656                 if (indexingReturnsSampler)
1657                 {
1658                     out << "_" << field->name();
1659                 }
1660                 else
1661                 {
1662                     out << "." << DecorateField(field->name(), *structure);
1663                 }
1664 
1665                 return false;
1666             }
1667         }
1668         break;
1669         case EOpIndexDirectInterfaceBlock:
1670         {
1671             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1672             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1673                                               IsInStd140UniformBlock(node->getLeft()) &&
1674                                               needStructMapping(node);
1675             if (visit == PreVisit && structInStd140UniformBlock)
1676             {
1677                 mNeedStructMapping = true;
1678                 out << "map";
1679             }
1680             if (visit == InVisit)
1681             {
1682                 const TInterfaceBlock *interfaceBlock =
1683                     node->getLeft()->getType().getInterfaceBlock();
1684                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1685                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1686                 if (structInStd140UniformBlock)
1687                 {
1688                     out << "_";
1689                 }
1690                 else
1691                 {
1692                     out << ".";
1693                 }
1694                 out << Decorate(field->name());
1695 
1696                 return false;
1697             }
1698             break;
1699         }
1700         case EOpAdd:
1701             outputTriplet(out, visit, "(", " + ", ")");
1702             break;
1703         case EOpSub:
1704             outputTriplet(out, visit, "(", " - ", ")");
1705             break;
1706         case EOpMul:
1707             outputTriplet(out, visit, "(", " * ", ")");
1708             break;
1709         case EOpDiv:
1710             outputTriplet(out, visit, "(", " / ", ")");
1711             break;
1712         case EOpIMod:
1713             outputTriplet(out, visit, "(", " % ", ")");
1714             break;
1715         case EOpBitShiftLeft:
1716             outputTriplet(out, visit, "(", " << ", ")");
1717             break;
1718         case EOpBitShiftRight:
1719             outputTriplet(out, visit, "(", " >> ", ")");
1720             break;
1721         case EOpBitwiseAnd:
1722             outputTriplet(out, visit, "(", " & ", ")");
1723             break;
1724         case EOpBitwiseXor:
1725             outputTriplet(out, visit, "(", " ^ ", ")");
1726             break;
1727         case EOpBitwiseOr:
1728             outputTriplet(out, visit, "(", " | ", ")");
1729             break;
1730         case EOpEqual:
1731         case EOpNotEqual:
1732             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1733             break;
1734         case EOpLessThan:
1735             outputTriplet(out, visit, "(", " < ", ")");
1736             break;
1737         case EOpGreaterThan:
1738             outputTriplet(out, visit, "(", " > ", ")");
1739             break;
1740         case EOpLessThanEqual:
1741             outputTriplet(out, visit, "(", " <= ", ")");
1742             break;
1743         case EOpGreaterThanEqual:
1744             outputTriplet(out, visit, "(", " >= ", ")");
1745             break;
1746         case EOpVectorTimesScalar:
1747             outputTriplet(out, visit, "(", " * ", ")");
1748             break;
1749         case EOpMatrixTimesScalar:
1750             outputTriplet(out, visit, "(", " * ", ")");
1751             break;
1752         case EOpVectorTimesMatrix:
1753             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1754             break;
1755         case EOpMatrixTimesVector:
1756             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1757             break;
1758         case EOpMatrixTimesMatrix:
1759             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1760             break;
1761         case EOpLogicalOr:
1762             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1763             // been unfolded.
1764             ASSERT(!node->getRight()->hasSideEffects());
1765             outputTriplet(out, visit, "(", " || ", ")");
1766             return true;
1767         case EOpLogicalXor:
1768             mUsesXor = true;
1769             outputTriplet(out, visit, "xor(", ", ", ")");
1770             break;
1771         case EOpLogicalAnd:
1772             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1773             // been unfolded.
1774             ASSERT(!node->getRight()->hasSideEffects());
1775             outputTriplet(out, visit, "(", " && ", ")");
1776             return true;
1777         default:
1778             UNREACHABLE();
1779     }
1780 
1781     return true;
1782 }
1783 
visitUnary(Visit visit,TIntermUnary * node)1784 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1785 {
1786     TInfoSinkBase &out = getInfoSink();
1787 
1788     switch (node->getOp())
1789     {
1790         case EOpNegative:
1791             outputTriplet(out, visit, "(-", "", ")");
1792             break;
1793         case EOpPositive:
1794             outputTriplet(out, visit, "(+", "", ")");
1795             break;
1796         case EOpLogicalNot:
1797             outputTriplet(out, visit, "(!", "", ")");
1798             break;
1799         case EOpBitwiseNot:
1800             outputTriplet(out, visit, "(~", "", ")");
1801             break;
1802         case EOpPostIncrement:
1803             outputTriplet(out, visit, "(", "", "++)");
1804             break;
1805         case EOpPostDecrement:
1806             outputTriplet(out, visit, "(", "", "--)");
1807             break;
1808         case EOpPreIncrement:
1809             outputTriplet(out, visit, "(++", "", ")");
1810             break;
1811         case EOpPreDecrement:
1812             outputTriplet(out, visit, "(--", "", ")");
1813             break;
1814         case EOpRadians:
1815             outputTriplet(out, visit, "radians(", "", ")");
1816             break;
1817         case EOpDegrees:
1818             outputTriplet(out, visit, "degrees(", "", ")");
1819             break;
1820         case EOpSin:
1821             outputTriplet(out, visit, "sin(", "", ")");
1822             break;
1823         case EOpCos:
1824             outputTriplet(out, visit, "cos(", "", ")");
1825             break;
1826         case EOpTan:
1827             outputTriplet(out, visit, "tan(", "", ")");
1828             break;
1829         case EOpAsin:
1830             outputTriplet(out, visit, "asin(", "", ")");
1831             break;
1832         case EOpAcos:
1833             outputTriplet(out, visit, "acos(", "", ")");
1834             break;
1835         case EOpAtan:
1836             outputTriplet(out, visit, "atan(", "", ")");
1837             break;
1838         case EOpSinh:
1839             outputTriplet(out, visit, "sinh(", "", ")");
1840             break;
1841         case EOpCosh:
1842             outputTriplet(out, visit, "cosh(", "", ")");
1843             break;
1844         case EOpTanh:
1845         case EOpAsinh:
1846         case EOpAcosh:
1847         case EOpAtanh:
1848             ASSERT(node->getUseEmulatedFunction());
1849             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1850             break;
1851         case EOpExp:
1852             outputTriplet(out, visit, "exp(", "", ")");
1853             break;
1854         case EOpLog:
1855             outputTriplet(out, visit, "log(", "", ")");
1856             break;
1857         case EOpExp2:
1858             outputTriplet(out, visit, "exp2(", "", ")");
1859             break;
1860         case EOpLog2:
1861             outputTriplet(out, visit, "log2(", "", ")");
1862             break;
1863         case EOpSqrt:
1864             outputTriplet(out, visit, "sqrt(", "", ")");
1865             break;
1866         case EOpInversesqrt:
1867             outputTriplet(out, visit, "rsqrt(", "", ")");
1868             break;
1869         case EOpAbs:
1870             outputTriplet(out, visit, "abs(", "", ")");
1871             break;
1872         case EOpSign:
1873             outputTriplet(out, visit, "sign(", "", ")");
1874             break;
1875         case EOpFloor:
1876             outputTriplet(out, visit, "floor(", "", ")");
1877             break;
1878         case EOpTrunc:
1879             outputTriplet(out, visit, "trunc(", "", ")");
1880             break;
1881         case EOpRound:
1882             outputTriplet(out, visit, "round(", "", ")");
1883             break;
1884         case EOpRoundEven:
1885             ASSERT(node->getUseEmulatedFunction());
1886             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1887             break;
1888         case EOpCeil:
1889             outputTriplet(out, visit, "ceil(", "", ")");
1890             break;
1891         case EOpFract:
1892             outputTriplet(out, visit, "frac(", "", ")");
1893             break;
1894         case EOpIsnan:
1895             if (node->getUseEmulatedFunction())
1896                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1897             else
1898                 outputTriplet(out, visit, "isnan(", "", ")");
1899             mRequiresIEEEStrictCompiling = true;
1900             break;
1901         case EOpIsinf:
1902             outputTriplet(out, visit, "isinf(", "", ")");
1903             break;
1904         case EOpFloatBitsToInt:
1905             outputTriplet(out, visit, "asint(", "", ")");
1906             break;
1907         case EOpFloatBitsToUint:
1908             outputTriplet(out, visit, "asuint(", "", ")");
1909             break;
1910         case EOpIntBitsToFloat:
1911             outputTriplet(out, visit, "asfloat(", "", ")");
1912             break;
1913         case EOpUintBitsToFloat:
1914             outputTriplet(out, visit, "asfloat(", "", ")");
1915             break;
1916         case EOpPackSnorm2x16:
1917         case EOpPackUnorm2x16:
1918         case EOpPackHalf2x16:
1919         case EOpUnpackSnorm2x16:
1920         case EOpUnpackUnorm2x16:
1921         case EOpUnpackHalf2x16:
1922         case EOpPackUnorm4x8:
1923         case EOpPackSnorm4x8:
1924         case EOpUnpackUnorm4x8:
1925         case EOpUnpackSnorm4x8:
1926             ASSERT(node->getUseEmulatedFunction());
1927             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1928             break;
1929         case EOpLength:
1930             outputTriplet(out, visit, "length(", "", ")");
1931             break;
1932         case EOpNormalize:
1933             outputTriplet(out, visit, "normalize(", "", ")");
1934             break;
1935         case EOpDFdx:
1936             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1937             {
1938                 outputTriplet(out, visit, "(", "", ", 0.0)");
1939             }
1940             else
1941             {
1942                 outputTriplet(out, visit, "ddx(", "", ")");
1943             }
1944             break;
1945         case EOpDFdy:
1946             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1947             {
1948                 outputTriplet(out, visit, "(", "", ", 0.0)");
1949             }
1950             else
1951             {
1952                 outputTriplet(out, visit, "ddy(", "", ")");
1953             }
1954             break;
1955         case EOpFwidth:
1956             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1957             {
1958                 outputTriplet(out, visit, "(", "", ", 0.0)");
1959             }
1960             else
1961             {
1962                 outputTriplet(out, visit, "fwidth(", "", ")");
1963             }
1964             break;
1965         case EOpTranspose:
1966             outputTriplet(out, visit, "transpose(", "", ")");
1967             break;
1968         case EOpDeterminant:
1969             outputTriplet(out, visit, "determinant(transpose(", "", "))");
1970             break;
1971         case EOpInverse:
1972             ASSERT(node->getUseEmulatedFunction());
1973             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1974             break;
1975 
1976         case EOpAny:
1977             outputTriplet(out, visit, "any(", "", ")");
1978             break;
1979         case EOpAll:
1980             outputTriplet(out, visit, "all(", "", ")");
1981             break;
1982         case EOpLogicalNotComponentWise:
1983             outputTriplet(out, visit, "(!", "", ")");
1984             break;
1985         case EOpBitfieldReverse:
1986             outputTriplet(out, visit, "reversebits(", "", ")");
1987             break;
1988         case EOpBitCount:
1989             outputTriplet(out, visit, "countbits(", "", ")");
1990             break;
1991         case EOpFindLSB:
1992             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
1993             // in GLSLTest and results are consistent with GL.
1994             outputTriplet(out, visit, "firstbitlow(", "", ")");
1995             break;
1996         case EOpFindMSB:
1997             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
1998             // tested in GLSLTest and results are consistent with GL.
1999             outputTriplet(out, visit, "firstbithigh(", "", ")");
2000             break;
2001         case EOpArrayLength:
2002         {
2003             TIntermTyped *operand = node->getOperand();
2004             ASSERT(IsInShaderStorageBlock(operand));
2005             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2006             return false;
2007         }
2008         default:
2009             UNREACHABLE();
2010     }
2011 
2012     return true;
2013 }
2014 
samplerNamePrefixFromStruct(TIntermTyped * node)2015 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2016 {
2017     if (node->getAsSymbolNode())
2018     {
2019         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2020         return node->getAsSymbolNode()->getName();
2021     }
2022     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2023     switch (nodeBinary->getOp())
2024     {
2025         case EOpIndexDirect:
2026         {
2027             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2028 
2029             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2030             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2031             return ImmutableString(prefixSink.str());
2032         }
2033         case EOpIndexDirectStruct:
2034         {
2035             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2036             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2037             const TField *field = s->fields()[index];
2038 
2039             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2040             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2041                        << field->name();
2042             return ImmutableString(prefixSink.str());
2043         }
2044         default:
2045             UNREACHABLE();
2046             return kEmptyImmutableString;
2047     }
2048 }
2049 
visitBlock(Visit visit,TIntermBlock * node)2050 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2051 {
2052     TInfoSinkBase &out = getInfoSink();
2053 
2054     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2055 
2056     if (mInsideFunction)
2057     {
2058         outputLineDirective(out, node->getLine().first_line);
2059         out << "{\n";
2060         if (isMainBlock)
2061         {
2062             if (mShaderType == GL_COMPUTE_SHADER)
2063             {
2064                 out << "initGLBuiltins(input);\n";
2065             }
2066             else
2067             {
2068                 out << "@@ MAIN PROLOGUE @@\n";
2069             }
2070         }
2071     }
2072 
2073     for (TIntermNode *statement : *node->getSequence())
2074     {
2075         outputLineDirective(out, statement->getLine().first_line);
2076 
2077         statement->traverse(this);
2078 
2079         // Don't output ; after case labels, they're terminated by :
2080         // This is needed especially since outputting a ; after a case statement would turn empty
2081         // case statements into non-empty case statements, disallowing fall-through from them.
2082         // Also the output code is clearer if we don't output ; after statements where it is not
2083         // needed:
2084         //  * if statements
2085         //  * switch statements
2086         //  * blocks
2087         //  * function definitions
2088         //  * loops (do-while loops output the semicolon in VisitLoop)
2089         //  * declarations that don't generate output.
2090         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2091             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2092             statement->getAsSwitchNode() == nullptr &&
2093             statement->getAsFunctionDefinition() == nullptr &&
2094             (statement->getAsDeclarationNode() == nullptr ||
2095              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2096             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2097         {
2098             out << ";\n";
2099         }
2100     }
2101 
2102     if (mInsideFunction)
2103     {
2104         outputLineDirective(out, node->getLine().last_line);
2105         if (isMainBlock && shaderNeedsGenerateOutput())
2106         {
2107             // We could have an empty main, a main function without a branch at the end, or a main
2108             // function with a discard statement at the end. In these cases we need to add a return
2109             // statement.
2110             bool needReturnStatement =
2111                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2112                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2113             if (needReturnStatement)
2114             {
2115                 out << "return " << generateOutputCall() << ";\n";
2116             }
2117         }
2118         out << "}\n";
2119     }
2120 
2121     return false;
2122 }
2123 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2124 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2125 {
2126     TInfoSinkBase &out = getInfoSink();
2127 
2128     ASSERT(mCurrentFunctionMetadata == nullptr);
2129 
2130     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2131     ASSERT(index != CallDAG::InvalidIndex);
2132     mCurrentFunctionMetadata = &mASTMetadataList[index];
2133 
2134     const TFunction *func = node->getFunction();
2135 
2136     if (func->isMain())
2137     {
2138         // The stub strings below are replaced when shader is dynamically defined by its layout:
2139         switch (mShaderType)
2140         {
2141             case GL_VERTEX_SHADER:
2142                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2143                     << "@@ VERTEX OUTPUT @@\n\n"
2144                     << "VS_OUTPUT main(VS_INPUT input)";
2145                 break;
2146             case GL_FRAGMENT_SHADER:
2147                 out << "@@ PIXEL OUTPUT @@\n\n"
2148                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2149                 break;
2150             case GL_COMPUTE_SHADER:
2151                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2152                     << mWorkGroupSize[2] << ")]\n";
2153                 out << "void main(CS_INPUT input)";
2154                 break;
2155             default:
2156                 UNREACHABLE();
2157                 break;
2158         }
2159     }
2160     else
2161     {
2162         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2163         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2164             << (mOutputLod0Function ? "Lod0(" : "(");
2165 
2166         size_t paramCount = func->getParamCount();
2167         for (unsigned int i = 0; i < paramCount; i++)
2168         {
2169             const TVariable *param = func->getParam(i);
2170             ensureStructDefined(param->getType());
2171 
2172             writeParameter(param, out);
2173 
2174             if (i < paramCount - 1)
2175             {
2176                 out << ", ";
2177             }
2178         }
2179 
2180         out << ")\n";
2181     }
2182 
2183     mInsideFunction = true;
2184     if (func->isMain())
2185     {
2186         mInsideMain = true;
2187     }
2188     // The function body node will output braces.
2189     node->getBody()->traverse(this);
2190     mInsideFunction = false;
2191     mInsideMain     = false;
2192 
2193     mCurrentFunctionMetadata = nullptr;
2194 
2195     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2196     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2197     {
2198         ASSERT(!node->getFunction()->isMain());
2199         mOutputLod0Function = true;
2200         node->traverse(this);
2201         mOutputLod0Function = false;
2202     }
2203 
2204     return false;
2205 }
2206 
visitDeclaration(Visit visit,TIntermDeclaration * node)2207 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2208 {
2209     if (visit == PreVisit)
2210     {
2211         TIntermSequence *sequence = node->getSequence();
2212         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2213         ASSERT(sequence->size() == 1);
2214         ASSERT(declarator);
2215 
2216         if (IsDeclarationWrittenOut(node))
2217         {
2218             TInfoSinkBase &out = getInfoSink();
2219             ensureStructDefined(declarator->getType());
2220 
2221             if (!declarator->getAsSymbolNode() ||
2222                 declarator->getAsSymbolNode()->variable().symbolType() !=
2223                     SymbolType::Empty)  // Variable declaration
2224             {
2225                 if (declarator->getQualifier() == EvqShared)
2226                 {
2227                     out << "groupshared ";
2228                 }
2229                 else if (!mInsideFunction)
2230                 {
2231                     out << "static ";
2232                 }
2233 
2234                 out << TypeString(declarator->getType()) + " ";
2235 
2236                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2237 
2238                 if (symbol)
2239                 {
2240                     symbol->traverse(this);
2241                     out << ArrayString(symbol->getType());
2242                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2243                     // drivers to compile a compute shader if we add code to initialize a
2244                     // groupshared array variable with a large array size. And maybe produce
2245                     // incorrect result. See http://anglebug.com/3226.
2246                     if (declarator->getQualifier() != EvqShared)
2247                     {
2248                         out << " = " + zeroInitializer(symbol->getType());
2249                     }
2250                 }
2251                 else
2252                 {
2253                     declarator->traverse(this);
2254                 }
2255             }
2256         }
2257         else if (IsVaryingOut(declarator->getQualifier()))
2258         {
2259             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2260             ASSERT(symbol);  // Varying declarations can't have initializers.
2261 
2262             const TVariable &variable = symbol->variable();
2263 
2264             if (variable.symbolType() != SymbolType::Empty)
2265             {
2266                 // Vertex outputs which are declared but not written to should still be declared to
2267                 // allow successful linking.
2268                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2269             }
2270         }
2271     }
2272     return false;
2273 }
2274 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2275 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2276                                                  TIntermGlobalQualifierDeclaration *node)
2277 {
2278     // Do not do any translation
2279     return false;
2280 }
2281 
visitFunctionPrototype(TIntermFunctionPrototype * node)2282 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2283 {
2284     TInfoSinkBase &out = getInfoSink();
2285 
2286     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2287     // Skip the prototype if it is not implemented (and thus not used)
2288     if (index == CallDAG::InvalidIndex)
2289     {
2290         return;
2291     }
2292 
2293     const TFunction *func = node->getFunction();
2294 
2295     TString name = DecorateFunctionIfNeeded(func);
2296     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2297         << (mOutputLod0Function ? "Lod0(" : "(");
2298 
2299     size_t paramCount = func->getParamCount();
2300     for (unsigned int i = 0; i < paramCount; i++)
2301     {
2302         writeParameter(func->getParam(i), out);
2303 
2304         if (i < paramCount - 1)
2305         {
2306             out << ", ";
2307         }
2308     }
2309 
2310     out << ");\n";
2311 
2312     // Also prototype the Lod0 variant if needed
2313     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2314     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2315     {
2316         mOutputLod0Function = true;
2317         node->traverse(this);
2318         mOutputLod0Function = false;
2319     }
2320 }
2321 
visitAggregate(Visit visit,TIntermAggregate * node)2322 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2323 {
2324     TInfoSinkBase &out = getInfoSink();
2325 
2326     switch (node->getOp())
2327     {
2328         case EOpCallBuiltInFunction:
2329         case EOpCallFunctionInAST:
2330         case EOpCallInternalRawFunction:
2331         {
2332             TIntermSequence *arguments = node->getSequence();
2333 
2334             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2335                         mShaderType == GL_FRAGMENT_SHADER;
2336             if (node->getOp() == EOpCallFunctionInAST)
2337             {
2338                 if (node->isArray())
2339                 {
2340                     UNIMPLEMENTED();
2341                 }
2342                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2343                 ASSERT(index != CallDAG::InvalidIndex);
2344                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2345 
2346                 out << DecorateFunctionIfNeeded(node->getFunction());
2347                 out << DisambiguateFunctionName(node->getSequence());
2348                 out << (lod0 ? "Lod0(" : "(");
2349             }
2350             else if (node->getOp() == EOpCallInternalRawFunction)
2351             {
2352                 // This path is used for internal functions that don't have their definitions in the
2353                 // AST, such as precision emulation functions.
2354                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2355             }
2356             else if (node->getFunction()->isImageFunction())
2357             {
2358                 const ImmutableString &name              = node->getFunction()->name();
2359                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2360                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2361                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2362                     type.getMemoryQualifier().readonly);
2363                 out << imageFunctionName << "(";
2364             }
2365             else if (node->getFunction()->isAtomicCounterFunction())
2366             {
2367                 const ImmutableString &name = node->getFunction()->name();
2368                 ImmutableString atomicFunctionName =
2369                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2370                 out << atomicFunctionName << "(";
2371             }
2372             else
2373             {
2374                 const ImmutableString &name = node->getFunction()->name();
2375                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2376                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2377                 if (arguments->size() > 1)
2378                 {
2379                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2380                 }
2381                 const ImmutableString &textureFunctionName =
2382                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2383                                                              arguments->size(), lod0, mShaderType);
2384                 out << textureFunctionName << "(";
2385             }
2386 
2387             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2388             {
2389                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2390                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2391                 {
2392                     out << "texture_";
2393                     (*arg)->traverse(this);
2394                     out << ", sampler_";
2395                 }
2396 
2397                 (*arg)->traverse(this);
2398 
2399                 if (typedArg->getType().isStructureContainingSamplers())
2400                 {
2401                     const TType &argType = typedArg->getType();
2402                     TVector<const TVariable *> samplerSymbols;
2403                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2404                     std::string namePrefix     = "angle_";
2405                     namePrefix += structName.data();
2406                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2407                                                  nullptr, mSymbolTable);
2408                     for (const TVariable *sampler : samplerSymbols)
2409                     {
2410                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2411                         {
2412                             out << ", texture_" << sampler->name();
2413                             out << ", sampler_" << sampler->name();
2414                         }
2415                         else
2416                         {
2417                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2418                             // of D3D9, it's the sampler variable.
2419                             out << ", " << sampler->name();
2420                         }
2421                     }
2422                 }
2423 
2424                 if (arg < arguments->end() - 1)
2425                 {
2426                     out << ", ";
2427                 }
2428             }
2429 
2430             out << ")";
2431 
2432             return false;
2433         }
2434         case EOpConstruct:
2435             outputConstructor(out, visit, node);
2436             break;
2437         case EOpEqualComponentWise:
2438             outputTriplet(out, visit, "(", " == ", ")");
2439             break;
2440         case EOpNotEqualComponentWise:
2441             outputTriplet(out, visit, "(", " != ", ")");
2442             break;
2443         case EOpLessThanComponentWise:
2444             outputTriplet(out, visit, "(", " < ", ")");
2445             break;
2446         case EOpGreaterThanComponentWise:
2447             outputTriplet(out, visit, "(", " > ", ")");
2448             break;
2449         case EOpLessThanEqualComponentWise:
2450             outputTriplet(out, visit, "(", " <= ", ")");
2451             break;
2452         case EOpGreaterThanEqualComponentWise:
2453             outputTriplet(out, visit, "(", " >= ", ")");
2454             break;
2455         case EOpMod:
2456             ASSERT(node->getUseEmulatedFunction());
2457             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2458             break;
2459         case EOpModf:
2460             outputTriplet(out, visit, "modf(", ", ", ")");
2461             break;
2462         case EOpPow:
2463             outputTriplet(out, visit, "pow(", ", ", ")");
2464             break;
2465         case EOpAtan:
2466             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2467             ASSERT(node->getUseEmulatedFunction());
2468             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2469             break;
2470         case EOpMin:
2471             outputTriplet(out, visit, "min(", ", ", ")");
2472             break;
2473         case EOpMax:
2474             outputTriplet(out, visit, "max(", ", ", ")");
2475             break;
2476         case EOpClamp:
2477             outputTriplet(out, visit, "clamp(", ", ", ")");
2478             break;
2479         case EOpMix:
2480         {
2481             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2482             if (lastParamNode->getType().getBasicType() == EbtBool)
2483             {
2484                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2485                 // y, genBType a)",
2486                 // so use emulated version.
2487                 ASSERT(node->getUseEmulatedFunction());
2488                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2489             }
2490             else
2491             {
2492                 outputTriplet(out, visit, "lerp(", ", ", ")");
2493             }
2494             break;
2495         }
2496         case EOpStep:
2497             outputTriplet(out, visit, "step(", ", ", ")");
2498             break;
2499         case EOpSmoothstep:
2500             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2501             break;
2502         case EOpFma:
2503             outputTriplet(out, visit, "mad(", ", ", ")");
2504             break;
2505         case EOpFrexp:
2506         case EOpLdexp:
2507             ASSERT(node->getUseEmulatedFunction());
2508             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2509             break;
2510         case EOpDistance:
2511             outputTriplet(out, visit, "distance(", ", ", ")");
2512             break;
2513         case EOpDot:
2514             outputTriplet(out, visit, "dot(", ", ", ")");
2515             break;
2516         case EOpCross:
2517             outputTriplet(out, visit, "cross(", ", ", ")");
2518             break;
2519         case EOpFaceforward:
2520             ASSERT(node->getUseEmulatedFunction());
2521             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2522             break;
2523         case EOpReflect:
2524             outputTriplet(out, visit, "reflect(", ", ", ")");
2525             break;
2526         case EOpRefract:
2527             outputTriplet(out, visit, "refract(", ", ", ")");
2528             break;
2529         case EOpOuterProduct:
2530             ASSERT(node->getUseEmulatedFunction());
2531             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2532             break;
2533         case EOpMulMatrixComponentWise:
2534             outputTriplet(out, visit, "(", " * ", ")");
2535             break;
2536         case EOpBitfieldExtract:
2537         case EOpBitfieldInsert:
2538         case EOpUaddCarry:
2539         case EOpUsubBorrow:
2540         case EOpUmulExtended:
2541         case EOpImulExtended:
2542             ASSERT(node->getUseEmulatedFunction());
2543             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2544             break;
2545         case EOpBarrier:
2546             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2547             // cheapest *WithGroupSync() function, without any functionality loss, but
2548             // with the potential for severe performance loss.
2549             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2550             break;
2551         case EOpMemoryBarrierShared:
2552             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2553             break;
2554         case EOpMemoryBarrierAtomicCounter:
2555         case EOpMemoryBarrierBuffer:
2556         case EOpMemoryBarrierImage:
2557             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2558             break;
2559         case EOpGroupMemoryBarrier:
2560         case EOpMemoryBarrier:
2561             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2562             break;
2563 
2564         // Single atomic function calls without return value.
2565         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2566         case EOpAtomicAdd:
2567         case EOpAtomicMin:
2568         case EOpAtomicMax:
2569         case EOpAtomicAnd:
2570         case EOpAtomicOr:
2571         case EOpAtomicXor:
2572         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2573         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2574         // optional.
2575         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2576         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2577         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2578         // compare_value, value) should all be modified into the form of "int temp; temp =
2579         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2580         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2581         case EOpAtomicExchange:
2582         case EOpAtomicCompSwap:
2583         {
2584             ASSERT(node->getChildCount() > 1);
2585             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2586             if (IsInShaderStorageBlock(memNode))
2587             {
2588                 // Atomic memory functions for SSBO.
2589                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2590                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2591                 // Write the rest argument list to |out|.
2592                 for (size_t i = 1; i < node->getChildCount(); i++)
2593                 {
2594                     out << ", ";
2595                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2596                     if (IsInShaderStorageBlock(argument))
2597                     {
2598                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2599                     }
2600                     else
2601                     {
2602                         argument->traverse(this);
2603                     }
2604                 }
2605 
2606                 out << ")";
2607                 return false;
2608             }
2609             else
2610             {
2611                 // Atomic memory functions for shared variable.
2612                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2613                 {
2614                     outputTriplet(out, visit,
2615                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2616                                   ")");
2617                 }
2618                 else
2619                 {
2620                     UNREACHABLE();
2621                 }
2622             }
2623 
2624             break;
2625         }
2626         default:
2627             UNREACHABLE();
2628     }
2629 
2630     return true;
2631 }
2632 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2633 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2634 {
2635     out << "if (";
2636 
2637     node->getCondition()->traverse(this);
2638 
2639     out << ")\n";
2640 
2641     outputLineDirective(out, node->getLine().first_line);
2642 
2643     bool discard = false;
2644 
2645     if (node->getTrueBlock())
2646     {
2647         // The trueBlock child node will output braces.
2648         node->getTrueBlock()->traverse(this);
2649 
2650         // Detect true discard
2651         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2652     }
2653     else
2654     {
2655         // TODO(oetuaho): Check if the semicolon inside is necessary.
2656         // It's there as a result of conservative refactoring of the output.
2657         out << "{;}\n";
2658     }
2659 
2660     outputLineDirective(out, node->getLine().first_line);
2661 
2662     if (node->getFalseBlock())
2663     {
2664         out << "else\n";
2665 
2666         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2667 
2668         // The falseBlock child node will output braces.
2669         node->getFalseBlock()->traverse(this);
2670 
2671         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2672 
2673         // Detect false discard
2674         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2675     }
2676 
2677     // ANGLE issue 486: Detect problematic conditional discard
2678     if (discard)
2679     {
2680         mUsesDiscardRewriting = true;
2681     }
2682 }
2683 
visitTernary(Visit,TIntermTernary *)2684 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2685 {
2686     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2687     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2688     UNREACHABLE();
2689     return false;
2690 }
2691 
visitIfElse(Visit visit,TIntermIfElse * node)2692 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2693 {
2694     TInfoSinkBase &out = getInfoSink();
2695 
2696     ASSERT(mInsideFunction);
2697 
2698     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2699     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2700     {
2701         out << "FLATTEN ";
2702     }
2703 
2704     writeIfElse(out, node);
2705 
2706     return false;
2707 }
2708 
visitSwitch(Visit visit,TIntermSwitch * node)2709 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2710 {
2711     TInfoSinkBase &out = getInfoSink();
2712 
2713     ASSERT(node->getStatementList());
2714     if (visit == PreVisit)
2715     {
2716         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2717     }
2718     outputTriplet(out, visit, "switch (", ") ", "");
2719     // The curly braces get written when visiting the statementList block.
2720     return true;
2721 }
2722 
visitCase(Visit visit,TIntermCase * node)2723 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2724 {
2725     TInfoSinkBase &out = getInfoSink();
2726 
2727     if (node->hasCondition())
2728     {
2729         outputTriplet(out, visit, "case (", "", "):\n");
2730         return true;
2731     }
2732     else
2733     {
2734         out << "default:\n";
2735         return false;
2736     }
2737 }
2738 
visitConstantUnion(TIntermConstantUnion * node)2739 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2740 {
2741     TInfoSinkBase &out = getInfoSink();
2742     writeConstantUnion(out, node->getType(), node->getConstantValue());
2743 }
2744 
visitLoop(Visit visit,TIntermLoop * node)2745 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2746 {
2747     mNestedLoopDepth++;
2748 
2749     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2750     mInsideDiscontinuousLoop =
2751         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2752 
2753     TInfoSinkBase &out = getInfoSink();
2754 
2755     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2756     {
2757         if (handleExcessiveLoop(out, node))
2758         {
2759             mInsideDiscontinuousLoop = wasDiscontinuous;
2760             mNestedLoopDepth--;
2761 
2762             return false;
2763         }
2764     }
2765 
2766     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2767     if (node->getType() == ELoopDoWhile)
2768     {
2769         out << "{" << unroll << " do\n";
2770 
2771         outputLineDirective(out, node->getLine().first_line);
2772     }
2773     else
2774     {
2775         out << "{" << unroll << " for(";
2776 
2777         if (node->getInit())
2778         {
2779             node->getInit()->traverse(this);
2780         }
2781 
2782         out << "; ";
2783 
2784         if (node->getCondition())
2785         {
2786             node->getCondition()->traverse(this);
2787         }
2788 
2789         out << "; ";
2790 
2791         if (node->getExpression())
2792         {
2793             node->getExpression()->traverse(this);
2794         }
2795 
2796         out << ")\n";
2797 
2798         outputLineDirective(out, node->getLine().first_line);
2799     }
2800 
2801     if (node->getBody())
2802     {
2803         // The loop body node will output braces.
2804         node->getBody()->traverse(this);
2805     }
2806     else
2807     {
2808         // TODO(oetuaho): Check if the semicolon inside is necessary.
2809         // It's there as a result of conservative refactoring of the output.
2810         out << "{;}\n";
2811     }
2812 
2813     outputLineDirective(out, node->getLine().first_line);
2814 
2815     if (node->getType() == ELoopDoWhile)
2816     {
2817         outputLineDirective(out, node->getCondition()->getLine().first_line);
2818         out << "while (";
2819 
2820         node->getCondition()->traverse(this);
2821 
2822         out << ");\n";
2823     }
2824 
2825     out << "}\n";
2826 
2827     mInsideDiscontinuousLoop = wasDiscontinuous;
2828     mNestedLoopDepth--;
2829 
2830     return false;
2831 }
2832 
visitBranch(Visit visit,TIntermBranch * node)2833 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2834 {
2835     if (visit == PreVisit)
2836     {
2837         TInfoSinkBase &out = getInfoSink();
2838 
2839         switch (node->getFlowOp())
2840         {
2841             case EOpKill:
2842                 out << "discard";
2843                 break;
2844             case EOpBreak:
2845                 if (mNestedLoopDepth > 1)
2846                 {
2847                     mUsesNestedBreak = true;
2848                 }
2849 
2850                 if (mExcessiveLoopIndex)
2851                 {
2852                     out << "{Break";
2853                     mExcessiveLoopIndex->traverse(this);
2854                     out << " = true; break;}\n";
2855                 }
2856                 else
2857                 {
2858                     out << "break";
2859                 }
2860                 break;
2861             case EOpContinue:
2862                 out << "continue";
2863                 break;
2864             case EOpReturn:
2865                 if (node->getExpression())
2866                 {
2867                     ASSERT(!mInsideMain);
2868                     out << "return ";
2869                 }
2870                 else
2871                 {
2872                     if (mInsideMain && shaderNeedsGenerateOutput())
2873                     {
2874                         out << "return " << generateOutputCall();
2875                     }
2876                     else
2877                     {
2878                         out << "return";
2879                     }
2880                 }
2881                 break;
2882             default:
2883                 UNREACHABLE();
2884         }
2885     }
2886 
2887     return true;
2888 }
2889 
2890 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2891 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2892 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2893 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2894 {
2895     const int MAX_LOOP_ITERATIONS = 254;
2896 
2897     // Parse loops of the form:
2898     // for(int index = initial; index [comparator] limit; index += increment)
2899     TIntermSymbol *index = nullptr;
2900     TOperator comparator = EOpNull;
2901     int initial          = 0;
2902     int limit            = 0;
2903     int increment        = 0;
2904 
2905     // Parse index name and intial value
2906     if (node->getInit())
2907     {
2908         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2909 
2910         if (init)
2911         {
2912             TIntermSequence *sequence = init->getSequence();
2913             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2914 
2915             if (variable && variable->getQualifier() == EvqTemporary)
2916             {
2917                 TIntermBinary *assign = variable->getAsBinaryNode();
2918 
2919                 if (assign->getOp() == EOpInitialize)
2920                 {
2921                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2922                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2923 
2924                     if (symbol && constant)
2925                     {
2926                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2927                         {
2928                             index   = symbol;
2929                             initial = constant->getIConst(0);
2930                         }
2931                     }
2932                 }
2933             }
2934         }
2935     }
2936 
2937     // Parse comparator and limit value
2938     if (index != nullptr && node->getCondition())
2939     {
2940         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
2941 
2942         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
2943         {
2944             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
2945 
2946             if (constant)
2947             {
2948                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2949                 {
2950                     comparator = test->getOp();
2951                     limit      = constant->getIConst(0);
2952                 }
2953             }
2954         }
2955     }
2956 
2957     // Parse increment
2958     if (index != nullptr && comparator != EOpNull && node->getExpression())
2959     {
2960         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
2961         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
2962 
2963         if (binaryTerminal)
2964         {
2965             TOperator op                   = binaryTerminal->getOp();
2966             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
2967 
2968             if (constant)
2969             {
2970                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2971                 {
2972                     int value = constant->getIConst(0);
2973 
2974                     switch (op)
2975                     {
2976                         case EOpAddAssign:
2977                             increment = value;
2978                             break;
2979                         case EOpSubAssign:
2980                             increment = -value;
2981                             break;
2982                         default:
2983                             UNIMPLEMENTED();
2984                     }
2985                 }
2986             }
2987         }
2988         else if (unaryTerminal)
2989         {
2990             TOperator op = unaryTerminal->getOp();
2991 
2992             switch (op)
2993             {
2994                 case EOpPostIncrement:
2995                     increment = 1;
2996                     break;
2997                 case EOpPostDecrement:
2998                     increment = -1;
2999                     break;
3000                 case EOpPreIncrement:
3001                     increment = 1;
3002                     break;
3003                 case EOpPreDecrement:
3004                     increment = -1;
3005                     break;
3006                 default:
3007                     UNIMPLEMENTED();
3008             }
3009         }
3010     }
3011 
3012     if (index != nullptr && comparator != EOpNull && increment != 0)
3013     {
3014         if (comparator == EOpLessThanEqual)
3015         {
3016             comparator = EOpLessThan;
3017             limit += 1;
3018         }
3019 
3020         if (comparator == EOpLessThan)
3021         {
3022             int iterations = (limit - initial) / increment;
3023 
3024             if (iterations <= MAX_LOOP_ITERATIONS)
3025             {
3026                 return false;  // Not an excessive loop
3027             }
3028 
3029             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3030             mExcessiveLoopIndex         = index;
3031 
3032             out << "{int ";
3033             index->traverse(this);
3034             out << ";\n"
3035                    "bool Break";
3036             index->traverse(this);
3037             out << " = false;\n";
3038 
3039             bool firstLoopFragment = true;
3040 
3041             while (iterations > 0)
3042             {
3043                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3044 
3045                 if (!firstLoopFragment)
3046                 {
3047                     out << "if (!Break";
3048                     index->traverse(this);
3049                     out << ") {\n";
3050                 }
3051 
3052                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3053                 {
3054                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3055                 }
3056 
3057                 // for(int index = initial; index < clampedLimit; index += increment)
3058                 const char *unroll =
3059                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3060 
3061                 out << unroll << " for(";
3062                 index->traverse(this);
3063                 out << " = ";
3064                 out << initial;
3065 
3066                 out << "; ";
3067                 index->traverse(this);
3068                 out << " < ";
3069                 out << clampedLimit;
3070 
3071                 out << "; ";
3072                 index->traverse(this);
3073                 out << " += ";
3074                 out << increment;
3075                 out << ")\n";
3076 
3077                 outputLineDirective(out, node->getLine().first_line);
3078                 out << "{\n";
3079 
3080                 if (node->getBody())
3081                 {
3082                     node->getBody()->traverse(this);
3083                 }
3084 
3085                 outputLineDirective(out, node->getLine().first_line);
3086                 out << ";}\n";
3087 
3088                 if (!firstLoopFragment)
3089                 {
3090                     out << "}\n";
3091                 }
3092 
3093                 firstLoopFragment = false;
3094 
3095                 initial += MAX_LOOP_ITERATIONS * increment;
3096                 iterations -= MAX_LOOP_ITERATIONS;
3097             }
3098 
3099             out << "}";
3100 
3101             mExcessiveLoopIndex = restoreIndex;
3102 
3103             return true;
3104         }
3105         else
3106             UNIMPLEMENTED();
3107     }
3108 
3109     return false;  // Not handled as an excessive loop
3110 }
3111 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3112 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3113                                Visit visit,
3114                                const char *preString,
3115                                const char *inString,
3116                                const char *postString)
3117 {
3118     if (visit == PreVisit)
3119     {
3120         out << preString;
3121     }
3122     else if (visit == InVisit)
3123     {
3124         out << inString;
3125     }
3126     else if (visit == PostVisit)
3127     {
3128         out << postString;
3129     }
3130 }
3131 
outputLineDirective(TInfoSinkBase & out,int line)3132 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3133 {
3134     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
3135     {
3136         out << "\n";
3137         out << "#line " << line;
3138 
3139         if (mSourcePath)
3140         {
3141             out << " \"" << mSourcePath << "\"";
3142         }
3143 
3144         out << "\n";
3145     }
3146 }
3147 
writeParameter(const TVariable * param,TInfoSinkBase & out)3148 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3149 {
3150     const TType &type    = param->getType();
3151     TQualifier qualifier = type.getQualifier();
3152 
3153     TString nameStr = DecorateVariableIfNeeded(*param);
3154     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3155 
3156     if (IsSampler(type.getBasicType()))
3157     {
3158         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3159         {
3160             // Samplers are passed as indices to the sampler array.
3161             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3162             out << "const uint " << nameStr << ArrayString(type);
3163             return;
3164         }
3165         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3166         {
3167             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3168                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3169                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3170                 << ArrayString(type);
3171             return;
3172         }
3173     }
3174 
3175     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3176     // buffer offset.
3177     if (IsAtomicCounter(type.getBasicType()))
3178     {
3179         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3180             << nameStr << "_offset";
3181     }
3182     else
3183     {
3184         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3185             << ArrayString(type);
3186     }
3187 
3188     // If the structure parameter contains samplers, they need to be passed into the function as
3189     // separate parameters. HLSL doesn't natively support samplers in structs.
3190     if (type.isStructureContainingSamplers())
3191     {
3192         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3193         TVector<const TVariable *> samplerSymbols;
3194         std::string namePrefix = "angle";
3195         namePrefix += nameStr.c_str();
3196         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3197                                   mSymbolTable);
3198         for (const TVariable *sampler : samplerSymbols)
3199         {
3200             const TType &samplerType = sampler->getType();
3201             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3202             {
3203                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3204             }
3205             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3206             {
3207                 ASSERT(IsSampler(samplerType.getBasicType()));
3208                 out << ", " << QualifierString(qualifier) << " "
3209                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3210                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3211                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3212                     << ArrayString(samplerType);
3213             }
3214             else
3215             {
3216                 ASSERT(IsSampler(samplerType.getBasicType()));
3217                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3218                     << sampler->name() << ArrayString(samplerType);
3219             }
3220         }
3221     }
3222 }
3223 
zeroInitializer(const TType & type) const3224 TString OutputHLSL::zeroInitializer(const TType &type) const
3225 {
3226     TString string;
3227 
3228     size_t size = type.getObjectSize();
3229     if (size >= kZeroCount)
3230     {
3231         mUseZeroArray = true;
3232     }
3233     string = GetZeroInitializer(size).c_str();
3234 
3235     return "{" + string + "}";
3236 }
3237 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3238 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3239 {
3240     // Array constructors should have been already pruned from the code.
3241     ASSERT(!node->getType().isArray());
3242 
3243     if (visit == PreVisit)
3244     {
3245         TString constructorName;
3246         if (node->getBasicType() == EbtStruct)
3247         {
3248             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3249         }
3250         else
3251         {
3252             constructorName =
3253                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3254         }
3255         out << constructorName << "(";
3256     }
3257     else if (visit == InVisit)
3258     {
3259         out << ", ";
3260     }
3261     else if (visit == PostVisit)
3262     {
3263         out << ")";
3264     }
3265 }
3266 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3267 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3268                                                      const TType &type,
3269                                                      const TConstantUnion *const constUnion)
3270 {
3271     ASSERT(!type.isArray());
3272 
3273     const TConstantUnion *constUnionIterated = constUnion;
3274 
3275     const TStructure *structure = type.getStruct();
3276     if (structure)
3277     {
3278         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3279 
3280         const TFieldList &fields = structure->fields();
3281 
3282         for (size_t i = 0; i < fields.size(); i++)
3283         {
3284             const TType *fieldType = fields[i]->type();
3285             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3286 
3287             if (i != fields.size() - 1)
3288             {
3289                 out << ", ";
3290             }
3291         }
3292 
3293         out << ")";
3294     }
3295     else
3296     {
3297         size_t size    = type.getObjectSize();
3298         bool writeType = size > 1;
3299 
3300         if (writeType)
3301         {
3302             out << TypeString(type) << "(";
3303         }
3304         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3305         if (writeType)
3306         {
3307             out << ")";
3308         }
3309     }
3310 
3311     return constUnionIterated;
3312 }
3313 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3314 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3315 {
3316     if (visit == PreVisit)
3317     {
3318         const char *opStr = GetOperatorString(op);
3319         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3320         out << "(";
3321     }
3322     else
3323     {
3324         outputTriplet(out, visit, nullptr, ", ", ")");
3325     }
3326 }
3327 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3328 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3329                                             TIntermSymbol *symbolNode,
3330                                             TIntermTyped *expression)
3331 {
3332     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3333     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3334 
3335     if (symbolInInitializer)
3336     {
3337         // Type already printed
3338         out << "t" + str(mUniqueIndex) + " = ";
3339         expression->traverse(this);
3340         out << ", ";
3341         symbolNode->traverse(this);
3342         out << " = t" + str(mUniqueIndex);
3343 
3344         mUniqueIndex++;
3345         return true;
3346     }
3347 
3348     return false;
3349 }
3350 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3351 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3352                                              TIntermSymbol *symbolNode,
3353                                              TIntermTyped *initializer)
3354 {
3355     if (initializer->hasConstantValue())
3356     {
3357         symbolNode->traverse(this);
3358         out << ArrayString(symbolNode->getType());
3359         out << " = {";
3360         writeConstantUnionArray(out, initializer->getConstantValue(),
3361                                 initializer->getType().getObjectSize());
3362         out << "}";
3363         return true;
3364     }
3365     return false;
3366 }
3367 
addStructEqualityFunction(const TStructure & structure)3368 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3369 {
3370     const TFieldList &fields = structure.fields();
3371 
3372     for (const auto &eqFunction : mStructEqualityFunctions)
3373     {
3374         if (eqFunction->structure == &structure)
3375         {
3376             return eqFunction->functionName;
3377         }
3378     }
3379 
3380     const TString &structNameString = StructNameString(structure);
3381 
3382     StructEqualityFunction *function = new StructEqualityFunction();
3383     function->structure              = &structure;
3384     function->functionName           = "angle_eq_" + structNameString;
3385 
3386     TInfoSinkBase fnOut;
3387 
3388     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3389           << structNameString + " b)\n"
3390           << "{\n"
3391              "    return ";
3392 
3393     for (size_t i = 0; i < fields.size(); i++)
3394     {
3395         const TField *field    = fields[i];
3396         const TType *fieldType = field->type();
3397 
3398         const TString &fieldNameA = "a." + Decorate(field->name());
3399         const TString &fieldNameB = "b." + Decorate(field->name());
3400 
3401         if (i > 0)
3402         {
3403             fnOut << " && ";
3404         }
3405 
3406         fnOut << "(";
3407         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3408         fnOut << fieldNameA;
3409         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3410         fnOut << fieldNameB;
3411         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3412         fnOut << ")";
3413     }
3414 
3415     fnOut << ";\n"
3416           << "}\n";
3417 
3418     function->functionDefinition = fnOut.c_str();
3419 
3420     mStructEqualityFunctions.push_back(function);
3421     mEqualityFunctions.push_back(function);
3422 
3423     return function->functionName;
3424 }
3425 
addArrayEqualityFunction(const TType & type)3426 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3427 {
3428     for (const auto &eqFunction : mArrayEqualityFunctions)
3429     {
3430         if (eqFunction->type == type)
3431         {
3432             return eqFunction->functionName;
3433         }
3434     }
3435 
3436     TType elementType(type);
3437     elementType.toArrayElementType();
3438 
3439     ArrayHelperFunction *function = new ArrayHelperFunction();
3440     function->type                = type;
3441 
3442     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3443 
3444     TInfoSinkBase fnOut;
3445 
3446     const TString &typeName = TypeString(type);
3447     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3448           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3449           << "{\n"
3450              "    for (int i = 0; i < "
3451           << type.getOutermostArraySize()
3452           << "; ++i)\n"
3453              "    {\n"
3454              "        if (";
3455 
3456     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3457     fnOut << "a[i]";
3458     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3459     fnOut << "b[i]";
3460     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3461 
3462     fnOut << ") { return false; }\n"
3463              "    }\n"
3464              "    return true;\n"
3465              "}\n";
3466 
3467     function->functionDefinition = fnOut.c_str();
3468 
3469     mArrayEqualityFunctions.push_back(function);
3470     mEqualityFunctions.push_back(function);
3471 
3472     return function->functionName;
3473 }
3474 
addArrayAssignmentFunction(const TType & type)3475 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3476 {
3477     for (const auto &assignFunction : mArrayAssignmentFunctions)
3478     {
3479         if (assignFunction.type == type)
3480         {
3481             return assignFunction.functionName;
3482         }
3483     }
3484 
3485     TType elementType(type);
3486     elementType.toArrayElementType();
3487 
3488     ArrayHelperFunction function;
3489     function.type = type;
3490 
3491     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3492 
3493     TInfoSinkBase fnOut;
3494 
3495     const TString &typeName = TypeString(type);
3496     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3497           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3498           << "{\n"
3499              "    for (int i = 0; i < "
3500           << type.getOutermostArraySize()
3501           << "; ++i)\n"
3502              "    {\n"
3503              "        ";
3504 
3505     outputAssign(PreVisit, elementType, fnOut);
3506     fnOut << "a[i]";
3507     outputAssign(InVisit, elementType, fnOut);
3508     fnOut << "b[i]";
3509     outputAssign(PostVisit, elementType, fnOut);
3510 
3511     fnOut << ";\n"
3512              "    }\n"
3513              "}\n";
3514 
3515     function.functionDefinition = fnOut.c_str();
3516 
3517     mArrayAssignmentFunctions.push_back(function);
3518 
3519     return function.functionName;
3520 }
3521 
addArrayConstructIntoFunction(const TType & type)3522 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3523 {
3524     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3525     {
3526         if (constructIntoFunction.type == type)
3527         {
3528             return constructIntoFunction.functionName;
3529         }
3530     }
3531 
3532     TType elementType(type);
3533     elementType.toArrayElementType();
3534 
3535     ArrayHelperFunction function;
3536     function.type = type;
3537 
3538     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3539 
3540     TInfoSinkBase fnOut;
3541 
3542     const TString &typeName = TypeString(type);
3543     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3544     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3545     {
3546         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3547     }
3548     fnOut << ")\n"
3549              "{\n";
3550 
3551     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3552     {
3553         fnOut << "    ";
3554         outputAssign(PreVisit, elementType, fnOut);
3555         fnOut << "a[" << i << "]";
3556         outputAssign(InVisit, elementType, fnOut);
3557         fnOut << "b" << i;
3558         outputAssign(PostVisit, elementType, fnOut);
3559         fnOut << ";\n";
3560     }
3561     fnOut << "}\n";
3562 
3563     function.functionDefinition = fnOut.c_str();
3564 
3565     mArrayConstructIntoFunctions.push_back(function);
3566 
3567     return function.functionName;
3568 }
3569 
ensureStructDefined(const TType & type)3570 void OutputHLSL::ensureStructDefined(const TType &type)
3571 {
3572     const TStructure *structure = type.getStruct();
3573     if (structure)
3574     {
3575         ASSERT(type.getBasicType() == EbtStruct);
3576         mStructureHLSL->ensureStructDefined(*structure);
3577     }
3578 }
3579 
shaderNeedsGenerateOutput() const3580 bool OutputHLSL::shaderNeedsGenerateOutput() const
3581 {
3582     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3583 }
3584 
generateOutputCall() const3585 const char *OutputHLSL::generateOutputCall() const
3586 {
3587     if (mShaderType == GL_VERTEX_SHADER)
3588     {
3589         return "generateOutput(input)";
3590     }
3591     else
3592     {
3593         return "generateOutput()";
3594     }
3595 }
3596 
3597 }  // namespace sh
3598