1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/AtomicCounterFunctionHLSL.h"
17 #include "compiler/translator/BuiltInFunctionEmulator.h"
18 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
19 #include "compiler/translator/ImageFunctionHLSL.h"
20 #include "compiler/translator/InfoSink.h"
21 #include "compiler/translator/ResourcesHLSL.h"
22 #include "compiler/translator/StructureHLSL.h"
23 #include "compiler/translator/TextureFunctionHLSL.h"
24 #include "compiler/translator/TranslatorHLSL.h"
25 #include "compiler/translator/UtilsHLSL.h"
26 #include "compiler/translator/blocklayout.h"
27 #include "compiler/translator/tree_ops/d3d/RemoveSwitchFallThrough.h"
28 #include "compiler/translator/tree_util/FindSymbolNode.h"
29 #include "compiler/translator/tree_util/NodeSearch.h"
30 #include "compiler/translator/util.h"
31 
32 namespace sh
33 {
34 
35 namespace
36 {
37 
38 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
39 
ArrayHelperFunctionName(const char * prefix,const TType & type)40 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
41 {
42     TStringStream fnName = sh::InitializeStream<TStringStream>();
43     fnName << prefix << "_";
44     if (type.isArray())
45     {
46         for (unsigned int arraySize : type.getArraySizes())
47         {
48             fnName << arraySize << "_";
49         }
50     }
51     fnName << TypeString(type);
52     return fnName.str();
53 }
54 
IsDeclarationWrittenOut(TIntermDeclaration * node)55 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
56 {
57     TIntermSequence *sequence = node->getSequence();
58     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
59     ASSERT(sequence->size() == 1);
60     ASSERT(variable);
61     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
62             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
63 }
64 
IsInStd140UniformBlock(TIntermTyped * node)65 bool IsInStd140UniformBlock(TIntermTyped *node)
66 {
67     TIntermBinary *binaryNode = node->getAsBinaryNode();
68 
69     if (binaryNode)
70     {
71         return IsInStd140UniformBlock(binaryNode->getLeft());
72     }
73 
74     const TType &type = node->getType();
75 
76     if (type.getQualifier() == EvqUniform)
77     {
78         // determine if we are in the standard layout
79         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80         if (interfaceBlock)
81         {
82             return (interfaceBlock->blockStorage() == EbsStd140);
83         }
84     }
85 
86     return false;
87 }
88 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)89 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
90 {
91     const TIntermBinary *binaryNode = node->getAsBinaryNode();
92     if (binaryNode)
93     {
94         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
95         {
96             return binaryNode->getLeft()->getType().getInterfaceBlock();
97         }
98     }
99 
100     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
101     if (symbolNode)
102     {
103         const TVariable &variable = symbolNode->variable();
104         const TType &variableType = variable.getType();
105 
106         if (variableType.getQualifier() == EvqUniform &&
107             variable.symbolType() == SymbolType::UserDefined)
108         {
109             return variableType.getInterfaceBlock();
110         }
111     }
112 
113     return nullptr;
114 }
115 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)116 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
117 {
118     switch (op)
119     {
120         case EOpAtomicAdd:
121             return "InterlockedAdd(";
122         case EOpAtomicMin:
123             return "InterlockedMin(";
124         case EOpAtomicMax:
125             return "InterlockedMax(";
126         case EOpAtomicAnd:
127             return "InterlockedAnd(";
128         case EOpAtomicOr:
129             return "InterlockedOr(";
130         case EOpAtomicXor:
131             return "InterlockedXor(";
132         case EOpAtomicExchange:
133             return "InterlockedExchange(";
134         case EOpAtomicCompSwap:
135             return "InterlockedCompareExchange(";
136         default:
137             UNREACHABLE();
138             return "";
139     }
140 }
141 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)142 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
143 {
144     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
145     if (aggregateNode == nullptr)
146     {
147         return false;
148     }
149 
150     if (node.getOp() == EOpAssign && IsAtomicFunction(aggregateNode->getOp()))
151     {
152         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped());
153     }
154 
155     return false;
156 }
157 
158 const char *kZeros       = "_ANGLE_ZEROS_";
159 constexpr int kZeroCount = 256;
DefineZeroArray()160 std::string DefineZeroArray()
161 {
162     std::stringstream ss = sh::InitializeStream<std::stringstream>();
163     // For 'static', if the declaration does not include an initializer, the value is set to zero.
164     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
165     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
166     return ss.str();
167 }
168 
GetZeroInitializer(size_t size)169 std::string GetZeroInitializer(size_t size)
170 {
171     std::stringstream ss = sh::InitializeStream<std::stringstream>();
172     size_t quotient      = size / kZeroCount;
173     size_t reminder      = size % kZeroCount;
174 
175     for (size_t i = 0; i < quotient; ++i)
176     {
177         if (i != 0)
178         {
179             ss << ", ";
180         }
181         ss << kZeros;
182     }
183 
184     for (size_t i = 0; i < reminder; ++i)
185     {
186         if (quotient != 0 || i != 0)
187         {
188             ss << ", ";
189         }
190         ss << "0";
191     }
192 
193     return ss.str();
194 }
195 
196 }  // anonymous namespace
197 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)198 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
199                                    const TVariable *aInstanceVariable)
200     : block(aBlock), instanceVariable(aInstanceVariable)
201 {}
202 
needStructMapping(TIntermTyped * node)203 bool OutputHLSL::needStructMapping(TIntermTyped *node)
204 {
205     ASSERT(node->getBasicType() == EbtStruct);
206     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
207     {
208         TIntermNode *ancestor               = getAncestorNode(n);
209         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
210         if (ancestorBinary)
211         {
212             switch (ancestorBinary->getOp())
213             {
214                 case EOpIndexDirectStruct:
215                 {
216                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
217                     const TIntermConstantUnion *index =
218                         ancestorBinary->getRight()->getAsConstantUnion();
219                     const TField *field = structure->fields()[index->getIConst(0)];
220                     if (field->type()->getStruct() == nullptr)
221                     {
222                         return false;
223                     }
224                     break;
225                 }
226                 case EOpIndexDirect:
227                 case EOpIndexIndirect:
228                     break;
229                 default:
230                     return true;
231             }
232         }
233         else
234         {
235             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
236             if (ancestorAggregate)
237             {
238                 return true;
239             }
240             return false;
241         }
242     }
243     return true;
244 }
245 
writeFloat(TInfoSinkBase & out,float f)246 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
247 {
248     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
249     // regardless.
250     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
251         mOutputType == SH_HLSL_4_1_OUTPUT)
252     {
253         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
254     }
255     else
256     {
257         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
258     }
259 }
260 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)261 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
262 {
263     ASSERT(constUnion != nullptr);
264     switch (constUnion->getType())
265     {
266         case EbtFloat:
267             writeFloat(out, constUnion->getFConst());
268             break;
269         case EbtInt:
270             out << constUnion->getIConst();
271             break;
272         case EbtUInt:
273             out << constUnion->getUConst();
274             break;
275         case EbtBool:
276             out << constUnion->getBConst();
277             break;
278         default:
279             UNREACHABLE();
280     }
281 }
282 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)283 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
284                                                           const TConstantUnion *const constUnion,
285                                                           const size_t size)
286 {
287     const TConstantUnion *constUnionIterated = constUnion;
288     for (size_t i = 0; i < size; i++, constUnionIterated++)
289     {
290         writeSingleConstant(out, constUnionIterated);
291 
292         if (i != size - 1)
293         {
294             out << ", ";
295         }
296     }
297     return constUnionIterated;
298 }
299 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,ShCompileOptions compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::map<int,const TInterfaceBlock * > & uniformBlockOptimizedMap,const std::vector<InterfaceBlock> & shaderStorageBlocks)300 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
301                        ShShaderSpec shaderSpec,
302                        int shaderVersion,
303                        const TExtensionBehavior &extensionBehavior,
304                        const char *sourcePath,
305                        ShShaderOutput outputType,
306                        int numRenderTargets,
307                        int maxDualSourceDrawBuffers,
308                        const std::vector<ShaderVariable> &uniforms,
309                        ShCompileOptions compileOptions,
310                        sh::WorkGroupSize workGroupSize,
311                        TSymbolTable *symbolTable,
312                        PerformanceDiagnostics *perfDiagnostics,
313                        const std::map<int, const TInterfaceBlock *> &uniformBlockOptimizedMap,
314                        const std::vector<InterfaceBlock> &shaderStorageBlocks)
315     : TIntermTraverser(true, true, true, symbolTable),
316       mShaderType(shaderType),
317       mShaderSpec(shaderSpec),
318       mShaderVersion(shaderVersion),
319       mExtensionBehavior(extensionBehavior),
320       mSourcePath(sourcePath),
321       mOutputType(outputType),
322       mCompileOptions(compileOptions),
323       mInsideFunction(false),
324       mInsideMain(false),
325       mUniformBlockOptimizedMap(uniformBlockOptimizedMap),
326       mNumRenderTargets(numRenderTargets),
327       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
328       mCurrentFunctionMetadata(nullptr),
329       mWorkGroupSize(workGroupSize),
330       mPerfDiagnostics(perfDiagnostics),
331       mNeedStructMapping(false)
332 {
333     mUsesFragColor        = false;
334     mUsesFragData         = false;
335     mUsesDepthRange       = false;
336     mUsesFragCoord        = false;
337     mUsesPointCoord       = false;
338     mUsesFrontFacing      = false;
339     mUsesHelperInvocation = false;
340     mUsesPointSize        = false;
341     mUsesInstanceID       = false;
342     mHasMultiviewExtensionEnabled =
343         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
344         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
345     mUsesViewID                  = false;
346     mUsesVertexID                = false;
347     mUsesFragDepth               = false;
348     mUsesNumWorkGroups           = false;
349     mUsesWorkGroupID             = false;
350     mUsesLocalInvocationID       = false;
351     mUsesGlobalInvocationID      = false;
352     mUsesLocalInvocationIndex    = false;
353     mUsesXor                     = false;
354     mUsesDiscardRewriting        = false;
355     mUsesNestedBreak             = false;
356     mRequiresIEEEStrictCompiling = false;
357     mUseZeroArray                = false;
358     mUsesSecondaryColor          = false;
359 
360     mUniqueIndex = 0;
361 
362     mOutputLod0Function      = false;
363     mInsideDiscontinuousLoop = false;
364     mNestedLoopDepth         = 0;
365 
366     mExcessiveLoopIndex = nullptr;
367 
368     mStructureHLSL       = new StructureHLSL;
369     mTextureFunctionHLSL = new TextureFunctionHLSL;
370     mImageFunctionHLSL   = new ImageFunctionHLSL;
371     mAtomicCounterFunctionHLSL =
372         new AtomicCounterFunctionHLSL((compileOptions & SH_FORCE_ATOMIC_VALUE_RESOLUTION) != 0);
373 
374     unsigned int firstUniformRegister =
375         (compileOptions & SH_SKIP_D3D_CONSTANT_REGISTER_ZERO) != 0 ? 1u : 0u;
376     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
377 
378     if (mOutputType == SH_HLSL_3_0_OUTPUT)
379     {
380         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
381         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
382         // dx_ViewAdjust.
383         // In both cases total 3 uniform registers need to be reserved.
384         mResourcesHLSL->reserveUniformRegisters(3);
385     }
386 
387     // Reserve registers for the default uniform block and driver constants
388     mResourcesHLSL->reserveUniformBlockRegisters(2);
389 
390     mSSBOOutputHLSL = new ShaderStorageBlockOutputHLSL(this, mResourcesHLSL, shaderStorageBlocks);
391 }
392 
~OutputHLSL()393 OutputHLSL::~OutputHLSL()
394 {
395     SafeDelete(mSSBOOutputHLSL);
396     SafeDelete(mStructureHLSL);
397     SafeDelete(mResourcesHLSL);
398     SafeDelete(mTextureFunctionHLSL);
399     SafeDelete(mImageFunctionHLSL);
400     SafeDelete(mAtomicCounterFunctionHLSL);
401     for (auto &eqFunction : mStructEqualityFunctions)
402     {
403         SafeDelete(eqFunction);
404     }
405     for (auto &eqFunction : mArrayEqualityFunctions)
406     {
407         SafeDelete(eqFunction);
408     }
409 }
410 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)411 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
412 {
413     BuiltInFunctionEmulator builtInFunctionEmulator;
414     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
415     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
416     {
417         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
418                                                            mShaderVersion);
419     }
420 
421     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
422 
423     // Now that we are done changing the AST, do the analyses need for HLSL generation
424     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
425     ASSERT(success == CallDAG::INITDAG_SUCCESS);
426     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
427 
428     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
429     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
430     // the shader code. When we add shader storage blocks we might also consider an alternative
431     // solution, since the struct mapping won't work very well for shader storage blocks.
432 
433     // Output the body and footer first to determine what has to go in the header
434     mInfoSinkStack.push(&mBody);
435     treeRoot->traverse(this);
436     mInfoSinkStack.pop();
437 
438     mInfoSinkStack.push(&mFooter);
439     mInfoSinkStack.pop();
440 
441     mInfoSinkStack.push(&mHeader);
442     header(mHeader, std140Structs, &builtInFunctionEmulator);
443     mInfoSinkStack.pop();
444 
445     objSink << mHeader.c_str();
446     objSink << mBody.c_str();
447     objSink << mFooter.c_str();
448 
449     builtInFunctionEmulator.cleanup();
450 }
451 
getShaderStorageBlockRegisterMap() const452 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
453 {
454     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
455 }
456 
getUniformBlockRegisterMap() const457 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
458 {
459     return mResourcesHLSL->getUniformBlockRegisterMap();
460 }
461 
getUniformBlockUseStructuredBufferMap() const462 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
463 {
464     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
465 }
466 
getUniformRegisterMap() const467 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
468 {
469     return mResourcesHLSL->getUniformRegisterMap();
470 }
471 
getReadonlyImage2DRegisterIndex() const472 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
473 {
474     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
475 }
476 
getImage2DRegisterIndex() const477 unsigned int OutputHLSL::getImage2DRegisterIndex() const
478 {
479     return mResourcesHLSL->getImage2DRegisterIndex();
480 }
481 
getUsedImage2DFunctionNames() const482 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
483 {
484     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
485 }
486 
structInitializerString(int indent,const TType & type,const TString & name) const487 TString OutputHLSL::structInitializerString(int indent,
488                                             const TType &type,
489                                             const TString &name) const
490 {
491     TString init;
492 
493     TString indentString;
494     for (int spaces = 0; spaces < indent; spaces++)
495     {
496         indentString += "    ";
497     }
498 
499     if (type.isArray())
500     {
501         init += indentString + "{\n";
502         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
503         {
504             TStringStream indexedString = sh::InitializeStream<TStringStream>();
505             indexedString << name << "[" << arrayIndex << "]";
506             TType elementType = type;
507             elementType.toArrayElementType();
508             init += structInitializerString(indent + 1, elementType, indexedString.str());
509             if (arrayIndex < type.getOutermostArraySize() - 1)
510             {
511                 init += ",";
512             }
513             init += "\n";
514         }
515         init += indentString + "}";
516     }
517     else if (type.getBasicType() == EbtStruct)
518     {
519         init += indentString + "{\n";
520         const TStructure &structure = *type.getStruct();
521         const TFieldList &fields    = structure.fields();
522         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
523         {
524             const TField &field      = *fields[fieldIndex];
525             const TString &fieldName = name + "." + Decorate(field.name());
526             const TType &fieldType   = *field.type();
527 
528             init += structInitializerString(indent + 1, fieldType, fieldName);
529             if (fieldIndex < fields.size() - 1)
530             {
531                 init += ",";
532             }
533             init += "\n";
534         }
535         init += indentString + "}";
536     }
537     else
538     {
539         init += indentString + name;
540     }
541 
542     return init;
543 }
544 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const545 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
546 {
547     TString mappedStructs;
548 
549     for (auto &mappedStruct : std140Structs)
550     {
551         const TInterfaceBlock *interfaceBlock =
552             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
553         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
554         switch (qualifier)
555         {
556             case EvqUniform:
557                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
558                 {
559                     continue;
560                 }
561                 break;
562             case EvqBuffer:
563                 continue;
564             default:
565                 UNREACHABLE();
566                 return mappedStructs;
567         }
568 
569         unsigned int instanceCount = 1u;
570         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
571         if (isInstanceArray)
572         {
573             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
574         }
575 
576         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
577              ++instanceArrayIndex)
578         {
579             TString originalName;
580             TString mappedName("map");
581 
582             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
583             {
584                 const ImmutableString &instanceName =
585                     mappedStruct.blockDeclarator->variable().name();
586                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
587                 if (isInstanceArray)
588                     instanceStringArrayIndex = instanceArrayIndex;
589                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
590                     instanceName, instanceStringArrayIndex);
591                 originalName += instanceString;
592                 mappedName += instanceString;
593                 originalName += ".";
594                 mappedName += "_";
595             }
596 
597             TString fieldName = Decorate(mappedStruct.field->name());
598             originalName += fieldName;
599             mappedName += fieldName;
600 
601             TType *structType = mappedStruct.field->type();
602             mappedStructs +=
603                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
604 
605             if (structType->isArray())
606             {
607                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
608             }
609 
610             mappedStructs += " =\n";
611             mappedStructs += structInitializerString(0, *structType, originalName);
612             mappedStructs += ";\n";
613         }
614     }
615     return mappedStructs;
616 }
617 
writeReferencedAttributes(TInfoSinkBase & out) const618 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
619 {
620     for (const auto &attribute : mReferencedAttributes)
621     {
622         const TType &type           = attribute.second->getType();
623         const ImmutableString &name = attribute.second->name();
624 
625         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
626             << zeroInitializer(type) << ";\n";
627     }
628 }
629 
writeReferencedVaryings(TInfoSinkBase & out) const630 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
631 {
632     for (const auto &varying : mReferencedVaryings)
633     {
634         const TType &type = varying.second->getType();
635 
636         // Program linking depends on this exact format
637         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
638             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
639             << zeroInitializer(type) << ";\n";
640     }
641 }
642 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const643 void OutputHLSL::header(TInfoSinkBase &out,
644                         const std::vector<MappedStruct> &std140Structs,
645                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
646 {
647     TString mappedStructs;
648     if (mNeedStructMapping)
649     {
650         mappedStructs = generateStructMapping(std140Structs);
651     }
652 
653     // Suppress some common warnings:
654     // 3556 : Integer divides might be much slower, try using uints if possible.
655     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
656     //        conditionally handle negative values if you expect them.
657     out << "#pragma warning( disable: 3556 3571 )\n";
658 
659     out << mStructureHLSL->structsHeader();
660 
661     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
662     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks, mUniformBlockOptimizedMap);
663     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(out);
664 
665     if (!mEqualityFunctions.empty())
666     {
667         out << "\n// Equality functions\n\n";
668         for (const auto &eqFunction : mEqualityFunctions)
669         {
670             out << eqFunction->functionDefinition << "\n";
671         }
672     }
673     if (!mArrayAssignmentFunctions.empty())
674     {
675         out << "\n// Assignment functions\n\n";
676         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
677         {
678             out << assignmentFunction.functionDefinition << "\n";
679         }
680     }
681     if (!mArrayConstructIntoFunctions.empty())
682     {
683         out << "\n// Array constructor functions\n\n";
684         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
685         {
686             out << constructIntoFunction.functionDefinition << "\n";
687         }
688     }
689 
690     if (mUsesDiscardRewriting)
691     {
692         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
693     }
694 
695     if (mUsesNestedBreak)
696     {
697         out << "#define ANGLE_USES_NESTED_BREAK\n";
698     }
699 
700     if (mRequiresIEEEStrictCompiling)
701     {
702         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
703     }
704 
705     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
706            "#define LOOP [loop]\n"
707            "#define FLATTEN [flatten]\n"
708            "#else\n"
709            "#define LOOP\n"
710            "#define FLATTEN\n"
711            "#endif\n";
712 
713     // array stride for atomic counter buffers is always 4 per original extension
714     // ARB_shader_atomic_counters and discussion on
715     // https://github.com/KhronosGroup/OpenGL-API/issues/5
716     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
717 
718     if (mUseZeroArray)
719     {
720         out << DefineZeroArray() << "\n";
721     }
722 
723     if (mShaderType == GL_FRAGMENT_SHADER)
724     {
725         const bool usingMRTExtension =
726             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
727         const bool usingBFEExtension =
728             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
729 
730         out << "// Varyings\n";
731         writeReferencedVaryings(out);
732         out << "\n";
733 
734         if ((IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 130) ||
735             (!IsDesktopGLSpec(mShaderSpec) && mShaderVersion >= 300))
736         {
737             for (const auto &outputVariable : mReferencedOutputVariables)
738             {
739                 const ImmutableString &variableName = outputVariable.second->name();
740                 const TType &variableType           = outputVariable.second->getType();
741 
742                 out << "static " << TypeString(variableType) << " out_" << variableName
743                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
744             }
745         }
746         else
747         {
748             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
749 
750             out << "static float4 gl_Color[" << numColorValues
751                 << "] =\n"
752                    "{\n";
753             for (unsigned int i = 0; i < numColorValues; i++)
754             {
755                 out << "    float4(0, 0, 0, 0)";
756                 if (i + 1 != numColorValues)
757                 {
758                     out << ",";
759                 }
760                 out << "\n";
761             }
762 
763             out << "};\n";
764 
765             if (usingBFEExtension && mUsesSecondaryColor)
766             {
767                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
768                     << "] = \n"
769                        "{\n";
770                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
771                 {
772                     out << "    float4(0, 0, 0, 0)";
773                     if (i + 1 != mMaxDualSourceDrawBuffers)
774                     {
775                         out << ",";
776                     }
777                     out << "\n";
778                 }
779                 out << "};\n";
780             }
781         }
782 
783         if (mUsesFragDepth)
784         {
785             out << "static float gl_Depth = 0.0;\n";
786         }
787 
788         if (mUsesFragCoord)
789         {
790             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
791         }
792 
793         if (mUsesPointCoord)
794         {
795             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
796         }
797 
798         if (mUsesFrontFacing)
799         {
800             out << "static bool gl_FrontFacing = false;\n";
801         }
802 
803         if (mUsesHelperInvocation)
804         {
805             out << "static bool gl_HelperInvocation = false;\n";
806         }
807 
808         out << "\n";
809 
810         if (mUsesDepthRange)
811         {
812             out << "struct gl_DepthRangeParameters\n"
813                    "{\n"
814                    "    float near;\n"
815                    "    float far;\n"
816                    "    float diff;\n"
817                    "};\n"
818                    "\n";
819         }
820 
821         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
822         {
823             out << "cbuffer DriverConstants : register(b1)\n"
824                    "{\n";
825 
826             if (mUsesDepthRange)
827             {
828                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
829             }
830 
831             if (mUsesFragCoord)
832             {
833                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
834             }
835 
836             if (mUsesFragCoord || mUsesFrontFacing)
837             {
838                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
839             }
840 
841             if (mUsesFragCoord)
842             {
843                 // dx_ViewScale is only used in the fragment shader to correct
844                 // the value for glFragCoord if necessary
845                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
846             }
847 
848             if (mHasMultiviewExtensionEnabled)
849             {
850                 // We have to add a value which we can use to keep track of which multi-view code
851                 // path is to be selected in the GS.
852                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
853             }
854 
855             if (mOutputType == SH_HLSL_4_1_OUTPUT)
856             {
857                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
858             }
859 
860             out << "};\n";
861         }
862         else
863         {
864             if (mUsesDepthRange)
865             {
866                 out << "uniform float3 dx_DepthRange : register(c0);";
867             }
868 
869             if (mUsesFragCoord)
870             {
871                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
872             }
873 
874             if (mUsesFragCoord || mUsesFrontFacing)
875             {
876                 out << "uniform float3 dx_DepthFront : register(c2);\n";
877             }
878         }
879 
880         out << "\n";
881 
882         if (mUsesDepthRange)
883         {
884             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
885                    "dx_DepthRange.y, dx_DepthRange.z};\n"
886                    "\n";
887         }
888 
889         if (usingMRTExtension && mNumRenderTargets > 1)
890         {
891             out << "#define GL_USES_MRT\n";
892         }
893 
894         if (mUsesFragColor)
895         {
896             out << "#define GL_USES_FRAG_COLOR\n";
897         }
898 
899         if (mUsesFragData)
900         {
901             out << "#define GL_USES_FRAG_DATA\n";
902         }
903 
904         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
905         {
906             out << "#define GL_USES_SECONDARY_COLOR\n";
907         }
908     }
909     else if (mShaderType == GL_VERTEX_SHADER)
910     {
911         out << "// Attributes\n";
912         writeReferencedAttributes(out);
913         out << "\n"
914                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
915 
916         if (mUsesPointSize)
917         {
918             out << "static float gl_PointSize = float(1);\n";
919         }
920 
921         if (mUsesInstanceID)
922         {
923             out << "static int gl_InstanceID;";
924         }
925 
926         if (mUsesVertexID)
927         {
928             out << "static int gl_VertexID;";
929         }
930 
931         out << "\n"
932                "// Varyings\n";
933         writeReferencedVaryings(out);
934         out << "\n";
935 
936         if (mUsesDepthRange)
937         {
938             out << "struct gl_DepthRangeParameters\n"
939                    "{\n"
940                    "    float near;\n"
941                    "    float far;\n"
942                    "    float diff;\n"
943                    "};\n"
944                    "\n";
945         }
946 
947         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
948         {
949             out << "cbuffer DriverConstants : register(b1)\n"
950                    "{\n";
951 
952             if (mUsesDepthRange)
953             {
954                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
955             }
956 
957             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
958             // shaders. However, we declare it for all shaders (including Feature Level 10+).
959             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
960             // if it's unused.
961             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
962             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
963             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
964 
965             if (mHasMultiviewExtensionEnabled)
966             {
967                 // We have to add a value which we can use to keep track of which multi-view code
968                 // path is to be selected in the GS.
969                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
970             }
971 
972             if (mOutputType == SH_HLSL_4_1_OUTPUT)
973             {
974                 mResourcesHLSL->samplerMetadataUniforms(out, 4);
975             }
976 
977             if (mUsesVertexID)
978             {
979                 out << "    uint dx_VertexID : packoffset(c3.w);\n";
980             }
981 
982             out << "};\n"
983                    "\n";
984         }
985         else
986         {
987             if (mUsesDepthRange)
988             {
989                 out << "uniform float3 dx_DepthRange : register(c0);\n";
990             }
991 
992             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
993             out << "uniform float2 dx_ViewCoords : register(c2);\n"
994                    "\n";
995         }
996 
997         if (mUsesDepthRange)
998         {
999             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1000                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1001                    "\n";
1002         }
1003     }
1004     else  // Compute shader
1005     {
1006         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1007 
1008         out << "cbuffer DriverConstants : register(b1)\n"
1009                "{\n";
1010         if (mUsesNumWorkGroups)
1011         {
1012             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1013         }
1014         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1015         unsigned int registerIndex = 1;
1016         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1017         // Sampler metadata struct must be two 4-vec, 32 bytes.
1018         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1019         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1020         out << "};\n";
1021 
1022         out << kImage2DFunctionString << "\n";
1023 
1024         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1025         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1026 
1027         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1028         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n"
1029                                 << "{\n";
1030 
1031         if (mUsesWorkGroupID)
1032         {
1033             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1034             systemValueDeclaration << "    uint3 dx_WorkGroupID : "
1035                                    << "SV_GroupID;\n";
1036             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1037         }
1038 
1039         if (mUsesLocalInvocationID)
1040         {
1041             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1042             systemValueDeclaration << "    uint3 dx_LocalInvocationID : "
1043                                    << "SV_GroupThreadID;\n";
1044             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1045         }
1046 
1047         if (mUsesGlobalInvocationID)
1048         {
1049             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1050             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1051                                    << "SV_DispatchThreadID;\n";
1052             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1053         }
1054 
1055         if (mUsesLocalInvocationIndex)
1056         {
1057             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1058             systemValueDeclaration << "    uint dx_LocalInvocationIndex : "
1059                                    << "SV_GroupIndex;\n";
1060             glBuiltinInitialization
1061                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1062         }
1063 
1064         systemValueDeclaration << "};\n\n";
1065         glBuiltinInitialization << "};\n\n";
1066 
1067         out << systemValueDeclaration.str();
1068         out << glBuiltinInitialization.str();
1069     }
1070 
1071     if (!mappedStructs.empty())
1072     {
1073         out << "// Structures from std140 blocks with padding removed\n";
1074         out << "\n";
1075         out << mappedStructs;
1076         out << "\n";
1077     }
1078 
1079     bool getDimensionsIgnoresBaseLevel =
1080         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
1081     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1082     mImageFunctionHLSL->imageFunctionHeader(out);
1083     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1084 
1085     if (mUsesFragCoord)
1086     {
1087         out << "#define GL_USES_FRAG_COORD\n";
1088     }
1089 
1090     if (mUsesPointCoord)
1091     {
1092         out << "#define GL_USES_POINT_COORD\n";
1093     }
1094 
1095     if (mUsesFrontFacing)
1096     {
1097         out << "#define GL_USES_FRONT_FACING\n";
1098     }
1099 
1100     if (mUsesHelperInvocation)
1101     {
1102         out << "#define GL_USES_HELPER_INVOCATION\n";
1103     }
1104 
1105     if (mUsesPointSize)
1106     {
1107         out << "#define GL_USES_POINT_SIZE\n";
1108     }
1109 
1110     if (mHasMultiviewExtensionEnabled)
1111     {
1112         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
1113     }
1114 
1115     if (mUsesVertexID)
1116     {
1117         out << "#define GL_USES_VERTEX_ID\n";
1118     }
1119 
1120     if (mUsesViewID)
1121     {
1122         out << "#define GL_USES_VIEW_ID\n";
1123     }
1124 
1125     if (mUsesFragDepth)
1126     {
1127         out << "#define GL_USES_FRAG_DEPTH\n";
1128     }
1129 
1130     if (mUsesDepthRange)
1131     {
1132         out << "#define GL_USES_DEPTH_RANGE\n";
1133     }
1134 
1135     if (mUsesXor)
1136     {
1137         out << "bool xor(bool p, bool q)\n"
1138                "{\n"
1139                "    return (p || q) && !(p && q);\n"
1140                "}\n"
1141                "\n";
1142     }
1143 
1144     builtInFunctionEmulator->outputEmulatedFunctions(out);
1145 }
1146 
visitSymbol(TIntermSymbol * node)1147 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1148 {
1149     const TVariable &variable = node->variable();
1150 
1151     // Empty symbols can only appear in declarations and function arguments, and in either of those
1152     // cases the symbol nodes are not visited.
1153     ASSERT(variable.symbolType() != SymbolType::Empty);
1154 
1155     TInfoSinkBase &out = getInfoSink();
1156 
1157     // Handle accessing std140 structs by value
1158     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1159         needStructMapping(node))
1160     {
1161         mNeedStructMapping = true;
1162         out << "map";
1163     }
1164 
1165     const ImmutableString &name     = variable.name();
1166     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1167 
1168     if (name == "gl_DepthRange")
1169     {
1170         mUsesDepthRange = true;
1171         out << name;
1172     }
1173     else if (IsAtomicCounter(variable.getType().getBasicType()))
1174     {
1175         const TType &variableType = variable.getType();
1176         if (variableType.getQualifier() == EvqUniform)
1177         {
1178             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1179             mReferencedUniforms[uniqueId.get()] = &variable;
1180             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1181         }
1182         else
1183         {
1184             TString varName = DecorateVariableIfNeeded(variable);
1185             out << varName << ", " << varName << "_offset";
1186         }
1187     }
1188     else
1189     {
1190         const TType &variableType = variable.getType();
1191         TQualifier qualifier      = variable.getType().getQualifier();
1192 
1193         ensureStructDefined(variableType);
1194 
1195         if (qualifier == EvqUniform)
1196         {
1197             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1198 
1199             if (interfaceBlock)
1200             {
1201                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1202                 {
1203                     const TVariable *instanceVariable = nullptr;
1204                     if (variableType.isInterfaceBlock())
1205                     {
1206                         instanceVariable = &variable;
1207                     }
1208                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1209                         new TReferencedBlock(interfaceBlock, instanceVariable);
1210                 }
1211             }
1212             else
1213             {
1214                 mReferencedUniforms[uniqueId.get()] = &variable;
1215             }
1216 
1217             out << DecorateVariableIfNeeded(variable);
1218         }
1219         else if (qualifier == EvqBuffer)
1220         {
1221             UNREACHABLE();
1222         }
1223         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1224         {
1225             mReferencedAttributes[uniqueId.get()] = &variable;
1226             out << Decorate(name);
1227         }
1228         else if (IsVarying(qualifier))
1229         {
1230             mReferencedVaryings[uniqueId.get()] = &variable;
1231             out << DecorateVariableIfNeeded(variable);
1232             if (variable.symbolType() == SymbolType::AngleInternal && name == "ViewID_OVR")
1233             {
1234                 mUsesViewID = true;
1235             }
1236         }
1237         else if (qualifier == EvqFragmentOut)
1238         {
1239             mReferencedOutputVariables[uniqueId.get()] = &variable;
1240             out << "out_" << name;
1241         }
1242         else if (qualifier == EvqFragColor)
1243         {
1244             out << "gl_Color[0]";
1245             mUsesFragColor = true;
1246         }
1247         else if (qualifier == EvqFragData)
1248         {
1249             out << "gl_Color";
1250             mUsesFragData = true;
1251         }
1252         else if (qualifier == EvqSecondaryFragColorEXT)
1253         {
1254             out << "gl_SecondaryColor[0]";
1255             mUsesSecondaryColor = true;
1256         }
1257         else if (qualifier == EvqSecondaryFragDataEXT)
1258         {
1259             out << "gl_SecondaryColor";
1260             mUsesSecondaryColor = true;
1261         }
1262         else if (qualifier == EvqFragCoord)
1263         {
1264             mUsesFragCoord = true;
1265             out << name;
1266         }
1267         else if (qualifier == EvqPointCoord)
1268         {
1269             mUsesPointCoord = true;
1270             out << name;
1271         }
1272         else if (qualifier == EvqFrontFacing)
1273         {
1274             mUsesFrontFacing = true;
1275             out << name;
1276         }
1277         else if (qualifier == EvqHelperInvocation)
1278         {
1279             mUsesHelperInvocation = true;
1280             out << name;
1281         }
1282         else if (qualifier == EvqPointSize)
1283         {
1284             mUsesPointSize = true;
1285             out << name;
1286         }
1287         else if (qualifier == EvqInstanceID)
1288         {
1289             mUsesInstanceID = true;
1290             out << name;
1291         }
1292         else if (qualifier == EvqVertexID)
1293         {
1294             mUsesVertexID = true;
1295             out << name;
1296         }
1297         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1298         {
1299             mUsesFragDepth = true;
1300             out << "gl_Depth";
1301         }
1302         else if (qualifier == EvqNumWorkGroups)
1303         {
1304             mUsesNumWorkGroups = true;
1305             out << name;
1306         }
1307         else if (qualifier == EvqWorkGroupID)
1308         {
1309             mUsesWorkGroupID = true;
1310             out << name;
1311         }
1312         else if (qualifier == EvqLocalInvocationID)
1313         {
1314             mUsesLocalInvocationID = true;
1315             out << name;
1316         }
1317         else if (qualifier == EvqGlobalInvocationID)
1318         {
1319             mUsesGlobalInvocationID = true;
1320             out << name;
1321         }
1322         else if (qualifier == EvqLocalInvocationIndex)
1323         {
1324             mUsesLocalInvocationIndex = true;
1325             out << name;
1326         }
1327         else
1328         {
1329             out << DecorateVariableIfNeeded(variable);
1330         }
1331     }
1332 }
1333 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1334 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1335 {
1336     if (type.isScalar() && !type.isArray())
1337     {
1338         if (op == EOpEqual)
1339         {
1340             outputTriplet(out, visit, "(", " == ", ")");
1341         }
1342         else
1343         {
1344             outputTriplet(out, visit, "(", " != ", ")");
1345         }
1346     }
1347     else
1348     {
1349         if (visit == PreVisit && op == EOpNotEqual)
1350         {
1351             out << "!";
1352         }
1353 
1354         if (type.isArray())
1355         {
1356             const TString &functionName = addArrayEqualityFunction(type);
1357             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1358         }
1359         else if (type.getBasicType() == EbtStruct)
1360         {
1361             const TStructure &structure = *type.getStruct();
1362             const TString &functionName = addStructEqualityFunction(structure);
1363             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1364         }
1365         else
1366         {
1367             ASSERT(type.isMatrix() || type.isVector());
1368             outputTriplet(out, visit, "all(", " == ", ")");
1369         }
1370     }
1371 }
1372 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1373 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1374 {
1375     if (type.isArray())
1376     {
1377         const TString &functionName = addArrayAssignmentFunction(type);
1378         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1379     }
1380     else
1381     {
1382         outputTriplet(out, visit, "(", " = ", ")");
1383     }
1384 }
1385 
ancestorEvaluatesToSamplerInStruct()1386 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1387 {
1388     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1389     {
1390         TIntermNode *ancestor               = getAncestorNode(n);
1391         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1392         if (ancestorBinary == nullptr)
1393         {
1394             return false;
1395         }
1396         switch (ancestorBinary->getOp())
1397         {
1398             case EOpIndexDirectStruct:
1399             {
1400                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1401                 const TIntermConstantUnion *index =
1402                     ancestorBinary->getRight()->getAsConstantUnion();
1403                 const TField *field = structure->fields()[index->getIConst(0)];
1404                 if (IsSampler(field->type()->getBasicType()))
1405                 {
1406                     return true;
1407                 }
1408                 break;
1409             }
1410             case EOpIndexDirect:
1411                 break;
1412             default:
1413                 // Returning a sampler from indirect indexing is not supported.
1414                 return false;
1415         }
1416     }
1417     return false;
1418 }
1419 
visitSwizzle(Visit visit,TIntermSwizzle * node)1420 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1421 {
1422     TInfoSinkBase &out = getInfoSink();
1423     if (visit == PostVisit)
1424     {
1425         out << ".";
1426         node->writeOffsetsAsXYZW(&out);
1427     }
1428     return true;
1429 }
1430 
visitBinary(Visit visit,TIntermBinary * node)1431 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1432 {
1433     TInfoSinkBase &out = getInfoSink();
1434 
1435     switch (node->getOp())
1436     {
1437         case EOpComma:
1438             outputTriplet(out, visit, "(", ", ", ")");
1439             break;
1440         case EOpAssign:
1441             if (node->isArray())
1442             {
1443                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1444                 if (rightAgg != nullptr && rightAgg->isConstructor())
1445                 {
1446                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1447                     out << functionName << "(";
1448                     node->getLeft()->traverse(this);
1449                     TIntermSequence *seq = rightAgg->getSequence();
1450                     for (auto &arrayElement : *seq)
1451                     {
1452                         out << ", ";
1453                         arrayElement->traverse(this);
1454                     }
1455                     out << ")";
1456                     return false;
1457                 }
1458                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1459                 // function call is assigned.
1460                 ASSERT(rightAgg == nullptr);
1461             }
1462             // Assignment expressions with atomic functions should be transformed into atomic
1463             // function calls in HLSL.
1464             // e.g. original_value = atomicAdd(dest, value) should be translated into
1465             //      InterlockedAdd(dest, value, original_value);
1466             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1467             {
1468                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1469                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1470                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1471                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1472                 ASSERT(argumentSeq->size() >= 2u);
1473                 for (auto &argument : *argumentSeq)
1474                 {
1475                     argument->traverse(this);
1476                     out << ", ";
1477                 }
1478                 node->getLeft()->traverse(this);
1479                 out << ")";
1480                 return false;
1481             }
1482             else if (IsInShaderStorageBlock(node->getLeft()))
1483             {
1484                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1485                 out << ", ";
1486                 if (IsInShaderStorageBlock(node->getRight()))
1487                 {
1488                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1489                 }
1490                 else
1491                 {
1492                     node->getRight()->traverse(this);
1493                 }
1494 
1495                 out << ")";
1496                 return false;
1497             }
1498             else if (IsInShaderStorageBlock(node->getRight()))
1499             {
1500                 node->getLeft()->traverse(this);
1501                 out << " = ";
1502                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1503                 return false;
1504             }
1505 
1506             outputAssign(visit, node->getType(), out);
1507             break;
1508         case EOpInitialize:
1509             if (visit == PreVisit)
1510             {
1511                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1512                 ASSERT(symbolNode);
1513                 TIntermTyped *initializer = node->getRight();
1514 
1515                 // Global initializers must be constant at this point.
1516                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1517 
1518                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1519                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1520                 // new variable is created before the assignment is evaluated), so we need to
1521                 // convert
1522                 // this to "float t = x, x = t;".
1523                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1524                 {
1525                     // Skip initializing the rest of the expression
1526                     return false;
1527                 }
1528                 else if (writeConstantInitialization(out, symbolNode, initializer))
1529                 {
1530                     return false;
1531                 }
1532             }
1533             else if (visit == InVisit)
1534             {
1535                 out << " = ";
1536                 if (IsInShaderStorageBlock(node->getRight()))
1537                 {
1538                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1539                     return false;
1540                 }
1541             }
1542             break;
1543         case EOpAddAssign:
1544             outputTriplet(out, visit, "(", " += ", ")");
1545             break;
1546         case EOpSubAssign:
1547             outputTriplet(out, visit, "(", " -= ", ")");
1548             break;
1549         case EOpMulAssign:
1550             outputTriplet(out, visit, "(", " *= ", ")");
1551             break;
1552         case EOpVectorTimesScalarAssign:
1553             outputTriplet(out, visit, "(", " *= ", ")");
1554             break;
1555         case EOpMatrixTimesScalarAssign:
1556             outputTriplet(out, visit, "(", " *= ", ")");
1557             break;
1558         case EOpVectorTimesMatrixAssign:
1559             if (visit == PreVisit)
1560             {
1561                 out << "(";
1562             }
1563             else if (visit == InVisit)
1564             {
1565                 out << " = mul(";
1566                 node->getLeft()->traverse(this);
1567                 out << ", transpose(";
1568             }
1569             else
1570             {
1571                 out << ")))";
1572             }
1573             break;
1574         case EOpMatrixTimesMatrixAssign:
1575             if (visit == PreVisit)
1576             {
1577                 out << "(";
1578             }
1579             else if (visit == InVisit)
1580             {
1581                 out << " = transpose(mul(transpose(";
1582                 node->getLeft()->traverse(this);
1583                 out << "), transpose(";
1584             }
1585             else
1586             {
1587                 out << "))))";
1588             }
1589             break;
1590         case EOpDivAssign:
1591             outputTriplet(out, visit, "(", " /= ", ")");
1592             break;
1593         case EOpIModAssign:
1594             outputTriplet(out, visit, "(", " %= ", ")");
1595             break;
1596         case EOpBitShiftLeftAssign:
1597             outputTriplet(out, visit, "(", " <<= ", ")");
1598             break;
1599         case EOpBitShiftRightAssign:
1600             outputTriplet(out, visit, "(", " >>= ", ")");
1601             break;
1602         case EOpBitwiseAndAssign:
1603             outputTriplet(out, visit, "(", " &= ", ")");
1604             break;
1605         case EOpBitwiseXorAssign:
1606             outputTriplet(out, visit, "(", " ^= ", ")");
1607             break;
1608         case EOpBitwiseOrAssign:
1609             outputTriplet(out, visit, "(", " |= ", ")");
1610             break;
1611         case EOpIndexDirect:
1612         {
1613             const TType &leftType = node->getLeft()->getType();
1614             if (leftType.isInterfaceBlock())
1615             {
1616                 if (visit == PreVisit)
1617                 {
1618                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1619                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1620 
1621                     ASSERT(leftType.getQualifier() == EvqUniform);
1622                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1623                     {
1624                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1625                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1626                     }
1627                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1628                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1629                         instanceArraySymbol->getName(), arrayIndex);
1630                     return false;
1631                 }
1632             }
1633             else if (ancestorEvaluatesToSamplerInStruct())
1634             {
1635                 // All parts of an expression that access a sampler in a struct need to use _ as
1636                 // separator to access the sampler variable that has been moved out of the struct.
1637                 outputTriplet(out, visit, "", "_", "");
1638             }
1639             else if (IsAtomicCounter(leftType.getBasicType()))
1640             {
1641                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1642             }
1643             else
1644             {
1645                 outputTriplet(out, visit, "", "[", "]");
1646                 if (visit == PostVisit)
1647                 {
1648                     const TInterfaceBlock *interfaceBlock =
1649                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1650                     if (interfaceBlock &&
1651                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1652                     {
1653                         // If the uniform block member's type is not structure, we had explicitly
1654                         // packed the member into a structure, so need to add an operator of field
1655                         // slection.
1656                         const TField *field    = interfaceBlock->fields()[0];
1657                         const TType *fieldType = field->type();
1658                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1659                             fieldType->isScalarArray())
1660                         {
1661                             out << "." << Decorate(field->name());
1662                         }
1663                     }
1664                 }
1665             }
1666         }
1667         break;
1668         case EOpIndexIndirect:
1669         {
1670             // We do not currently support indirect references to interface blocks
1671             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1672 
1673             const TType &leftType = node->getLeft()->getType();
1674             if (IsAtomicCounter(leftType.getBasicType()))
1675             {
1676                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1677             }
1678             else
1679             {
1680                 outputTriplet(out, visit, "", "[", "]");
1681                 if (visit == PostVisit)
1682                 {
1683                     const TInterfaceBlock *interfaceBlock =
1684                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1685                     if (interfaceBlock &&
1686                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1687                     {
1688                         // If the uniform block member's type is not structure, we had explicitly
1689                         // packed the member into a structure, so need to add an operator of field
1690                         // slection.
1691                         const TField *field    = interfaceBlock->fields()[0];
1692                         const TType *fieldType = field->type();
1693                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1694                             fieldType->isScalarArray())
1695                         {
1696                             out << "." << Decorate(field->name());
1697                         }
1698                     }
1699                 }
1700             }
1701             break;
1702         }
1703         case EOpIndexDirectStruct:
1704         {
1705             const TStructure *structure       = node->getLeft()->getType().getStruct();
1706             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1707             const TField *field               = structure->fields()[index->getIConst(0)];
1708 
1709             // In cases where indexing returns a sampler, we need to access the sampler variable
1710             // that has been moved out of the struct.
1711             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1712             if (visit == PreVisit && indexingReturnsSampler)
1713             {
1714                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1715                 // This prefix is only output at the beginning of the indexing expression, which
1716                 // may have multiple parts.
1717                 out << "angle";
1718             }
1719             if (!indexingReturnsSampler)
1720             {
1721                 // All parts of an expression that access a sampler in a struct need to use _ as
1722                 // separator to access the sampler variable that has been moved out of the struct.
1723                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1724             }
1725             if (visit == InVisit)
1726             {
1727                 if (indexingReturnsSampler)
1728                 {
1729                     out << "_" << field->name();
1730                 }
1731                 else
1732                 {
1733                     out << "." << DecorateField(field->name(), *structure);
1734                 }
1735 
1736                 return false;
1737             }
1738         }
1739         break;
1740         case EOpIndexDirectInterfaceBlock:
1741         {
1742             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1743             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1744                                               IsInStd140UniformBlock(node->getLeft()) &&
1745                                               needStructMapping(node);
1746             if (visit == PreVisit && structInStd140UniformBlock)
1747             {
1748                 mNeedStructMapping = true;
1749                 out << "map";
1750             }
1751             if (visit == InVisit)
1752             {
1753                 const TInterfaceBlock *interfaceBlock =
1754                     node->getLeft()->getType().getInterfaceBlock();
1755                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1756                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1757                 if (structInStd140UniformBlock ||
1758                     mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1759                 {
1760                     out << "_";
1761                 }
1762                 else
1763                 {
1764                     out << ".";
1765                 }
1766                 out << Decorate(field->name());
1767 
1768                 return false;
1769             }
1770             break;
1771         }
1772         case EOpAdd:
1773             outputTriplet(out, visit, "(", " + ", ")");
1774             break;
1775         case EOpSub:
1776             outputTriplet(out, visit, "(", " - ", ")");
1777             break;
1778         case EOpMul:
1779             outputTriplet(out, visit, "(", " * ", ")");
1780             break;
1781         case EOpDiv:
1782             outputTriplet(out, visit, "(", " / ", ")");
1783             break;
1784         case EOpIMod:
1785             outputTriplet(out, visit, "(", " % ", ")");
1786             break;
1787         case EOpBitShiftLeft:
1788             outputTriplet(out, visit, "(", " << ", ")");
1789             break;
1790         case EOpBitShiftRight:
1791             outputTriplet(out, visit, "(", " >> ", ")");
1792             break;
1793         case EOpBitwiseAnd:
1794             outputTriplet(out, visit, "(", " & ", ")");
1795             break;
1796         case EOpBitwiseXor:
1797             outputTriplet(out, visit, "(", " ^ ", ")");
1798             break;
1799         case EOpBitwiseOr:
1800             outputTriplet(out, visit, "(", " | ", ")");
1801             break;
1802         case EOpEqual:
1803         case EOpNotEqual:
1804             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1805             break;
1806         case EOpLessThan:
1807             outputTriplet(out, visit, "(", " < ", ")");
1808             break;
1809         case EOpGreaterThan:
1810             outputTriplet(out, visit, "(", " > ", ")");
1811             break;
1812         case EOpLessThanEqual:
1813             outputTriplet(out, visit, "(", " <= ", ")");
1814             break;
1815         case EOpGreaterThanEqual:
1816             outputTriplet(out, visit, "(", " >= ", ")");
1817             break;
1818         case EOpVectorTimesScalar:
1819             outputTriplet(out, visit, "(", " * ", ")");
1820             break;
1821         case EOpMatrixTimesScalar:
1822             outputTriplet(out, visit, "(", " * ", ")");
1823             break;
1824         case EOpVectorTimesMatrix:
1825             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1826             break;
1827         case EOpMatrixTimesVector:
1828             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1829             break;
1830         case EOpMatrixTimesMatrix:
1831             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1832             break;
1833         case EOpLogicalOr:
1834             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1835             // been unfolded.
1836             ASSERT(!node->getRight()->hasSideEffects());
1837             outputTriplet(out, visit, "(", " || ", ")");
1838             return true;
1839         case EOpLogicalXor:
1840             mUsesXor = true;
1841             outputTriplet(out, visit, "xor(", ", ", ")");
1842             break;
1843         case EOpLogicalAnd:
1844             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1845             // been unfolded.
1846             ASSERT(!node->getRight()->hasSideEffects());
1847             outputTriplet(out, visit, "(", " && ", ")");
1848             return true;
1849         default:
1850             UNREACHABLE();
1851     }
1852 
1853     return true;
1854 }
1855 
visitUnary(Visit visit,TIntermUnary * node)1856 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1857 {
1858     TInfoSinkBase &out = getInfoSink();
1859 
1860     switch (node->getOp())
1861     {
1862         case EOpNegative:
1863             outputTriplet(out, visit, "(-", "", ")");
1864             break;
1865         case EOpPositive:
1866             outputTriplet(out, visit, "(+", "", ")");
1867             break;
1868         case EOpLogicalNot:
1869             outputTriplet(out, visit, "(!", "", ")");
1870             break;
1871         case EOpBitwiseNot:
1872             outputTriplet(out, visit, "(~", "", ")");
1873             break;
1874         case EOpPostIncrement:
1875             outputTriplet(out, visit, "(", "", "++)");
1876             break;
1877         case EOpPostDecrement:
1878             outputTriplet(out, visit, "(", "", "--)");
1879             break;
1880         case EOpPreIncrement:
1881             outputTriplet(out, visit, "(++", "", ")");
1882             break;
1883         case EOpPreDecrement:
1884             outputTriplet(out, visit, "(--", "", ")");
1885             break;
1886         case EOpRadians:
1887             outputTriplet(out, visit, "radians(", "", ")");
1888             break;
1889         case EOpDegrees:
1890             outputTriplet(out, visit, "degrees(", "", ")");
1891             break;
1892         case EOpSin:
1893             outputTriplet(out, visit, "sin(", "", ")");
1894             break;
1895         case EOpCos:
1896             outputTriplet(out, visit, "cos(", "", ")");
1897             break;
1898         case EOpTan:
1899             outputTriplet(out, visit, "tan(", "", ")");
1900             break;
1901         case EOpAsin:
1902             outputTriplet(out, visit, "asin(", "", ")");
1903             break;
1904         case EOpAcos:
1905             outputTriplet(out, visit, "acos(", "", ")");
1906             break;
1907         case EOpAtan:
1908             outputTriplet(out, visit, "atan(", "", ")");
1909             break;
1910         case EOpSinh:
1911             outputTriplet(out, visit, "sinh(", "", ")");
1912             break;
1913         case EOpCosh:
1914             outputTriplet(out, visit, "cosh(", "", ")");
1915             break;
1916         case EOpTanh:
1917         case EOpAsinh:
1918         case EOpAcosh:
1919         case EOpAtanh:
1920             ASSERT(node->getUseEmulatedFunction());
1921             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1922             break;
1923         case EOpExp:
1924             outputTriplet(out, visit, "exp(", "", ")");
1925             break;
1926         case EOpLog:
1927             outputTriplet(out, visit, "log(", "", ")");
1928             break;
1929         case EOpExp2:
1930             outputTriplet(out, visit, "exp2(", "", ")");
1931             break;
1932         case EOpLog2:
1933             outputTriplet(out, visit, "log2(", "", ")");
1934             break;
1935         case EOpSqrt:
1936             outputTriplet(out, visit, "sqrt(", "", ")");
1937             break;
1938         case EOpInversesqrt:
1939             outputTriplet(out, visit, "rsqrt(", "", ")");
1940             break;
1941         case EOpAbs:
1942             outputTriplet(out, visit, "abs(", "", ")");
1943             break;
1944         case EOpSign:
1945             outputTriplet(out, visit, "sign(", "", ")");
1946             break;
1947         case EOpFloor:
1948             outputTriplet(out, visit, "floor(", "", ")");
1949             break;
1950         case EOpTrunc:
1951             outputTriplet(out, visit, "trunc(", "", ")");
1952             break;
1953         case EOpRound:
1954             outputTriplet(out, visit, "round(", "", ")");
1955             break;
1956         case EOpRoundEven:
1957             ASSERT(node->getUseEmulatedFunction());
1958             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1959             break;
1960         case EOpCeil:
1961             outputTriplet(out, visit, "ceil(", "", ")");
1962             break;
1963         case EOpFract:
1964             outputTriplet(out, visit, "frac(", "", ")");
1965             break;
1966         case EOpIsnan:
1967             if (node->getUseEmulatedFunction())
1968                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1969             else
1970                 outputTriplet(out, visit, "isnan(", "", ")");
1971             mRequiresIEEEStrictCompiling = true;
1972             break;
1973         case EOpIsinf:
1974             outputTriplet(out, visit, "isinf(", "", ")");
1975             break;
1976         case EOpFloatBitsToInt:
1977             outputTriplet(out, visit, "asint(", "", ")");
1978             break;
1979         case EOpFloatBitsToUint:
1980             outputTriplet(out, visit, "asuint(", "", ")");
1981             break;
1982         case EOpIntBitsToFloat:
1983             outputTriplet(out, visit, "asfloat(", "", ")");
1984             break;
1985         case EOpUintBitsToFloat:
1986             outputTriplet(out, visit, "asfloat(", "", ")");
1987             break;
1988         case EOpPackSnorm2x16:
1989         case EOpPackUnorm2x16:
1990         case EOpPackHalf2x16:
1991         case EOpUnpackSnorm2x16:
1992         case EOpUnpackUnorm2x16:
1993         case EOpUnpackHalf2x16:
1994         case EOpPackUnorm4x8:
1995         case EOpPackSnorm4x8:
1996         case EOpUnpackUnorm4x8:
1997         case EOpUnpackSnorm4x8:
1998             ASSERT(node->getUseEmulatedFunction());
1999             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2000             break;
2001         case EOpLength:
2002             outputTriplet(out, visit, "length(", "", ")");
2003             break;
2004         case EOpNormalize:
2005             outputTriplet(out, visit, "normalize(", "", ")");
2006             break;
2007         case EOpDFdx:
2008             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2009             {
2010                 outputTriplet(out, visit, "(", "", ", 0.0)");
2011             }
2012             else
2013             {
2014                 outputTriplet(out, visit, "ddx(", "", ")");
2015             }
2016             break;
2017         case EOpDFdy:
2018             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2019             {
2020                 outputTriplet(out, visit, "(", "", ", 0.0)");
2021             }
2022             else
2023             {
2024                 outputTriplet(out, visit, "ddy(", "", ")");
2025             }
2026             break;
2027         case EOpFwidth:
2028             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2029             {
2030                 outputTriplet(out, visit, "(", "", ", 0.0)");
2031             }
2032             else
2033             {
2034                 outputTriplet(out, visit, "fwidth(", "", ")");
2035             }
2036             break;
2037         case EOpTranspose:
2038             outputTriplet(out, visit, "transpose(", "", ")");
2039             break;
2040         case EOpDeterminant:
2041             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2042             break;
2043         case EOpInverse:
2044             ASSERT(node->getUseEmulatedFunction());
2045             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2046             break;
2047 
2048         case EOpAny:
2049             outputTriplet(out, visit, "any(", "", ")");
2050             break;
2051         case EOpAll:
2052             outputTriplet(out, visit, "all(", "", ")");
2053             break;
2054         case EOpLogicalNotComponentWise:
2055             outputTriplet(out, visit, "(!", "", ")");
2056             break;
2057         case EOpBitfieldReverse:
2058             outputTriplet(out, visit, "reversebits(", "", ")");
2059             break;
2060         case EOpBitCount:
2061             outputTriplet(out, visit, "countbits(", "", ")");
2062             break;
2063         case EOpFindLSB:
2064             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2065             // in GLSLTest and results are consistent with GL.
2066             outputTriplet(out, visit, "firstbitlow(", "", ")");
2067             break;
2068         case EOpFindMSB:
2069             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2070             // tested in GLSLTest and results are consistent with GL.
2071             outputTriplet(out, visit, "firstbithigh(", "", ")");
2072             break;
2073         case EOpArrayLength:
2074         {
2075             TIntermTyped *operand = node->getOperand();
2076             ASSERT(IsInShaderStorageBlock(operand));
2077             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2078             return false;
2079         }
2080         default:
2081             UNREACHABLE();
2082     }
2083 
2084     return true;
2085 }
2086 
samplerNamePrefixFromStruct(TIntermTyped * node)2087 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2088 {
2089     if (node->getAsSymbolNode())
2090     {
2091         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2092         return node->getAsSymbolNode()->getName();
2093     }
2094     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2095     switch (nodeBinary->getOp())
2096     {
2097         case EOpIndexDirect:
2098         {
2099             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2100 
2101             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2102             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2103             return ImmutableString(prefixSink.str());
2104         }
2105         case EOpIndexDirectStruct:
2106         {
2107             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2108             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2109             const TField *field = s->fields()[index];
2110 
2111             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2112             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2113                        << field->name();
2114             return ImmutableString(prefixSink.str());
2115         }
2116         default:
2117             UNREACHABLE();
2118             return kEmptyImmutableString;
2119     }
2120 }
2121 
visitBlock(Visit visit,TIntermBlock * node)2122 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2123 {
2124     TInfoSinkBase &out = getInfoSink();
2125 
2126     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2127 
2128     if (mInsideFunction)
2129     {
2130         outputLineDirective(out, node->getLine().first_line);
2131         out << "{\n";
2132         if (isMainBlock)
2133         {
2134             if (mShaderType == GL_COMPUTE_SHADER)
2135             {
2136                 out << "initGLBuiltins(input);\n";
2137             }
2138             else
2139             {
2140                 out << "@@ MAIN PROLOGUE @@\n";
2141             }
2142         }
2143     }
2144 
2145     for (TIntermNode *statement : *node->getSequence())
2146     {
2147         outputLineDirective(out, statement->getLine().first_line);
2148 
2149         statement->traverse(this);
2150 
2151         // Don't output ; after case labels, they're terminated by :
2152         // This is needed especially since outputting a ; after a case statement would turn empty
2153         // case statements into non-empty case statements, disallowing fall-through from them.
2154         // Also the output code is clearer if we don't output ; after statements where it is not
2155         // needed:
2156         //  * if statements
2157         //  * switch statements
2158         //  * blocks
2159         //  * function definitions
2160         //  * loops (do-while loops output the semicolon in VisitLoop)
2161         //  * declarations that don't generate output.
2162         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2163             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2164             statement->getAsSwitchNode() == nullptr &&
2165             statement->getAsFunctionDefinition() == nullptr &&
2166             (statement->getAsDeclarationNode() == nullptr ||
2167              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2168             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2169         {
2170             out << ";\n";
2171         }
2172     }
2173 
2174     if (mInsideFunction)
2175     {
2176         outputLineDirective(out, node->getLine().last_line);
2177         if (isMainBlock && shaderNeedsGenerateOutput())
2178         {
2179             // We could have an empty main, a main function without a branch at the end, or a main
2180             // function with a discard statement at the end. In these cases we need to add a return
2181             // statement.
2182             bool needReturnStatement =
2183                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2184                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2185             if (needReturnStatement)
2186             {
2187                 out << "return " << generateOutputCall() << ";\n";
2188             }
2189         }
2190         out << "}\n";
2191     }
2192 
2193     return false;
2194 }
2195 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2196 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2197 {
2198     TInfoSinkBase &out = getInfoSink();
2199 
2200     ASSERT(mCurrentFunctionMetadata == nullptr);
2201 
2202     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2203     ASSERT(index != CallDAG::InvalidIndex);
2204     mCurrentFunctionMetadata = &mASTMetadataList[index];
2205 
2206     const TFunction *func = node->getFunction();
2207 
2208     if (func->isMain())
2209     {
2210         // The stub strings below are replaced when shader is dynamically defined by its layout:
2211         switch (mShaderType)
2212         {
2213             case GL_VERTEX_SHADER:
2214                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2215                     << "@@ VERTEX OUTPUT @@\n\n"
2216                     << "VS_OUTPUT main(VS_INPUT input)";
2217                 break;
2218             case GL_FRAGMENT_SHADER:
2219                 out << "@@ PIXEL OUTPUT @@\n\n"
2220                     << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2221                 break;
2222             case GL_COMPUTE_SHADER:
2223                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2224                     << mWorkGroupSize[2] << ")]\n";
2225                 out << "void main(CS_INPUT input)";
2226                 break;
2227             default:
2228                 UNREACHABLE();
2229                 break;
2230         }
2231     }
2232     else
2233     {
2234         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2235         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2236             << (mOutputLod0Function ? "Lod0(" : "(");
2237 
2238         size_t paramCount = func->getParamCount();
2239         for (unsigned int i = 0; i < paramCount; i++)
2240         {
2241             const TVariable *param = func->getParam(i);
2242             ensureStructDefined(param->getType());
2243 
2244             writeParameter(param, out);
2245 
2246             if (i < paramCount - 1)
2247             {
2248                 out << ", ";
2249             }
2250         }
2251 
2252         out << ")\n";
2253     }
2254 
2255     mInsideFunction = true;
2256     if (func->isMain())
2257     {
2258         mInsideMain = true;
2259     }
2260     // The function body node will output braces.
2261     node->getBody()->traverse(this);
2262     mInsideFunction = false;
2263     mInsideMain     = false;
2264 
2265     mCurrentFunctionMetadata = nullptr;
2266 
2267     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2268     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2269     {
2270         ASSERT(!node->getFunction()->isMain());
2271         mOutputLod0Function = true;
2272         node->traverse(this);
2273         mOutputLod0Function = false;
2274     }
2275 
2276     return false;
2277 }
2278 
visitDeclaration(Visit visit,TIntermDeclaration * node)2279 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2280 {
2281     if (visit == PreVisit)
2282     {
2283         TIntermSequence *sequence = node->getSequence();
2284         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2285         ASSERT(sequence->size() == 1);
2286         ASSERT(declarator);
2287 
2288         if (IsDeclarationWrittenOut(node))
2289         {
2290             TInfoSinkBase &out = getInfoSink();
2291             ensureStructDefined(declarator->getType());
2292 
2293             if (!declarator->getAsSymbolNode() ||
2294                 declarator->getAsSymbolNode()->variable().symbolType() !=
2295                     SymbolType::Empty)  // Variable declaration
2296             {
2297                 if (declarator->getQualifier() == EvqShared)
2298                 {
2299                     out << "groupshared ";
2300                 }
2301                 else if (!mInsideFunction)
2302                 {
2303                     out << "static ";
2304                 }
2305 
2306                 out << TypeString(declarator->getType()) + " ";
2307 
2308                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2309 
2310                 if (symbol)
2311                 {
2312                     symbol->traverse(this);
2313                     out << ArrayString(symbol->getType());
2314                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2315                     // drivers to compile a compute shader if we add code to initialize a
2316                     // groupshared array variable with a large array size. And maybe produce
2317                     // incorrect result. See http://anglebug.com/3226.
2318                     if (declarator->getQualifier() != EvqShared)
2319                     {
2320                         out << " = " + zeroInitializer(symbol->getType());
2321                     }
2322                 }
2323                 else
2324                 {
2325                     declarator->traverse(this);
2326                 }
2327             }
2328         }
2329         else if (IsVaryingOut(declarator->getQualifier()))
2330         {
2331             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2332             ASSERT(symbol);  // Varying declarations can't have initializers.
2333 
2334             const TVariable &variable = symbol->variable();
2335 
2336             if (variable.symbolType() != SymbolType::Empty)
2337             {
2338                 // Vertex outputs which are declared but not written to should still be declared to
2339                 // allow successful linking.
2340                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2341             }
2342         }
2343     }
2344     return false;
2345 }
2346 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2347 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2348                                                  TIntermGlobalQualifierDeclaration *node)
2349 {
2350     // Do not do any translation
2351     return false;
2352 }
2353 
visitFunctionPrototype(TIntermFunctionPrototype * node)2354 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2355 {
2356     TInfoSinkBase &out = getInfoSink();
2357 
2358     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2359     // Skip the prototype if it is not implemented (and thus not used)
2360     if (index == CallDAG::InvalidIndex)
2361     {
2362         return;
2363     }
2364 
2365     const TFunction *func = node->getFunction();
2366 
2367     TString name = DecorateFunctionIfNeeded(func);
2368     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2369         << (mOutputLod0Function ? "Lod0(" : "(");
2370 
2371     size_t paramCount = func->getParamCount();
2372     for (unsigned int i = 0; i < paramCount; i++)
2373     {
2374         writeParameter(func->getParam(i), out);
2375 
2376         if (i < paramCount - 1)
2377         {
2378             out << ", ";
2379         }
2380     }
2381 
2382     out << ");\n";
2383 
2384     // Also prototype the Lod0 variant if needed
2385     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2386     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2387     {
2388         mOutputLod0Function = true;
2389         node->traverse(this);
2390         mOutputLod0Function = false;
2391     }
2392 }
2393 
visitAggregate(Visit visit,TIntermAggregate * node)2394 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2395 {
2396     TInfoSinkBase &out = getInfoSink();
2397 
2398     switch (node->getOp())
2399     {
2400         case EOpCallBuiltInFunction:
2401         case EOpCallFunctionInAST:
2402         case EOpCallInternalRawFunction:
2403         {
2404             TIntermSequence *arguments = node->getSequence();
2405 
2406             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2407                         mShaderType == GL_FRAGMENT_SHADER;
2408             if (node->getOp() == EOpCallFunctionInAST)
2409             {
2410                 if (node->isArray())
2411                 {
2412                     UNIMPLEMENTED();
2413                 }
2414                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2415                 ASSERT(index != CallDAG::InvalidIndex);
2416                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2417 
2418                 out << DecorateFunctionIfNeeded(node->getFunction());
2419                 out << DisambiguateFunctionName(node->getSequence());
2420                 out << (lod0 ? "Lod0(" : "(");
2421             }
2422             else if (node->getOp() == EOpCallInternalRawFunction)
2423             {
2424                 // This path is used for internal functions that don't have their definitions in the
2425                 // AST, such as precision emulation functions.
2426                 out << DecorateFunctionIfNeeded(node->getFunction()) << "(";
2427             }
2428             else if (node->getFunction()->isImageFunction())
2429             {
2430                 const ImmutableString &name              = node->getFunction()->name();
2431                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2432                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2433                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2434                     type.getMemoryQualifier().readonly);
2435                 out << imageFunctionName << "(";
2436             }
2437             else if (node->getFunction()->isAtomicCounterFunction())
2438             {
2439                 const ImmutableString &name = node->getFunction()->name();
2440                 ImmutableString atomicFunctionName =
2441                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2442                 out << atomicFunctionName << "(";
2443             }
2444             else
2445             {
2446                 const ImmutableString &name = node->getFunction()->name();
2447                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2448                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2449                 if (arguments->size() > 1)
2450                 {
2451                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2452                 }
2453                 const ImmutableString &textureFunctionName =
2454                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2455                                                              arguments->size(), lod0, mShaderType);
2456                 out << textureFunctionName << "(";
2457             }
2458 
2459             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2460             {
2461                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2462                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
2463                 {
2464                     out << "texture_";
2465                     (*arg)->traverse(this);
2466                     out << ", sampler_";
2467                 }
2468 
2469                 (*arg)->traverse(this);
2470 
2471                 if (typedArg->getType().isStructureContainingSamplers())
2472                 {
2473                     const TType &argType = typedArg->getType();
2474                     TVector<const TVariable *> samplerSymbols;
2475                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2476                     std::string namePrefix     = "angle_";
2477                     namePrefix += structName.data();
2478                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2479                                                  nullptr, mSymbolTable);
2480                     for (const TVariable *sampler : samplerSymbols)
2481                     {
2482                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2483                         {
2484                             out << ", texture_" << sampler->name();
2485                             out << ", sampler_" << sampler->name();
2486                         }
2487                         else
2488                         {
2489                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2490                             // of D3D9, it's the sampler variable.
2491                             out << ", " << sampler->name();
2492                         }
2493                     }
2494                 }
2495 
2496                 if (arg < arguments->end() - 1)
2497                 {
2498                     out << ", ";
2499                 }
2500             }
2501 
2502             out << ")";
2503 
2504             return false;
2505         }
2506         case EOpConstruct:
2507             outputConstructor(out, visit, node);
2508             break;
2509         case EOpEqualComponentWise:
2510             outputTriplet(out, visit, "(", " == ", ")");
2511             break;
2512         case EOpNotEqualComponentWise:
2513             outputTriplet(out, visit, "(", " != ", ")");
2514             break;
2515         case EOpLessThanComponentWise:
2516             outputTriplet(out, visit, "(", " < ", ")");
2517             break;
2518         case EOpGreaterThanComponentWise:
2519             outputTriplet(out, visit, "(", " > ", ")");
2520             break;
2521         case EOpLessThanEqualComponentWise:
2522             outputTriplet(out, visit, "(", " <= ", ")");
2523             break;
2524         case EOpGreaterThanEqualComponentWise:
2525             outputTriplet(out, visit, "(", " >= ", ")");
2526             break;
2527         case EOpMod:
2528             ASSERT(node->getUseEmulatedFunction());
2529             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2530             break;
2531         case EOpModf:
2532             outputTriplet(out, visit, "modf(", ", ", ")");
2533             break;
2534         case EOpPow:
2535             outputTriplet(out, visit, "pow(", ", ", ")");
2536             break;
2537         case EOpAtan:
2538             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2539             ASSERT(node->getUseEmulatedFunction());
2540             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2541             break;
2542         case EOpMin:
2543             outputTriplet(out, visit, "min(", ", ", ")");
2544             break;
2545         case EOpMax:
2546             outputTriplet(out, visit, "max(", ", ", ")");
2547             break;
2548         case EOpClamp:
2549             outputTriplet(out, visit, "clamp(", ", ", ")");
2550             break;
2551         case EOpMix:
2552         {
2553             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2554             if (lastParamNode->getType().getBasicType() == EbtBool)
2555             {
2556                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2557                 // y, genBType a)",
2558                 // so use emulated version.
2559                 ASSERT(node->getUseEmulatedFunction());
2560                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2561             }
2562             else
2563             {
2564                 outputTriplet(out, visit, "lerp(", ", ", ")");
2565             }
2566             break;
2567         }
2568         case EOpStep:
2569             outputTriplet(out, visit, "step(", ", ", ")");
2570             break;
2571         case EOpSmoothstep:
2572             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2573             break;
2574         case EOpFma:
2575             outputTriplet(out, visit, "mad(", ", ", ")");
2576             break;
2577         case EOpFrexp:
2578         case EOpLdexp:
2579             ASSERT(node->getUseEmulatedFunction());
2580             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2581             break;
2582         case EOpDistance:
2583             outputTriplet(out, visit, "distance(", ", ", ")");
2584             break;
2585         case EOpDot:
2586             outputTriplet(out, visit, "dot(", ", ", ")");
2587             break;
2588         case EOpCross:
2589             outputTriplet(out, visit, "cross(", ", ", ")");
2590             break;
2591         case EOpFaceforward:
2592             ASSERT(node->getUseEmulatedFunction());
2593             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2594             break;
2595         case EOpReflect:
2596             outputTriplet(out, visit, "reflect(", ", ", ")");
2597             break;
2598         case EOpRefract:
2599             outputTriplet(out, visit, "refract(", ", ", ")");
2600             break;
2601         case EOpOuterProduct:
2602             ASSERT(node->getUseEmulatedFunction());
2603             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2604             break;
2605         case EOpMulMatrixComponentWise:
2606             outputTriplet(out, visit, "(", " * ", ")");
2607             break;
2608         case EOpBitfieldExtract:
2609         case EOpBitfieldInsert:
2610         case EOpUaddCarry:
2611         case EOpUsubBorrow:
2612         case EOpUmulExtended:
2613         case EOpImulExtended:
2614             ASSERT(node->getUseEmulatedFunction());
2615             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2616             break;
2617         case EOpBarrier:
2618             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2619             // cheapest *WithGroupSync() function, without any functionality loss, but
2620             // with the potential for severe performance loss.
2621             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2622             break;
2623         case EOpMemoryBarrierShared:
2624             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2625             break;
2626         case EOpMemoryBarrierAtomicCounter:
2627         case EOpMemoryBarrierBuffer:
2628         case EOpMemoryBarrierImage:
2629             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2630             break;
2631         case EOpGroupMemoryBarrier:
2632         case EOpMemoryBarrier:
2633             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2634             break;
2635 
2636         // Single atomic function calls without return value.
2637         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2638         case EOpAtomicAdd:
2639         case EOpAtomicMin:
2640         case EOpAtomicMax:
2641         case EOpAtomicAnd:
2642         case EOpAtomicOr:
2643         case EOpAtomicXor:
2644         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2645         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2646         // optional.
2647         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2648         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2649         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2650         // compare_value, value) should all be modified into the form of "int temp; temp =
2651         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2652         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2653         case EOpAtomicExchange:
2654         case EOpAtomicCompSwap:
2655         {
2656             ASSERT(node->getChildCount() > 1);
2657             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2658             if (IsInShaderStorageBlock(memNode))
2659             {
2660                 // Atomic memory functions for SSBO.
2661                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2662                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2663                 // Write the rest argument list to |out|.
2664                 for (size_t i = 1; i < node->getChildCount(); i++)
2665                 {
2666                     out << ", ";
2667                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2668                     if (IsInShaderStorageBlock(argument))
2669                     {
2670                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2671                     }
2672                     else
2673                     {
2674                         argument->traverse(this);
2675                     }
2676                 }
2677 
2678                 out << ")";
2679                 return false;
2680             }
2681             else
2682             {
2683                 // Atomic memory functions for shared variable.
2684                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2685                 {
2686                     outputTriplet(out, visit,
2687                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2688                                   ")");
2689                 }
2690                 else
2691                 {
2692                     UNREACHABLE();
2693                 }
2694             }
2695 
2696             break;
2697         }
2698         default:
2699             UNREACHABLE();
2700     }
2701 
2702     return true;
2703 }
2704 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2705 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2706 {
2707     out << "if (";
2708 
2709     node->getCondition()->traverse(this);
2710 
2711     out << ")\n";
2712 
2713     outputLineDirective(out, node->getLine().first_line);
2714 
2715     bool discard = false;
2716 
2717     if (node->getTrueBlock())
2718     {
2719         // The trueBlock child node will output braces.
2720         node->getTrueBlock()->traverse(this);
2721 
2722         // Detect true discard
2723         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2724     }
2725     else
2726     {
2727         // TODO(oetuaho): Check if the semicolon inside is necessary.
2728         // It's there as a result of conservative refactoring of the output.
2729         out << "{;}\n";
2730     }
2731 
2732     outputLineDirective(out, node->getLine().first_line);
2733 
2734     if (node->getFalseBlock())
2735     {
2736         out << "else\n";
2737 
2738         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2739 
2740         // The falseBlock child node will output braces.
2741         node->getFalseBlock()->traverse(this);
2742 
2743         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2744 
2745         // Detect false discard
2746         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2747     }
2748 
2749     // ANGLE issue 486: Detect problematic conditional discard
2750     if (discard)
2751     {
2752         mUsesDiscardRewriting = true;
2753     }
2754 }
2755 
visitTernary(Visit,TIntermTernary *)2756 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2757 {
2758     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2759     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2760     UNREACHABLE();
2761     return false;
2762 }
2763 
visitIfElse(Visit visit,TIntermIfElse * node)2764 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2765 {
2766     TInfoSinkBase &out = getInfoSink();
2767 
2768     ASSERT(mInsideFunction);
2769 
2770     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2771     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2772     {
2773         out << "FLATTEN ";
2774     }
2775 
2776     writeIfElse(out, node);
2777 
2778     return false;
2779 }
2780 
visitSwitch(Visit visit,TIntermSwitch * node)2781 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2782 {
2783     TInfoSinkBase &out = getInfoSink();
2784 
2785     ASSERT(node->getStatementList());
2786     if (visit == PreVisit)
2787     {
2788         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2789     }
2790     outputTriplet(out, visit, "switch (", ") ", "");
2791     // The curly braces get written when visiting the statementList block.
2792     return true;
2793 }
2794 
visitCase(Visit visit,TIntermCase * node)2795 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2796 {
2797     TInfoSinkBase &out = getInfoSink();
2798 
2799     if (node->hasCondition())
2800     {
2801         outputTriplet(out, visit, "case (", "", "):\n");
2802         return true;
2803     }
2804     else
2805     {
2806         out << "default:\n";
2807         return false;
2808     }
2809 }
2810 
visitConstantUnion(TIntermConstantUnion * node)2811 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2812 {
2813     TInfoSinkBase &out = getInfoSink();
2814     writeConstantUnion(out, node->getType(), node->getConstantValue());
2815 }
2816 
visitLoop(Visit visit,TIntermLoop * node)2817 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2818 {
2819     mNestedLoopDepth++;
2820 
2821     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2822     mInsideDiscontinuousLoop =
2823         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2824 
2825     TInfoSinkBase &out = getInfoSink();
2826 
2827     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2828     {
2829         if (handleExcessiveLoop(out, node))
2830         {
2831             mInsideDiscontinuousLoop = wasDiscontinuous;
2832             mNestedLoopDepth--;
2833 
2834             return false;
2835         }
2836     }
2837 
2838     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2839     if (node->getType() == ELoopDoWhile)
2840     {
2841         out << "{" << unroll << " do\n";
2842 
2843         outputLineDirective(out, node->getLine().first_line);
2844     }
2845     else
2846     {
2847         out << "{" << unroll << " for(";
2848 
2849         if (node->getInit())
2850         {
2851             node->getInit()->traverse(this);
2852         }
2853 
2854         out << "; ";
2855 
2856         if (node->getCondition())
2857         {
2858             node->getCondition()->traverse(this);
2859         }
2860 
2861         out << "; ";
2862 
2863         if (node->getExpression())
2864         {
2865             node->getExpression()->traverse(this);
2866         }
2867 
2868         out << ")\n";
2869 
2870         outputLineDirective(out, node->getLine().first_line);
2871     }
2872 
2873     if (node->getBody())
2874     {
2875         // The loop body node will output braces.
2876         node->getBody()->traverse(this);
2877     }
2878     else
2879     {
2880         // TODO(oetuaho): Check if the semicolon inside is necessary.
2881         // It's there as a result of conservative refactoring of the output.
2882         out << "{;}\n";
2883     }
2884 
2885     outputLineDirective(out, node->getLine().first_line);
2886 
2887     if (node->getType() == ELoopDoWhile)
2888     {
2889         outputLineDirective(out, node->getCondition()->getLine().first_line);
2890         out << "while (";
2891 
2892         node->getCondition()->traverse(this);
2893 
2894         out << ");\n";
2895     }
2896 
2897     out << "}\n";
2898 
2899     mInsideDiscontinuousLoop = wasDiscontinuous;
2900     mNestedLoopDepth--;
2901 
2902     return false;
2903 }
2904 
visitBranch(Visit visit,TIntermBranch * node)2905 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2906 {
2907     if (visit == PreVisit)
2908     {
2909         TInfoSinkBase &out = getInfoSink();
2910 
2911         switch (node->getFlowOp())
2912         {
2913             case EOpKill:
2914                 out << "discard";
2915                 break;
2916             case EOpBreak:
2917                 if (mNestedLoopDepth > 1)
2918                 {
2919                     mUsesNestedBreak = true;
2920                 }
2921 
2922                 if (mExcessiveLoopIndex)
2923                 {
2924                     out << "{Break";
2925                     mExcessiveLoopIndex->traverse(this);
2926                     out << " = true; break;}\n";
2927                 }
2928                 else
2929                 {
2930                     out << "break";
2931                 }
2932                 break;
2933             case EOpContinue:
2934                 out << "continue";
2935                 break;
2936             case EOpReturn:
2937                 if (node->getExpression())
2938                 {
2939                     ASSERT(!mInsideMain);
2940                     out << "return ";
2941                 }
2942                 else
2943                 {
2944                     if (mInsideMain && shaderNeedsGenerateOutput())
2945                     {
2946                         out << "return " << generateOutputCall();
2947                     }
2948                     else
2949                     {
2950                         out << "return";
2951                     }
2952                 }
2953                 break;
2954             default:
2955                 UNREACHABLE();
2956         }
2957     }
2958 
2959     return true;
2960 }
2961 
2962 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2963 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2964 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2965 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2966 {
2967     const int MAX_LOOP_ITERATIONS = 254;
2968 
2969     // Parse loops of the form:
2970     // for(int index = initial; index [comparator] limit; index += increment)
2971     TIntermSymbol *index = nullptr;
2972     TOperator comparator = EOpNull;
2973     int initial          = 0;
2974     int limit            = 0;
2975     int increment        = 0;
2976 
2977     // Parse index name and intial value
2978     if (node->getInit())
2979     {
2980         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2981 
2982         if (init)
2983         {
2984             TIntermSequence *sequence = init->getSequence();
2985             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2986 
2987             if (variable && variable->getQualifier() == EvqTemporary)
2988             {
2989                 TIntermBinary *assign = variable->getAsBinaryNode();
2990 
2991                 if (assign->getOp() == EOpInitialize)
2992                 {
2993                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2994                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2995 
2996                     if (symbol && constant)
2997                     {
2998                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2999                         {
3000                             index   = symbol;
3001                             initial = constant->getIConst(0);
3002                         }
3003                     }
3004                 }
3005             }
3006         }
3007     }
3008 
3009     // Parse comparator and limit value
3010     if (index != nullptr && node->getCondition())
3011     {
3012         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3013 
3014         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3015         {
3016             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3017 
3018             if (constant)
3019             {
3020                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3021                 {
3022                     comparator = test->getOp();
3023                     limit      = constant->getIConst(0);
3024                 }
3025             }
3026         }
3027     }
3028 
3029     // Parse increment
3030     if (index != nullptr && comparator != EOpNull && node->getExpression())
3031     {
3032         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3033         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3034 
3035         if (binaryTerminal)
3036         {
3037             TOperator op                   = binaryTerminal->getOp();
3038             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3039 
3040             if (constant)
3041             {
3042                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3043                 {
3044                     int value = constant->getIConst(0);
3045 
3046                     switch (op)
3047                     {
3048                         case EOpAddAssign:
3049                             increment = value;
3050                             break;
3051                         case EOpSubAssign:
3052                             increment = -value;
3053                             break;
3054                         default:
3055                             UNIMPLEMENTED();
3056                     }
3057                 }
3058             }
3059         }
3060         else if (unaryTerminal)
3061         {
3062             TOperator op = unaryTerminal->getOp();
3063 
3064             switch (op)
3065             {
3066                 case EOpPostIncrement:
3067                     increment = 1;
3068                     break;
3069                 case EOpPostDecrement:
3070                     increment = -1;
3071                     break;
3072                 case EOpPreIncrement:
3073                     increment = 1;
3074                     break;
3075                 case EOpPreDecrement:
3076                     increment = -1;
3077                     break;
3078                 default:
3079                     UNIMPLEMENTED();
3080             }
3081         }
3082     }
3083 
3084     if (index != nullptr && comparator != EOpNull && increment != 0)
3085     {
3086         if (comparator == EOpLessThanEqual)
3087         {
3088             comparator = EOpLessThan;
3089             limit += 1;
3090         }
3091 
3092         if (comparator == EOpLessThan)
3093         {
3094             int iterations = (limit - initial) / increment;
3095 
3096             if (iterations <= MAX_LOOP_ITERATIONS)
3097             {
3098                 return false;  // Not an excessive loop
3099             }
3100 
3101             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3102             mExcessiveLoopIndex         = index;
3103 
3104             out << "{int ";
3105             index->traverse(this);
3106             out << ";\n"
3107                    "bool Break";
3108             index->traverse(this);
3109             out << " = false;\n";
3110 
3111             bool firstLoopFragment = true;
3112 
3113             while (iterations > 0)
3114             {
3115                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3116 
3117                 if (!firstLoopFragment)
3118                 {
3119                     out << "if (!Break";
3120                     index->traverse(this);
3121                     out << ") {\n";
3122                 }
3123 
3124                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3125                 {
3126                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3127                 }
3128 
3129                 // for(int index = initial; index < clampedLimit; index += increment)
3130                 const char *unroll =
3131                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3132 
3133                 out << unroll << " for(";
3134                 index->traverse(this);
3135                 out << " = ";
3136                 out << initial;
3137 
3138                 out << "; ";
3139                 index->traverse(this);
3140                 out << " < ";
3141                 out << clampedLimit;
3142 
3143                 out << "; ";
3144                 index->traverse(this);
3145                 out << " += ";
3146                 out << increment;
3147                 out << ")\n";
3148 
3149                 outputLineDirective(out, node->getLine().first_line);
3150                 out << "{\n";
3151 
3152                 if (node->getBody())
3153                 {
3154                     node->getBody()->traverse(this);
3155                 }
3156 
3157                 outputLineDirective(out, node->getLine().first_line);
3158                 out << ";}\n";
3159 
3160                 if (!firstLoopFragment)
3161                 {
3162                     out << "}\n";
3163                 }
3164 
3165                 firstLoopFragment = false;
3166 
3167                 initial += MAX_LOOP_ITERATIONS * increment;
3168                 iterations -= MAX_LOOP_ITERATIONS;
3169             }
3170 
3171             out << "}";
3172 
3173             mExcessiveLoopIndex = restoreIndex;
3174 
3175             return true;
3176         }
3177         else
3178             UNIMPLEMENTED();
3179     }
3180 
3181     return false;  // Not handled as an excessive loop
3182 }
3183 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3184 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3185                                Visit visit,
3186                                const char *preString,
3187                                const char *inString,
3188                                const char *postString)
3189 {
3190     if (visit == PreVisit)
3191     {
3192         out << preString;
3193     }
3194     else if (visit == InVisit)
3195     {
3196         out << inString;
3197     }
3198     else if (visit == PostVisit)
3199     {
3200         out << postString;
3201     }
3202 }
3203 
outputLineDirective(TInfoSinkBase & out,int line)3204 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3205 {
3206     if ((mCompileOptions & SH_LINE_DIRECTIVES) != 0 && line > 0)
3207     {
3208         out << "\n";
3209         out << "#line " << line;
3210 
3211         if (mSourcePath)
3212         {
3213             out << " \"" << mSourcePath << "\"";
3214         }
3215 
3216         out << "\n";
3217     }
3218 }
3219 
writeParameter(const TVariable * param,TInfoSinkBase & out)3220 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3221 {
3222     const TType &type    = param->getType();
3223     TQualifier qualifier = type.getQualifier();
3224 
3225     TString nameStr = DecorateVariableIfNeeded(*param);
3226     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3227 
3228     if (IsSampler(type.getBasicType()))
3229     {
3230         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3231         {
3232             // Samplers are passed as indices to the sampler array.
3233             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3234             out << "const uint " << nameStr << ArrayString(type);
3235             return;
3236         }
3237         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3238         {
3239             out << QualifierString(qualifier) << " " << TextureString(type.getBasicType())
3240                 << " texture_" << nameStr << ArrayString(type) << ", " << QualifierString(qualifier)
3241                 << " " << SamplerString(type.getBasicType()) << " sampler_" << nameStr
3242                 << ArrayString(type);
3243             return;
3244         }
3245     }
3246 
3247     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3248     // buffer offset.
3249     if (IsAtomicCounter(type.getBasicType()))
3250     {
3251         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3252             << nameStr << "_offset";
3253     }
3254     else
3255     {
3256         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3257             << ArrayString(type);
3258     }
3259 
3260     // If the structure parameter contains samplers, they need to be passed into the function as
3261     // separate parameters. HLSL doesn't natively support samplers in structs.
3262     if (type.isStructureContainingSamplers())
3263     {
3264         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
3265         TVector<const TVariable *> samplerSymbols;
3266         std::string namePrefix = "angle";
3267         namePrefix += nameStr.c_str();
3268         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3269                                   mSymbolTable);
3270         for (const TVariable *sampler : samplerSymbols)
3271         {
3272             const TType &samplerType = sampler->getType();
3273             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3274             {
3275                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3276             }
3277             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
3278             {
3279                 ASSERT(IsSampler(samplerType.getBasicType()));
3280                 out << ", " << QualifierString(qualifier) << " "
3281                     << TextureString(samplerType.getBasicType()) << " texture_" << sampler->name()
3282                     << ArrayString(samplerType) << ", " << QualifierString(qualifier) << " "
3283                     << SamplerString(samplerType.getBasicType()) << " sampler_" << sampler->name()
3284                     << ArrayString(samplerType);
3285             }
3286             else
3287             {
3288                 ASSERT(IsSampler(samplerType.getBasicType()));
3289                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3290                     << sampler->name() << ArrayString(samplerType);
3291             }
3292         }
3293     }
3294 }
3295 
zeroInitializer(const TType & type) const3296 TString OutputHLSL::zeroInitializer(const TType &type) const
3297 {
3298     TString string;
3299 
3300     size_t size = type.getObjectSize();
3301     if (size >= kZeroCount)
3302     {
3303         mUseZeroArray = true;
3304     }
3305     string = GetZeroInitializer(size).c_str();
3306 
3307     return "{" + string + "}";
3308 }
3309 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3310 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3311 {
3312     // Array constructors should have been already pruned from the code.
3313     ASSERT(!node->getType().isArray());
3314 
3315     if (visit == PreVisit)
3316     {
3317         TString constructorName;
3318         if (node->getBasicType() == EbtStruct)
3319         {
3320             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3321         }
3322         else
3323         {
3324             constructorName =
3325                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3326         }
3327         out << constructorName << "(";
3328     }
3329     else if (visit == InVisit)
3330     {
3331         out << ", ";
3332     }
3333     else if (visit == PostVisit)
3334     {
3335         out << ")";
3336     }
3337 }
3338 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3339 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3340                                                      const TType &type,
3341                                                      const TConstantUnion *const constUnion)
3342 {
3343     ASSERT(!type.isArray());
3344 
3345     const TConstantUnion *constUnionIterated = constUnion;
3346 
3347     const TStructure *structure = type.getStruct();
3348     if (structure)
3349     {
3350         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3351 
3352         const TFieldList &fields = structure->fields();
3353 
3354         for (size_t i = 0; i < fields.size(); i++)
3355         {
3356             const TType *fieldType = fields[i]->type();
3357             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3358 
3359             if (i != fields.size() - 1)
3360             {
3361                 out << ", ";
3362             }
3363         }
3364 
3365         out << ")";
3366     }
3367     else
3368     {
3369         size_t size    = type.getObjectSize();
3370         bool writeType = size > 1;
3371 
3372         if (writeType)
3373         {
3374             out << TypeString(type) << "(";
3375         }
3376         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3377         if (writeType)
3378         {
3379             out << ")";
3380         }
3381     }
3382 
3383     return constUnionIterated;
3384 }
3385 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)3386 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
3387 {
3388     if (visit == PreVisit)
3389     {
3390         const char *opStr = GetOperatorString(op);
3391         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
3392         out << "(";
3393     }
3394     else
3395     {
3396         outputTriplet(out, visit, nullptr, ", ", ")");
3397     }
3398 }
3399 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3400 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3401                                             TIntermSymbol *symbolNode,
3402                                             TIntermTyped *expression)
3403 {
3404     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3405     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3406 
3407     if (symbolInInitializer)
3408     {
3409         // Type already printed
3410         out << "t" + str(mUniqueIndex) + " = ";
3411         expression->traverse(this);
3412         out << ", ";
3413         symbolNode->traverse(this);
3414         out << " = t" + str(mUniqueIndex);
3415 
3416         mUniqueIndex++;
3417         return true;
3418     }
3419 
3420     return false;
3421 }
3422 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3423 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3424                                              TIntermSymbol *symbolNode,
3425                                              TIntermTyped *initializer)
3426 {
3427     if (initializer->hasConstantValue())
3428     {
3429         symbolNode->traverse(this);
3430         out << ArrayString(symbolNode->getType());
3431         out << " = {";
3432         writeConstantUnionArray(out, initializer->getConstantValue(),
3433                                 initializer->getType().getObjectSize());
3434         out << "}";
3435         return true;
3436     }
3437     return false;
3438 }
3439 
addStructEqualityFunction(const TStructure & structure)3440 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3441 {
3442     const TFieldList &fields = structure.fields();
3443 
3444     for (const auto &eqFunction : mStructEqualityFunctions)
3445     {
3446         if (eqFunction->structure == &structure)
3447         {
3448             return eqFunction->functionName;
3449         }
3450     }
3451 
3452     const TString &structNameString = StructNameString(structure);
3453 
3454     StructEqualityFunction *function = new StructEqualityFunction();
3455     function->structure              = &structure;
3456     function->functionName           = "angle_eq_" + structNameString;
3457 
3458     TInfoSinkBase fnOut;
3459 
3460     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3461           << structNameString + " b)\n"
3462           << "{\n"
3463              "    return ";
3464 
3465     for (size_t i = 0; i < fields.size(); i++)
3466     {
3467         const TField *field    = fields[i];
3468         const TType *fieldType = field->type();
3469 
3470         const TString &fieldNameA = "a." + Decorate(field->name());
3471         const TString &fieldNameB = "b." + Decorate(field->name());
3472 
3473         if (i > 0)
3474         {
3475             fnOut << " && ";
3476         }
3477 
3478         fnOut << "(";
3479         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3480         fnOut << fieldNameA;
3481         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3482         fnOut << fieldNameB;
3483         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3484         fnOut << ")";
3485     }
3486 
3487     fnOut << ";\n"
3488           << "}\n";
3489 
3490     function->functionDefinition = fnOut.c_str();
3491 
3492     mStructEqualityFunctions.push_back(function);
3493     mEqualityFunctions.push_back(function);
3494 
3495     return function->functionName;
3496 }
3497 
addArrayEqualityFunction(const TType & type)3498 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3499 {
3500     for (const auto &eqFunction : mArrayEqualityFunctions)
3501     {
3502         if (eqFunction->type == type)
3503         {
3504             return eqFunction->functionName;
3505         }
3506     }
3507 
3508     TType elementType(type);
3509     elementType.toArrayElementType();
3510 
3511     ArrayHelperFunction *function = new ArrayHelperFunction();
3512     function->type                = type;
3513 
3514     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3515 
3516     TInfoSinkBase fnOut;
3517 
3518     const TString &typeName = TypeString(type);
3519     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3520           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3521           << "{\n"
3522              "    for (int i = 0; i < "
3523           << type.getOutermostArraySize()
3524           << "; ++i)\n"
3525              "    {\n"
3526              "        if (";
3527 
3528     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3529     fnOut << "a[i]";
3530     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3531     fnOut << "b[i]";
3532     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3533 
3534     fnOut << ") { return false; }\n"
3535              "    }\n"
3536              "    return true;\n"
3537              "}\n";
3538 
3539     function->functionDefinition = fnOut.c_str();
3540 
3541     mArrayEqualityFunctions.push_back(function);
3542     mEqualityFunctions.push_back(function);
3543 
3544     return function->functionName;
3545 }
3546 
addArrayAssignmentFunction(const TType & type)3547 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3548 {
3549     for (const auto &assignFunction : mArrayAssignmentFunctions)
3550     {
3551         if (assignFunction.type == type)
3552         {
3553             return assignFunction.functionName;
3554         }
3555     }
3556 
3557     TType elementType(type);
3558     elementType.toArrayElementType();
3559 
3560     ArrayHelperFunction function;
3561     function.type = type;
3562 
3563     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3564 
3565     TInfoSinkBase fnOut;
3566 
3567     const TString &typeName = TypeString(type);
3568     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3569           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3570           << "{\n"
3571              "    for (int i = 0; i < "
3572           << type.getOutermostArraySize()
3573           << "; ++i)\n"
3574              "    {\n"
3575              "        ";
3576 
3577     outputAssign(PreVisit, elementType, fnOut);
3578     fnOut << "a[i]";
3579     outputAssign(InVisit, elementType, fnOut);
3580     fnOut << "b[i]";
3581     outputAssign(PostVisit, elementType, fnOut);
3582 
3583     fnOut << ";\n"
3584              "    }\n"
3585              "}\n";
3586 
3587     function.functionDefinition = fnOut.c_str();
3588 
3589     mArrayAssignmentFunctions.push_back(function);
3590 
3591     return function.functionName;
3592 }
3593 
addArrayConstructIntoFunction(const TType & type)3594 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3595 {
3596     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3597     {
3598         if (constructIntoFunction.type == type)
3599         {
3600             return constructIntoFunction.functionName;
3601         }
3602     }
3603 
3604     TType elementType(type);
3605     elementType.toArrayElementType();
3606 
3607     ArrayHelperFunction function;
3608     function.type = type;
3609 
3610     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3611 
3612     TInfoSinkBase fnOut;
3613 
3614     const TString &typeName = TypeString(type);
3615     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3616     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3617     {
3618         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3619     }
3620     fnOut << ")\n"
3621              "{\n";
3622 
3623     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3624     {
3625         fnOut << "    ";
3626         outputAssign(PreVisit, elementType, fnOut);
3627         fnOut << "a[" << i << "]";
3628         outputAssign(InVisit, elementType, fnOut);
3629         fnOut << "b" << i;
3630         outputAssign(PostVisit, elementType, fnOut);
3631         fnOut << ";\n";
3632     }
3633     fnOut << "}\n";
3634 
3635     function.functionDefinition = fnOut.c_str();
3636 
3637     mArrayConstructIntoFunctions.push_back(function);
3638 
3639     return function.functionName;
3640 }
3641 
ensureStructDefined(const TType & type)3642 void OutputHLSL::ensureStructDefined(const TType &type)
3643 {
3644     const TStructure *structure = type.getStruct();
3645     if (structure)
3646     {
3647         ASSERT(type.getBasicType() == EbtStruct);
3648         mStructureHLSL->ensureStructDefined(*structure);
3649     }
3650 }
3651 
shaderNeedsGenerateOutput() const3652 bool OutputHLSL::shaderNeedsGenerateOutput() const
3653 {
3654     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3655 }
3656 
generateOutputCall() const3657 const char *OutputHLSL::generateOutputCall() const
3658 {
3659     if (mShaderType == GL_VERTEX_SHADER)
3660     {
3661         return "generateOutput(input)";
3662     }
3663     else
3664     {
3665         return "generateOutput()";
3666     }
3667 }
3668 }  // namespace sh
3669