1 //
2 // Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/OutputHLSL.h"
8 
9 #include <algorithm>
10 #include <cfloat>
11 #include <stdio.h>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/BuiltInFunctionEmulator.h"
17 #include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
18 #include "compiler/translator/ImageFunctionHLSL.h"
19 #include "compiler/translator/InfoSink.h"
20 #include "compiler/translator/NodeSearch.h"
21 #include "compiler/translator/RemoveSwitchFallThrough.h"
22 #include "compiler/translator/SearchSymbol.h"
23 #include "compiler/translator/StructureHLSL.h"
24 #include "compiler/translator/TextureFunctionHLSL.h"
25 #include "compiler/translator/TranslatorHLSL.h"
26 #include "compiler/translator/UniformHLSL.h"
27 #include "compiler/translator/UtilsHLSL.h"
28 #include "compiler/translator/blocklayout.h"
29 #include "compiler/translator/util.h"
30 
31 namespace sh
32 {
33 
34 namespace
35 {
36 
ArrayHelperFunctionName(const char * prefix,const TType & type)37 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
38 {
39     TStringStream fnName;
40     fnName << prefix << "_";
41     if (type.isArray())
42     {
43         for (unsigned int arraySize : *type.getArraySizes())
44         {
45             fnName << arraySize << "_";
46         }
47     }
48     fnName << TypeString(type);
49     return fnName.str();
50 }
51 
IsDeclarationWrittenOut(TIntermDeclaration * node)52 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
53 {
54     TIntermSequence *sequence = node->getSequence();
55     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
56     ASSERT(sequence->size() == 1);
57     ASSERT(variable);
58     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
59             variable->getQualifier() == EvqConst);
60 }
61 
IsInStd140InterfaceBlock(TIntermTyped * node)62 bool IsInStd140InterfaceBlock(TIntermTyped *node)
63 {
64     TIntermBinary *binaryNode = node->getAsBinaryNode();
65 
66     if (binaryNode)
67     {
68         return IsInStd140InterfaceBlock(binaryNode->getLeft());
69     }
70 
71     const TType &type = node->getType();
72 
73     // determine if we are in the standard layout
74     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
75     if (interfaceBlock)
76     {
77         return (interfaceBlock->blockStorage() == EbsStd140);
78     }
79 
80     return false;
81 }
82 
83 }  // anonymous namespace
84 
writeFloat(TInfoSinkBase & out,float f)85 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
86 {
87     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
88     // regardless.
89     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
90         mOutputType == SH_HLSL_4_1_OUTPUT)
91     {
92         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
93     }
94     else
95     {
96         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
97     }
98 }
99 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)100 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
101 {
102     ASSERT(constUnion != nullptr);
103     switch (constUnion->getType())
104     {
105         case EbtFloat:
106             writeFloat(out, constUnion->getFConst());
107             break;
108         case EbtInt:
109             out << constUnion->getIConst();
110             break;
111         case EbtUInt:
112             out << constUnion->getUConst();
113             break;
114         case EbtBool:
115             out << constUnion->getBConst();
116             break;
117         default:
118             UNREACHABLE();
119     }
120 }
121 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)122 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
123                                                           const TConstantUnion *const constUnion,
124                                                           const size_t size)
125 {
126     const TConstantUnion *constUnionIterated = constUnion;
127     for (size_t i = 0; i < size; i++, constUnionIterated++)
128     {
129         writeSingleConstant(out, constUnionIterated);
130 
131         if (i != size - 1)
132         {
133             out << ", ";
134         }
135     }
136     return constUnionIterated;
137 }
138 
OutputHLSL(sh::GLenum shaderType,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,const std::vector<Uniform> & uniforms,ShCompileOptions compileOptions,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics)139 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
140                        int shaderVersion,
141                        const TExtensionBehavior &extensionBehavior,
142                        const char *sourcePath,
143                        ShShaderOutput outputType,
144                        int numRenderTargets,
145                        const std::vector<Uniform> &uniforms,
146                        ShCompileOptions compileOptions,
147                        TSymbolTable *symbolTable,
148                        PerformanceDiagnostics *perfDiagnostics)
149     : TIntermTraverser(true, true, true, symbolTable),
150       mShaderType(shaderType),
151       mShaderVersion(shaderVersion),
152       mExtensionBehavior(extensionBehavior),
153       mSourcePath(sourcePath),
154       mOutputType(outputType),
155       mCompileOptions(compileOptions),
156       mNumRenderTargets(numRenderTargets),
157       mCurrentFunctionMetadata(nullptr),
158       mPerfDiagnostics(perfDiagnostics)
159 {
160     mInsideFunction = false;
161 
162     mUsesFragColor               = false;
163     mUsesFragData                = false;
164     mUsesDepthRange              = false;
165     mUsesFragCoord               = false;
166     mUsesPointCoord              = false;
167     mUsesFrontFacing             = false;
168     mUsesPointSize               = false;
169     mUsesInstanceID              = false;
170     mHasMultiviewExtensionEnabled =
171         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview);
172     mUsesViewID                  = false;
173     mUsesVertexID                = false;
174     mUsesFragDepth               = false;
175     mUsesNumWorkGroups           = false;
176     mUsesWorkGroupID             = false;
177     mUsesLocalInvocationID       = false;
178     mUsesGlobalInvocationID      = false;
179     mUsesLocalInvocationIndex    = false;
180     mUsesXor                     = false;
181     mUsesDiscardRewriting        = false;
182     mUsesNestedBreak             = false;
183     mRequiresIEEEStrictCompiling = false;
184 
185     mUniqueIndex = 0;
186 
187     mOutputLod0Function      = false;
188     mInsideDiscontinuousLoop = false;
189     mNestedLoopDepth         = 0;
190 
191     mExcessiveLoopIndex = nullptr;
192 
193     mStructureHLSL       = new StructureHLSL;
194     mUniformHLSL         = new UniformHLSL(shaderType, mStructureHLSL, outputType, uniforms);
195     mTextureFunctionHLSL = new TextureFunctionHLSL;
196     mImageFunctionHLSL   = new ImageFunctionHLSL;
197 
198     if (mOutputType == SH_HLSL_3_0_OUTPUT)
199     {
200         // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
201         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
202         // dx_ViewAdjust.
203         // In both cases total 3 uniform registers need to be reserved.
204         mUniformHLSL->reserveUniformRegisters(3);
205     }
206 
207     // Reserve registers for the default uniform block and driver constants
208     mUniformHLSL->reserveUniformBlockRegisters(2);
209 }
210 
~OutputHLSL()211 OutputHLSL::~OutputHLSL()
212 {
213     SafeDelete(mStructureHLSL);
214     SafeDelete(mUniformHLSL);
215     SafeDelete(mTextureFunctionHLSL);
216     SafeDelete(mImageFunctionHLSL);
217     for (auto &eqFunction : mStructEqualityFunctions)
218     {
219         SafeDelete(eqFunction);
220     }
221     for (auto &eqFunction : mArrayEqualityFunctions)
222     {
223         SafeDelete(eqFunction);
224     }
225 }
226 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)227 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
228 {
229     BuiltInFunctionEmulator builtInFunctionEmulator;
230     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
231     if ((mCompileOptions & SH_EMULATE_ISNAN_FLOAT_FUNCTION) != 0)
232     {
233         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
234                                                            mShaderVersion);
235     }
236 
237     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
238 
239     // Now that we are done changing the AST, do the analyses need for HLSL generation
240     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
241     ASSERT(success == CallDAG::INITDAG_SUCCESS);
242     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
243 
244     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
245     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
246     // the shader code. When we add shader storage blocks we might also consider an alternative
247     // solution, since the struct mapping won't work very well for shader storage blocks.
248 
249     // Output the body and footer first to determine what has to go in the header
250     mInfoSinkStack.push(&mBody);
251     treeRoot->traverse(this);
252     mInfoSinkStack.pop();
253 
254     mInfoSinkStack.push(&mFooter);
255     mInfoSinkStack.pop();
256 
257     mInfoSinkStack.push(&mHeader);
258     header(mHeader, std140Structs, &builtInFunctionEmulator);
259     mInfoSinkStack.pop();
260 
261     objSink << mHeader.c_str();
262     objSink << mBody.c_str();
263     objSink << mFooter.c_str();
264 
265     builtInFunctionEmulator.cleanup();
266 }
267 
getUniformBlockRegisterMap() const268 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
269 {
270     return mUniformHLSL->getUniformBlockRegisterMap();
271 }
272 
getUniformRegisterMap() const273 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
274 {
275     return mUniformHLSL->getUniformRegisterMap();
276 }
277 
structInitializerString(int indent,const TType & type,const TString & name) const278 TString OutputHLSL::structInitializerString(int indent,
279                                             const TType &type,
280                                             const TString &name) const
281 {
282     TString init;
283 
284     TString indentString;
285     for (int spaces = 0; spaces < indent; spaces++)
286     {
287         indentString += "    ";
288     }
289 
290     if (type.isArray())
291     {
292         init += indentString + "{\n";
293         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
294         {
295             TStringStream indexedString;
296             indexedString << name << "[" << arrayIndex << "]";
297             TType elementType = type;
298             elementType.toArrayElementType();
299             init += structInitializerString(indent + 1, elementType, indexedString.str());
300             if (arrayIndex < type.getOutermostArraySize() - 1)
301             {
302                 init += ",";
303             }
304             init += "\n";
305         }
306         init += indentString + "}";
307     }
308     else if (type.getBasicType() == EbtStruct)
309     {
310         init += indentString + "{\n";
311         const TStructure &structure = *type.getStruct();
312         const TFieldList &fields    = structure.fields();
313         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
314         {
315             const TField &field      = *fields[fieldIndex];
316             const TString &fieldName = name + "." + Decorate(field.name());
317             const TType &fieldType   = *field.type();
318 
319             init += structInitializerString(indent + 1, fieldType, fieldName);
320             if (fieldIndex < fields.size() - 1)
321             {
322                 init += ",";
323             }
324             init += "\n";
325         }
326         init += indentString + "}";
327     }
328     else
329     {
330         init += indentString + name;
331     }
332 
333     return init;
334 }
335 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const336 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
337 {
338     TString mappedStructs;
339 
340     for (auto &mappedStruct : std140Structs)
341     {
342         TInterfaceBlock *interfaceBlock =
343             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
344         const TString &interfaceBlockName = interfaceBlock->name();
345         const TName &instanceName         = mappedStruct.blockDeclarator->getName();
346         if (mReferencedUniformBlocks.count(interfaceBlockName) == 0 &&
347             (instanceName.getString() == "" ||
348              mReferencedUniformBlocks.count(instanceName.getString()) == 0))
349         {
350             continue;
351         }
352 
353         unsigned int instanceCount = 1u;
354         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
355         if (isInstanceArray)
356         {
357             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
358         }
359 
360         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
361              ++instanceArrayIndex)
362         {
363             TString originalName;
364             TString mappedName("map");
365 
366             if (instanceName.getString() != "")
367             {
368                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
369                 if (isInstanceArray)
370                     instanceStringArrayIndex = instanceArrayIndex;
371                 TString instanceString = mUniformHLSL->uniformBlockInstanceString(
372                     *interfaceBlock, instanceStringArrayIndex);
373                 originalName += instanceString;
374                 mappedName += instanceString;
375                 originalName += ".";
376                 mappedName += "_";
377             }
378 
379             TString fieldName = Decorate(mappedStruct.field->name());
380             originalName += fieldName;
381             mappedName += fieldName;
382 
383             TType *structType = mappedStruct.field->type();
384             mappedStructs +=
385                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
386 
387             if (structType->isArray())
388             {
389                 mappedStructs += ArrayString(*mappedStruct.field->type());
390             }
391 
392             mappedStructs += " =\n";
393             mappedStructs += structInitializerString(0, *structType, originalName);
394             mappedStructs += ";\n";
395         }
396     }
397     return mappedStructs;
398 }
399 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const400 void OutputHLSL::header(TInfoSinkBase &out,
401                         const std::vector<MappedStruct> &std140Structs,
402                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
403 {
404     TString varyings;
405     TString attributes;
406     TString mappedStructs = generateStructMapping(std140Structs);
407 
408     for (ReferencedSymbols::const_iterator varying = mReferencedVaryings.begin();
409          varying != mReferencedVaryings.end(); varying++)
410     {
411         const TType &type   = varying->second->getType();
412         const TString &name = varying->second->getSymbol();
413 
414         // Program linking depends on this exact format
415         varyings += "static " + InterpolationString(type.getQualifier()) + " " + TypeString(type) +
416                     " " + Decorate(name) + ArrayString(type) + " = " + initializer(type) + ";\n";
417     }
418 
419     for (ReferencedSymbols::const_iterator attribute = mReferencedAttributes.begin();
420          attribute != mReferencedAttributes.end(); attribute++)
421     {
422         const TType &type   = attribute->second->getType();
423         const TString &name = attribute->second->getSymbol();
424 
425         attributes += "static " + TypeString(type) + " " + Decorate(name) + ArrayString(type) +
426                       " = " + initializer(type) + ";\n";
427     }
428 
429     out << mStructureHLSL->structsHeader();
430 
431     mUniformHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
432     out << mUniformHLSL->uniformBlocksHeader(mReferencedUniformBlocks);
433 
434     if (!mEqualityFunctions.empty())
435     {
436         out << "\n// Equality functions\n\n";
437         for (const auto &eqFunction : mEqualityFunctions)
438         {
439             out << eqFunction->functionDefinition << "\n";
440         }
441     }
442     if (!mArrayAssignmentFunctions.empty())
443     {
444         out << "\n// Assignment functions\n\n";
445         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
446         {
447             out << assignmentFunction.functionDefinition << "\n";
448         }
449     }
450     if (!mArrayConstructIntoFunctions.empty())
451     {
452         out << "\n// Array constructor functions\n\n";
453         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
454         {
455             out << constructIntoFunction.functionDefinition << "\n";
456         }
457     }
458 
459     if (mUsesDiscardRewriting)
460     {
461         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
462     }
463 
464     if (mUsesNestedBreak)
465     {
466         out << "#define ANGLE_USES_NESTED_BREAK\n";
467     }
468 
469     if (mRequiresIEEEStrictCompiling)
470     {
471         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
472     }
473 
474     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
475            "#define LOOP [loop]\n"
476            "#define FLATTEN [flatten]\n"
477            "#else\n"
478            "#define LOOP\n"
479            "#define FLATTEN\n"
480            "#endif\n";
481 
482     if (mShaderType == GL_FRAGMENT_SHADER)
483     {
484         const bool usingMRTExtension =
485             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
486 
487         out << "// Varyings\n";
488         out << varyings;
489         out << "\n";
490 
491         if (mShaderVersion >= 300)
492         {
493             for (ReferencedSymbols::const_iterator outputVariableIt =
494                      mReferencedOutputVariables.begin();
495                  outputVariableIt != mReferencedOutputVariables.end(); outputVariableIt++)
496             {
497                 const TString &variableName = outputVariableIt->first;
498                 const TType &variableType   = outputVariableIt->second->getType();
499 
500                 out << "static " + TypeString(variableType) + " out_" + variableName +
501                            ArrayString(variableType) + " = " + initializer(variableType) + ";\n";
502             }
503         }
504         else
505         {
506             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
507 
508             out << "static float4 gl_Color[" << numColorValues << "] =\n"
509                                                                   "{\n";
510             for (unsigned int i = 0; i < numColorValues; i++)
511             {
512                 out << "    float4(0, 0, 0, 0)";
513                 if (i + 1 != numColorValues)
514                 {
515                     out << ",";
516                 }
517                 out << "\n";
518             }
519 
520             out << "};\n";
521         }
522 
523         if (mUsesFragDepth)
524         {
525             out << "static float gl_Depth = 0.0;\n";
526         }
527 
528         if (mUsesFragCoord)
529         {
530             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
531         }
532 
533         if (mUsesPointCoord)
534         {
535             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
536         }
537 
538         if (mUsesFrontFacing)
539         {
540             out << "static bool gl_FrontFacing = false;\n";
541         }
542 
543         out << "\n";
544 
545         if (mUsesDepthRange)
546         {
547             out << "struct gl_DepthRangeParameters\n"
548                    "{\n"
549                    "    float near;\n"
550                    "    float far;\n"
551                    "    float diff;\n"
552                    "};\n"
553                    "\n";
554         }
555 
556         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
557         {
558             out << "cbuffer DriverConstants : register(b1)\n"
559                    "{\n";
560 
561             if (mUsesDepthRange)
562             {
563                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
564             }
565 
566             if (mUsesFragCoord)
567             {
568                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
569             }
570 
571             if (mUsesFragCoord || mUsesFrontFacing)
572             {
573                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
574             }
575 
576             if (mUsesFragCoord)
577             {
578                 // dx_ViewScale is only used in the fragment shader to correct
579                 // the value for glFragCoord if necessary
580                 out << "    float2 dx_ViewScale : packoffset(c3);\n";
581             }
582 
583             if (mHasMultiviewExtensionEnabled)
584             {
585                 // We have to add a value which we can use to keep track of which multi-view code
586                 // path is to be selected in the GS.
587                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
588             }
589 
590             if (mOutputType == SH_HLSL_4_1_OUTPUT)
591             {
592                 mUniformHLSL->samplerMetadataUniforms(out, "c4");
593             }
594 
595             out << "};\n";
596         }
597         else
598         {
599             if (mUsesDepthRange)
600             {
601                 out << "uniform float3 dx_DepthRange : register(c0);";
602             }
603 
604             if (mUsesFragCoord)
605             {
606                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
607             }
608 
609             if (mUsesFragCoord || mUsesFrontFacing)
610             {
611                 out << "uniform float3 dx_DepthFront : register(c2);\n";
612             }
613         }
614 
615         out << "\n";
616 
617         if (mUsesDepthRange)
618         {
619             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
620                    "dx_DepthRange.y, dx_DepthRange.z};\n"
621                    "\n";
622         }
623 
624         if (!mappedStructs.empty())
625         {
626             out << "// Structures from std140 blocks with padding removed\n";
627             out << "\n";
628             out << mappedStructs;
629             out << "\n";
630         }
631 
632         if (usingMRTExtension && mNumRenderTargets > 1)
633         {
634             out << "#define GL_USES_MRT\n";
635         }
636 
637         if (mUsesFragColor)
638         {
639             out << "#define GL_USES_FRAG_COLOR\n";
640         }
641 
642         if (mUsesFragData)
643         {
644             out << "#define GL_USES_FRAG_DATA\n";
645         }
646     }
647     else if (mShaderType == GL_VERTEX_SHADER)
648     {
649         out << "// Attributes\n";
650         out << attributes;
651         out << "\n"
652                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
653 
654         if (mUsesPointSize)
655         {
656             out << "static float gl_PointSize = float(1);\n";
657         }
658 
659         if (mUsesInstanceID)
660         {
661             out << "static int gl_InstanceID;";
662         }
663 
664         if (mUsesVertexID)
665         {
666             out << "static int gl_VertexID;";
667         }
668 
669         out << "\n"
670                "// Varyings\n";
671         out << varyings;
672         out << "\n";
673 
674         if (mUsesDepthRange)
675         {
676             out << "struct gl_DepthRangeParameters\n"
677                    "{\n"
678                    "    float near;\n"
679                    "    float far;\n"
680                    "    float diff;\n"
681                    "};\n"
682                    "\n";
683         }
684 
685         if (mOutputType == SH_HLSL_4_1_OUTPUT || mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
686         {
687             out << "cbuffer DriverConstants : register(b1)\n"
688                    "{\n";
689 
690             if (mUsesDepthRange)
691             {
692                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
693             }
694 
695             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
696             // shaders. However, we declare it for all shaders (including Feature Level 10+).
697             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
698             // if it's unused.
699             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
700             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
701             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
702 
703             if (mHasMultiviewExtensionEnabled)
704             {
705                 // We have to add a value which we can use to keep track of which multi-view code
706                 // path is to be selected in the GS.
707                 out << "    float multiviewSelectViewportIndex : packoffset(c3.z);\n";
708             }
709 
710             if (mOutputType == SH_HLSL_4_1_OUTPUT)
711             {
712                 mUniformHLSL->samplerMetadataUniforms(out, "c4");
713             }
714 
715             out << "};\n"
716                    "\n";
717         }
718         else
719         {
720             if (mUsesDepthRange)
721             {
722                 out << "uniform float3 dx_DepthRange : register(c0);\n";
723             }
724 
725             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
726             out << "uniform float2 dx_ViewCoords : register(c2);\n"
727                    "\n";
728         }
729 
730         if (mUsesDepthRange)
731         {
732             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
733                    "dx_DepthRange.y, dx_DepthRange.z};\n"
734                    "\n";
735         }
736 
737         if (!mappedStructs.empty())
738         {
739             out << "// Structures from std140 blocks with padding removed\n";
740             out << "\n";
741             out << mappedStructs;
742             out << "\n";
743         }
744     }
745     else  // Compute shader
746     {
747         ASSERT(mShaderType == GL_COMPUTE_SHADER);
748 
749         out << "cbuffer DriverConstants : register(b1)\n"
750                "{\n";
751         if (mUsesNumWorkGroups)
752         {
753             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
754         }
755         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
756         mUniformHLSL->samplerMetadataUniforms(out, "c1");
757         out << "};\n";
758 
759         // Follow built-in variables would be initialized in
760         // DynamicHLSL::generateComputeShaderLinkHLSL, if they
761         // are used in compute shader.
762         if (mUsesWorkGroupID)
763         {
764             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
765         }
766 
767         if (mUsesLocalInvocationID)
768         {
769             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
770         }
771 
772         if (mUsesGlobalInvocationID)
773         {
774             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
775         }
776 
777         if (mUsesLocalInvocationIndex)
778         {
779             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
780         }
781     }
782 
783     bool getDimensionsIgnoresBaseLevel =
784         (mCompileOptions & SH_HLSL_GET_DIMENSIONS_IGNORES_BASE_LEVEL) != 0;
785     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
786     mImageFunctionHLSL->imageFunctionHeader(out);
787 
788     if (mUsesFragCoord)
789     {
790         out << "#define GL_USES_FRAG_COORD\n";
791     }
792 
793     if (mUsesPointCoord)
794     {
795         out << "#define GL_USES_POINT_COORD\n";
796     }
797 
798     if (mUsesFrontFacing)
799     {
800         out << "#define GL_USES_FRONT_FACING\n";
801     }
802 
803     if (mUsesPointSize)
804     {
805         out << "#define GL_USES_POINT_SIZE\n";
806     }
807 
808     if (mHasMultiviewExtensionEnabled)
809     {
810         out << "#define GL_ANGLE_MULTIVIEW_ENABLED\n";
811     }
812 
813     if (mUsesViewID)
814     {
815         out << "#define GL_USES_VIEW_ID\n";
816     }
817 
818     if (mUsesFragDepth)
819     {
820         out << "#define GL_USES_FRAG_DEPTH\n";
821     }
822 
823     if (mUsesDepthRange)
824     {
825         out << "#define GL_USES_DEPTH_RANGE\n";
826     }
827 
828     if (mUsesNumWorkGroups)
829     {
830         out << "#define GL_USES_NUM_WORK_GROUPS\n";
831     }
832 
833     if (mUsesWorkGroupID)
834     {
835         out << "#define GL_USES_WORK_GROUP_ID\n";
836     }
837 
838     if (mUsesLocalInvocationID)
839     {
840         out << "#define GL_USES_LOCAL_INVOCATION_ID\n";
841     }
842 
843     if (mUsesGlobalInvocationID)
844     {
845         out << "#define GL_USES_GLOBAL_INVOCATION_ID\n";
846     }
847 
848     if (mUsesLocalInvocationIndex)
849     {
850         out << "#define GL_USES_LOCAL_INVOCATION_INDEX\n";
851     }
852 
853     if (mUsesXor)
854     {
855         out << "bool xor(bool p, bool q)\n"
856                "{\n"
857                "    return (p || q) && !(p && q);\n"
858                "}\n"
859                "\n";
860     }
861 
862     builtInFunctionEmulator->outputEmulatedFunctions(out);
863 }
864 
visitSymbol(TIntermSymbol * node)865 void OutputHLSL::visitSymbol(TIntermSymbol *node)
866 {
867     TInfoSinkBase &out = getInfoSink();
868 
869     // Handle accessing std140 structs by value
870     if (IsInStd140InterfaceBlock(node) && node->getBasicType() == EbtStruct)
871     {
872         out << "map";
873     }
874 
875     TString name = node->getSymbol();
876 
877     if (name == "gl_DepthRange")
878     {
879         mUsesDepthRange = true;
880         out << name;
881     }
882     else
883     {
884         const TType &nodeType = node->getType();
885         TQualifier qualifier = node->getQualifier();
886 
887         ensureStructDefined(nodeType);
888 
889         if (qualifier == EvqUniform)
890         {
891             const TInterfaceBlock *interfaceBlock = nodeType.getInterfaceBlock();
892 
893             if (interfaceBlock)
894             {
895                 mReferencedUniformBlocks[interfaceBlock->name()] = node;
896             }
897             else
898             {
899                 mReferencedUniforms[name] = node;
900             }
901 
902             out << DecorateVariableIfNeeded(node->getName());
903         }
904         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
905         {
906             mReferencedAttributes[name] = node;
907             out << Decorate(name);
908         }
909         else if (IsVarying(qualifier))
910         {
911             mReferencedVaryings[name] = node;
912             out << Decorate(name);
913             if (name == "ViewID_OVR")
914             {
915                 mUsesViewID = true;
916             }
917         }
918         else if (qualifier == EvqFragmentOut)
919         {
920             mReferencedOutputVariables[name] = node;
921             out << "out_" << name;
922         }
923         else if (qualifier == EvqFragColor)
924         {
925             out << "gl_Color[0]";
926             mUsesFragColor = true;
927         }
928         else if (qualifier == EvqFragData)
929         {
930             out << "gl_Color";
931             mUsesFragData = true;
932         }
933         else if (qualifier == EvqFragCoord)
934         {
935             mUsesFragCoord = true;
936             out << name;
937         }
938         else if (qualifier == EvqPointCoord)
939         {
940             mUsesPointCoord = true;
941             out << name;
942         }
943         else if (qualifier == EvqFrontFacing)
944         {
945             mUsesFrontFacing = true;
946             out << name;
947         }
948         else if (qualifier == EvqPointSize)
949         {
950             mUsesPointSize = true;
951             out << name;
952         }
953         else if (qualifier == EvqInstanceID)
954         {
955             mUsesInstanceID = true;
956             out << name;
957         }
958         else if (qualifier == EvqVertexID)
959         {
960             mUsesVertexID = true;
961             out << name;
962         }
963         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
964         {
965             mUsesFragDepth = true;
966             out << "gl_Depth";
967         }
968         else if (qualifier == EvqNumWorkGroups)
969         {
970             mUsesNumWorkGroups = true;
971             out << name;
972         }
973         else if (qualifier == EvqWorkGroupID)
974         {
975             mUsesWorkGroupID = true;
976             out << name;
977         }
978         else if (qualifier == EvqLocalInvocationID)
979         {
980             mUsesLocalInvocationID = true;
981             out << name;
982         }
983         else if (qualifier == EvqGlobalInvocationID)
984         {
985             mUsesGlobalInvocationID = true;
986             out << name;
987         }
988         else if (qualifier == EvqLocalInvocationIndex)
989         {
990             mUsesLocalInvocationIndex = true;
991             out << name;
992         }
993         else
994         {
995             out << DecorateVariableIfNeeded(node->getName());
996         }
997     }
998 }
999 
visitRaw(TIntermRaw * node)1000 void OutputHLSL::visitRaw(TIntermRaw *node)
1001 {
1002     getInfoSink() << node->getRawText();
1003 }
1004 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1005 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1006 {
1007     if (type.isScalar() && !type.isArray())
1008     {
1009         if (op == EOpEqual)
1010         {
1011             outputTriplet(out, visit, "(", " == ", ")");
1012         }
1013         else
1014         {
1015             outputTriplet(out, visit, "(", " != ", ")");
1016         }
1017     }
1018     else
1019     {
1020         if (visit == PreVisit && op == EOpNotEqual)
1021         {
1022             out << "!";
1023         }
1024 
1025         if (type.isArray())
1026         {
1027             const TString &functionName = addArrayEqualityFunction(type);
1028             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1029         }
1030         else if (type.getBasicType() == EbtStruct)
1031         {
1032             const TStructure &structure = *type.getStruct();
1033             const TString &functionName = addStructEqualityFunction(structure);
1034             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1035         }
1036         else
1037         {
1038             ASSERT(type.isMatrix() || type.isVector());
1039             outputTriplet(out, visit, "all(", " == ", ")");
1040         }
1041     }
1042 }
1043 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1044 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1045 {
1046     if (type.isArray())
1047     {
1048         const TString &functionName = addArrayAssignmentFunction(type);
1049         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1050     }
1051     else
1052     {
1053         outputTriplet(out, visit, "(", " = ", ")");
1054     }
1055 }
1056 
ancestorEvaluatesToSamplerInStruct()1057 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1058 {
1059     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1060     {
1061         TIntermNode *ancestor               = getAncestorNode(n);
1062         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1063         if (ancestorBinary == nullptr)
1064         {
1065             return false;
1066         }
1067         switch (ancestorBinary->getOp())
1068         {
1069             case EOpIndexDirectStruct:
1070             {
1071                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1072                 const TIntermConstantUnion *index =
1073                     ancestorBinary->getRight()->getAsConstantUnion();
1074                 const TField *field = structure->fields()[index->getIConst(0)];
1075                 if (IsSampler(field->type()->getBasicType()))
1076                 {
1077                     return true;
1078                 }
1079                 break;
1080             }
1081             case EOpIndexDirect:
1082                 break;
1083             default:
1084                 // Returning a sampler from indirect indexing is not supported.
1085                 return false;
1086         }
1087     }
1088     return false;
1089 }
1090 
visitSwizzle(Visit visit,TIntermSwizzle * node)1091 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1092 {
1093     TInfoSinkBase &out = getInfoSink();
1094     if (visit == PostVisit)
1095     {
1096         out << ".";
1097         node->writeOffsetsAsXYZW(&out);
1098     }
1099     return true;
1100 }
1101 
visitBinary(Visit visit,TIntermBinary * node)1102 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1103 {
1104     TInfoSinkBase &out = getInfoSink();
1105 
1106     switch (node->getOp())
1107     {
1108         case EOpComma:
1109             outputTriplet(out, visit, "(", ", ", ")");
1110             break;
1111         case EOpAssign:
1112             if (node->isArray())
1113             {
1114                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1115                 if (rightAgg != nullptr && rightAgg->isConstructor())
1116                 {
1117                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1118                     out << functionName << "(";
1119                     node->getLeft()->traverse(this);
1120                     TIntermSequence *seq = rightAgg->getSequence();
1121                     for (auto &arrayElement : *seq)
1122                     {
1123                         out << ", ";
1124                         arrayElement->traverse(this);
1125                     }
1126                     out << ")";
1127                     return false;
1128                 }
1129                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1130                 // function call is assigned.
1131                 ASSERT(rightAgg == nullptr);
1132             }
1133             outputAssign(visit, node->getType(), out);
1134             break;
1135         case EOpInitialize:
1136             if (visit == PreVisit)
1137             {
1138                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1139                 ASSERT(symbolNode);
1140                 TIntermTyped *expression = node->getRight();
1141 
1142                 // Global initializers must be constant at this point.
1143                 ASSERT(symbolNode->getQualifier() != EvqGlobal ||
1144                        canWriteAsHLSLLiteral(expression));
1145 
1146                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1147                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1148                 // new variable is created before the assignment is evaluated), so we need to
1149                 // convert
1150                 // this to "float t = x, x = t;".
1151                 if (writeSameSymbolInitializer(out, symbolNode, expression))
1152                 {
1153                     // Skip initializing the rest of the expression
1154                     return false;
1155                 }
1156                 else if (writeConstantInitialization(out, symbolNode, expression))
1157                 {
1158                     return false;
1159                 }
1160             }
1161             else if (visit == InVisit)
1162             {
1163                 out << " = ";
1164             }
1165             break;
1166         case EOpAddAssign:
1167             outputTriplet(out, visit, "(", " += ", ")");
1168             break;
1169         case EOpSubAssign:
1170             outputTriplet(out, visit, "(", " -= ", ")");
1171             break;
1172         case EOpMulAssign:
1173             outputTriplet(out, visit, "(", " *= ", ")");
1174             break;
1175         case EOpVectorTimesScalarAssign:
1176             outputTriplet(out, visit, "(", " *= ", ")");
1177             break;
1178         case EOpMatrixTimesScalarAssign:
1179             outputTriplet(out, visit, "(", " *= ", ")");
1180             break;
1181         case EOpVectorTimesMatrixAssign:
1182             if (visit == PreVisit)
1183             {
1184                 out << "(";
1185             }
1186             else if (visit == InVisit)
1187             {
1188                 out << " = mul(";
1189                 node->getLeft()->traverse(this);
1190                 out << ", transpose(";
1191             }
1192             else
1193             {
1194                 out << ")))";
1195             }
1196             break;
1197         case EOpMatrixTimesMatrixAssign:
1198             if (visit == PreVisit)
1199             {
1200                 out << "(";
1201             }
1202             else if (visit == InVisit)
1203             {
1204                 out << " = transpose(mul(transpose(";
1205                 node->getLeft()->traverse(this);
1206                 out << "), transpose(";
1207             }
1208             else
1209             {
1210                 out << "))))";
1211             }
1212             break;
1213         case EOpDivAssign:
1214             outputTriplet(out, visit, "(", " /= ", ")");
1215             break;
1216         case EOpIModAssign:
1217             outputTriplet(out, visit, "(", " %= ", ")");
1218             break;
1219         case EOpBitShiftLeftAssign:
1220             outputTriplet(out, visit, "(", " <<= ", ")");
1221             break;
1222         case EOpBitShiftRightAssign:
1223             outputTriplet(out, visit, "(", " >>= ", ")");
1224             break;
1225         case EOpBitwiseAndAssign:
1226             outputTriplet(out, visit, "(", " &= ", ")");
1227             break;
1228         case EOpBitwiseXorAssign:
1229             outputTriplet(out, visit, "(", " ^= ", ")");
1230             break;
1231         case EOpBitwiseOrAssign:
1232             outputTriplet(out, visit, "(", " |= ", ")");
1233             break;
1234         case EOpIndexDirect:
1235         {
1236             const TType &leftType = node->getLeft()->getType();
1237             if (leftType.isInterfaceBlock())
1238             {
1239                 if (visit == PreVisit)
1240                 {
1241                     TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1242                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1243                     mReferencedUniformBlocks[interfaceBlock->instanceName()] =
1244                         node->getLeft()->getAsSymbolNode();
1245                     out << mUniformHLSL->uniformBlockInstanceString(*interfaceBlock, arrayIndex);
1246                     return false;
1247                 }
1248             }
1249             else if (ancestorEvaluatesToSamplerInStruct())
1250             {
1251                 // All parts of an expression that access a sampler in a struct need to use _ as
1252                 // separator to access the sampler variable that has been moved out of the struct.
1253                 outputTriplet(out, visit, "", "_", "");
1254             }
1255             else
1256             {
1257                 outputTriplet(out, visit, "", "[", "]");
1258             }
1259         }
1260         break;
1261         case EOpIndexIndirect:
1262             // We do not currently support indirect references to interface blocks
1263             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1264             outputTriplet(out, visit, "", "[", "]");
1265             break;
1266         case EOpIndexDirectStruct:
1267         {
1268             const TStructure *structure       = node->getLeft()->getType().getStruct();
1269             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1270             const TField *field               = structure->fields()[index->getIConst(0)];
1271 
1272             // In cases where indexing returns a sampler, we need to access the sampler variable
1273             // that has been moved out of the struct.
1274             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1275             if (visit == PreVisit && indexingReturnsSampler)
1276             {
1277                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1278                 // This prefix is only output at the beginning of the indexing expression, which
1279                 // may have multiple parts.
1280                 out << "angle";
1281             }
1282             if (!indexingReturnsSampler)
1283             {
1284                 // All parts of an expression that access a sampler in a struct need to use _ as
1285                 // separator to access the sampler variable that has been moved out of the struct.
1286                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1287             }
1288             if (visit == InVisit)
1289             {
1290                 if (indexingReturnsSampler)
1291                 {
1292                     out << "_" + field->name();
1293                 }
1294                 else
1295                 {
1296                     out << "." + DecorateField(field->name(), *structure);
1297                 }
1298 
1299                 return false;
1300             }
1301         }
1302         break;
1303         case EOpIndexDirectInterfaceBlock:
1304         {
1305             bool structInStd140Block =
1306                 node->getBasicType() == EbtStruct && IsInStd140InterfaceBlock(node->getLeft());
1307             if (visit == PreVisit && structInStd140Block)
1308             {
1309                 out << "map";
1310             }
1311             if (visit == InVisit)
1312             {
1313                 const TInterfaceBlock *interfaceBlock =
1314                     node->getLeft()->getType().getInterfaceBlock();
1315                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1316                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1317                 if (structInStd140Block)
1318                 {
1319                     out << "_";
1320                 }
1321                 else
1322                 {
1323                     out << ".";
1324                 }
1325                 out << Decorate(field->name());
1326 
1327                 return false;
1328             }
1329             break;
1330         }
1331         case EOpAdd:
1332             outputTriplet(out, visit, "(", " + ", ")");
1333             break;
1334         case EOpSub:
1335             outputTriplet(out, visit, "(", " - ", ")");
1336             break;
1337         case EOpMul:
1338             outputTriplet(out, visit, "(", " * ", ")");
1339             break;
1340         case EOpDiv:
1341             outputTriplet(out, visit, "(", " / ", ")");
1342             break;
1343         case EOpIMod:
1344             outputTriplet(out, visit, "(", " % ", ")");
1345             break;
1346         case EOpBitShiftLeft:
1347             outputTriplet(out, visit, "(", " << ", ")");
1348             break;
1349         case EOpBitShiftRight:
1350             outputTriplet(out, visit, "(", " >> ", ")");
1351             break;
1352         case EOpBitwiseAnd:
1353             outputTriplet(out, visit, "(", " & ", ")");
1354             break;
1355         case EOpBitwiseXor:
1356             outputTriplet(out, visit, "(", " ^ ", ")");
1357             break;
1358         case EOpBitwiseOr:
1359             outputTriplet(out, visit, "(", " | ", ")");
1360             break;
1361         case EOpEqual:
1362         case EOpNotEqual:
1363             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1364             break;
1365         case EOpLessThan:
1366             outputTriplet(out, visit, "(", " < ", ")");
1367             break;
1368         case EOpGreaterThan:
1369             outputTriplet(out, visit, "(", " > ", ")");
1370             break;
1371         case EOpLessThanEqual:
1372             outputTriplet(out, visit, "(", " <= ", ")");
1373             break;
1374         case EOpGreaterThanEqual:
1375             outputTriplet(out, visit, "(", " >= ", ")");
1376             break;
1377         case EOpVectorTimesScalar:
1378             outputTriplet(out, visit, "(", " * ", ")");
1379             break;
1380         case EOpMatrixTimesScalar:
1381             outputTriplet(out, visit, "(", " * ", ")");
1382             break;
1383         case EOpVectorTimesMatrix:
1384             outputTriplet(out, visit, "mul(", ", transpose(", "))");
1385             break;
1386         case EOpMatrixTimesVector:
1387             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
1388             break;
1389         case EOpMatrixTimesMatrix:
1390             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
1391             break;
1392         case EOpLogicalOr:
1393             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
1394             // been unfolded.
1395             ASSERT(!node->getRight()->hasSideEffects());
1396             outputTriplet(out, visit, "(", " || ", ")");
1397             return true;
1398         case EOpLogicalXor:
1399             mUsesXor = true;
1400             outputTriplet(out, visit, "xor(", ", ", ")");
1401             break;
1402         case EOpLogicalAnd:
1403             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
1404             // been unfolded.
1405             ASSERT(!node->getRight()->hasSideEffects());
1406             outputTriplet(out, visit, "(", " && ", ")");
1407             return true;
1408         default:
1409             UNREACHABLE();
1410     }
1411 
1412     return true;
1413 }
1414 
visitUnary(Visit visit,TIntermUnary * node)1415 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
1416 {
1417     TInfoSinkBase &out = getInfoSink();
1418 
1419     switch (node->getOp())
1420     {
1421         case EOpNegative:
1422             outputTriplet(out, visit, "(-", "", ")");
1423             break;
1424         case EOpPositive:
1425             outputTriplet(out, visit, "(+", "", ")");
1426             break;
1427         case EOpLogicalNot:
1428             outputTriplet(out, visit, "(!", "", ")");
1429             break;
1430         case EOpBitwiseNot:
1431             outputTriplet(out, visit, "(~", "", ")");
1432             break;
1433         case EOpPostIncrement:
1434             outputTriplet(out, visit, "(", "", "++)");
1435             break;
1436         case EOpPostDecrement:
1437             outputTriplet(out, visit, "(", "", "--)");
1438             break;
1439         case EOpPreIncrement:
1440             outputTriplet(out, visit, "(++", "", ")");
1441             break;
1442         case EOpPreDecrement:
1443             outputTriplet(out, visit, "(--", "", ")");
1444             break;
1445         case EOpRadians:
1446             outputTriplet(out, visit, "radians(", "", ")");
1447             break;
1448         case EOpDegrees:
1449             outputTriplet(out, visit, "degrees(", "", ")");
1450             break;
1451         case EOpSin:
1452             outputTriplet(out, visit, "sin(", "", ")");
1453             break;
1454         case EOpCos:
1455             outputTriplet(out, visit, "cos(", "", ")");
1456             break;
1457         case EOpTan:
1458             outputTriplet(out, visit, "tan(", "", ")");
1459             break;
1460         case EOpAsin:
1461             outputTriplet(out, visit, "asin(", "", ")");
1462             break;
1463         case EOpAcos:
1464             outputTriplet(out, visit, "acos(", "", ")");
1465             break;
1466         case EOpAtan:
1467             outputTriplet(out, visit, "atan(", "", ")");
1468             break;
1469         case EOpSinh:
1470             outputTriplet(out, visit, "sinh(", "", ")");
1471             break;
1472         case EOpCosh:
1473             outputTriplet(out, visit, "cosh(", "", ")");
1474             break;
1475         case EOpTanh:
1476             outputTriplet(out, visit, "tanh(", "", ")");
1477             break;
1478         case EOpAsinh:
1479         case EOpAcosh:
1480         case EOpAtanh:
1481             ASSERT(node->getUseEmulatedFunction());
1482             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1483             break;
1484         case EOpExp:
1485             outputTriplet(out, visit, "exp(", "", ")");
1486             break;
1487         case EOpLog:
1488             outputTriplet(out, visit, "log(", "", ")");
1489             break;
1490         case EOpExp2:
1491             outputTriplet(out, visit, "exp2(", "", ")");
1492             break;
1493         case EOpLog2:
1494             outputTriplet(out, visit, "log2(", "", ")");
1495             break;
1496         case EOpSqrt:
1497             outputTriplet(out, visit, "sqrt(", "", ")");
1498             break;
1499         case EOpInverseSqrt:
1500             outputTriplet(out, visit, "rsqrt(", "", ")");
1501             break;
1502         case EOpAbs:
1503             outputTriplet(out, visit, "abs(", "", ")");
1504             break;
1505         case EOpSign:
1506             outputTriplet(out, visit, "sign(", "", ")");
1507             break;
1508         case EOpFloor:
1509             outputTriplet(out, visit, "floor(", "", ")");
1510             break;
1511         case EOpTrunc:
1512             outputTriplet(out, visit, "trunc(", "", ")");
1513             break;
1514         case EOpRound:
1515             outputTriplet(out, visit, "round(", "", ")");
1516             break;
1517         case EOpRoundEven:
1518             ASSERT(node->getUseEmulatedFunction());
1519             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1520             break;
1521         case EOpCeil:
1522             outputTriplet(out, visit, "ceil(", "", ")");
1523             break;
1524         case EOpFract:
1525             outputTriplet(out, visit, "frac(", "", ")");
1526             break;
1527         case EOpIsNan:
1528             if (node->getUseEmulatedFunction())
1529                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
1530             else
1531                 outputTriplet(out, visit, "isnan(", "", ")");
1532             mRequiresIEEEStrictCompiling = true;
1533             break;
1534         case EOpIsInf:
1535             outputTriplet(out, visit, "isinf(", "", ")");
1536             break;
1537         case EOpFloatBitsToInt:
1538             outputTriplet(out, visit, "asint(", "", ")");
1539             break;
1540         case EOpFloatBitsToUint:
1541             outputTriplet(out, visit, "asuint(", "", ")");
1542             break;
1543         case EOpIntBitsToFloat:
1544             outputTriplet(out, visit, "asfloat(", "", ")");
1545             break;
1546         case EOpUintBitsToFloat:
1547             outputTriplet(out, visit, "asfloat(", "", ")");
1548             break;
1549         case EOpPackSnorm2x16:
1550         case EOpPackUnorm2x16:
1551         case EOpPackHalf2x16:
1552         case EOpUnpackSnorm2x16:
1553         case EOpUnpackUnorm2x16:
1554         case EOpUnpackHalf2x16:
1555         case EOpPackUnorm4x8:
1556         case EOpPackSnorm4x8:
1557         case EOpUnpackUnorm4x8:
1558         case EOpUnpackSnorm4x8:
1559             ASSERT(node->getUseEmulatedFunction());
1560             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1561             break;
1562         case EOpLength:
1563             outputTriplet(out, visit, "length(", "", ")");
1564             break;
1565         case EOpNormalize:
1566             outputTriplet(out, visit, "normalize(", "", ")");
1567             break;
1568         case EOpDFdx:
1569             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1570             {
1571                 outputTriplet(out, visit, "(", "", ", 0.0)");
1572             }
1573             else
1574             {
1575                 outputTriplet(out, visit, "ddx(", "", ")");
1576             }
1577             break;
1578         case EOpDFdy:
1579             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1580             {
1581                 outputTriplet(out, visit, "(", "", ", 0.0)");
1582             }
1583             else
1584             {
1585                 outputTriplet(out, visit, "ddy(", "", ")");
1586             }
1587             break;
1588         case EOpFwidth:
1589             if (mInsideDiscontinuousLoop || mOutputLod0Function)
1590             {
1591                 outputTriplet(out, visit, "(", "", ", 0.0)");
1592             }
1593             else
1594             {
1595                 outputTriplet(out, visit, "fwidth(", "", ")");
1596             }
1597             break;
1598         case EOpTranspose:
1599             outputTriplet(out, visit, "transpose(", "", ")");
1600             break;
1601         case EOpDeterminant:
1602             outputTriplet(out, visit, "determinant(transpose(", "", "))");
1603             break;
1604         case EOpInverse:
1605             ASSERT(node->getUseEmulatedFunction());
1606             writeEmulatedFunctionTriplet(out, visit, node->getOp());
1607             break;
1608 
1609         case EOpAny:
1610             outputTriplet(out, visit, "any(", "", ")");
1611             break;
1612         case EOpAll:
1613             outputTriplet(out, visit, "all(", "", ")");
1614             break;
1615         case EOpLogicalNotComponentWise:
1616             outputTriplet(out, visit, "(!", "", ")");
1617             break;
1618         case EOpBitfieldReverse:
1619             outputTriplet(out, visit, "reversebits(", "", ")");
1620             break;
1621         case EOpBitCount:
1622             outputTriplet(out, visit, "countbits(", "", ")");
1623             break;
1624         case EOpFindLSB:
1625             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
1626             // in GLSLTest and results are consistent with GL.
1627             outputTriplet(out, visit, "firstbitlow(", "", ")");
1628             break;
1629         case EOpFindMSB:
1630             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
1631             // tested in GLSLTest and results are consistent with GL.
1632             outputTriplet(out, visit, "firstbithigh(", "", ")");
1633             break;
1634         default:
1635             UNREACHABLE();
1636     }
1637 
1638     return true;
1639 }
1640 
samplerNamePrefixFromStruct(TIntermTyped * node)1641 TString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
1642 {
1643     if (node->getAsSymbolNode())
1644     {
1645         return node->getAsSymbolNode()->getSymbol();
1646     }
1647     TIntermBinary *nodeBinary = node->getAsBinaryNode();
1648     switch (nodeBinary->getOp())
1649     {
1650         case EOpIndexDirect:
1651         {
1652             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
1653 
1654             TInfoSinkBase prefixSink;
1655             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
1656             return TString(prefixSink.c_str());
1657         }
1658         case EOpIndexDirectStruct:
1659         {
1660             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
1661             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
1662             const TField *field = s->fields()[index];
1663 
1664             TInfoSinkBase prefixSink;
1665             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
1666                        << field->name();
1667             return TString(prefixSink.c_str());
1668         }
1669         default:
1670             UNREACHABLE();
1671             return TString("");
1672     }
1673 }
1674 
visitBlock(Visit visit,TIntermBlock * node)1675 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
1676 {
1677     TInfoSinkBase &out = getInfoSink();
1678 
1679     if (mInsideFunction)
1680     {
1681         outputLineDirective(out, node->getLine().first_line);
1682         out << "{\n";
1683     }
1684 
1685     for (TIntermNode *statement : *node->getSequence())
1686     {
1687         outputLineDirective(out, statement->getLine().first_line);
1688 
1689         statement->traverse(this);
1690 
1691         // Don't output ; after case labels, they're terminated by :
1692         // This is needed especially since outputting a ; after a case statement would turn empty
1693         // case statements into non-empty case statements, disallowing fall-through from them.
1694         // Also the output code is clearer if we don't output ; after statements where it is not
1695         // needed:
1696         //  * if statements
1697         //  * switch statements
1698         //  * blocks
1699         //  * function definitions
1700         //  * loops (do-while loops output the semicolon in VisitLoop)
1701         //  * declarations that don't generate output.
1702         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
1703             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
1704             statement->getAsSwitchNode() == nullptr &&
1705             statement->getAsFunctionDefinition() == nullptr &&
1706             (statement->getAsDeclarationNode() == nullptr ||
1707              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
1708             statement->getAsInvariantDeclarationNode() == nullptr)
1709         {
1710             out << ";\n";
1711         }
1712     }
1713 
1714     if (mInsideFunction)
1715     {
1716         outputLineDirective(out, node->getLine().last_line);
1717         out << "}\n";
1718     }
1719 
1720     return false;
1721 }
1722 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)1723 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
1724 {
1725     TInfoSinkBase &out = getInfoSink();
1726 
1727     ASSERT(mCurrentFunctionMetadata == nullptr);
1728 
1729     size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
1730     ASSERT(index != CallDAG::InvalidIndex);
1731     mCurrentFunctionMetadata = &mASTMetadataList[index];
1732 
1733     out << TypeString(node->getFunctionPrototype()->getType()) << " ";
1734 
1735     TIntermSequence *parameters = node->getFunctionPrototype()->getSequence();
1736 
1737     if (node->getFunctionSymbolInfo()->isMain())
1738     {
1739         out << "gl_main(";
1740     }
1741     else
1742     {
1743         out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj())
1744             << DisambiguateFunctionName(parameters) << (mOutputLod0Function ? "Lod0(" : "(");
1745     }
1746 
1747     for (unsigned int i = 0; i < parameters->size(); i++)
1748     {
1749         TIntermSymbol *symbol = (*parameters)[i]->getAsSymbolNode();
1750 
1751         if (symbol)
1752         {
1753             ensureStructDefined(symbol->getType());
1754 
1755             out << argumentString(symbol);
1756 
1757             if (i < parameters->size() - 1)
1758             {
1759                 out << ", ";
1760             }
1761         }
1762         else
1763             UNREACHABLE();
1764     }
1765 
1766     out << ")\n";
1767 
1768     mInsideFunction = true;
1769     // The function body node will output braces.
1770     node->getBody()->traverse(this);
1771     mInsideFunction = false;
1772 
1773     mCurrentFunctionMetadata = nullptr;
1774 
1775     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
1776     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
1777     {
1778         ASSERT(!node->getFunctionSymbolInfo()->isMain());
1779         mOutputLod0Function = true;
1780         node->traverse(this);
1781         mOutputLod0Function = false;
1782     }
1783 
1784     return false;
1785 }
1786 
visitDeclaration(Visit visit,TIntermDeclaration * node)1787 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
1788 {
1789     if (visit == PreVisit)
1790     {
1791         TIntermSequence *sequence = node->getSequence();
1792         TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
1793         ASSERT(sequence->size() == 1);
1794         ASSERT(variable);
1795 
1796         if (IsDeclarationWrittenOut(node))
1797         {
1798             TInfoSinkBase &out = getInfoSink();
1799             ensureStructDefined(variable->getType());
1800 
1801             if (!variable->getAsSymbolNode() ||
1802                 variable->getAsSymbolNode()->getSymbol() != "")  // Variable declaration
1803             {
1804                 if (!mInsideFunction)
1805                 {
1806                     out << "static ";
1807                 }
1808 
1809                 out << TypeString(variable->getType()) + " ";
1810 
1811                 TIntermSymbol *symbol = variable->getAsSymbolNode();
1812 
1813                 if (symbol)
1814                 {
1815                     symbol->traverse(this);
1816                     out << ArrayString(symbol->getType());
1817                     out << " = " + initializer(symbol->getType());
1818                 }
1819                 else
1820                 {
1821                     variable->traverse(this);
1822                 }
1823             }
1824             else if (variable->getAsSymbolNode() &&
1825                      variable->getAsSymbolNode()->getSymbol() == "")  // Type (struct) declaration
1826             {
1827                 ASSERT(variable->getBasicType() == EbtStruct);
1828                 // ensureStructDefined has already been called.
1829             }
1830             else
1831                 UNREACHABLE();
1832         }
1833         else if (IsVaryingOut(variable->getQualifier()))
1834         {
1835             TIntermSymbol *symbol = variable->getAsSymbolNode();
1836             ASSERT(symbol);  // Varying declarations can't have initializers.
1837 
1838             // Vertex outputs which are declared but not written to should still be declared to
1839             // allow successful linking.
1840             mReferencedVaryings[symbol->getSymbol()] = symbol;
1841         }
1842     }
1843     return false;
1844 }
1845 
visitInvariantDeclaration(Visit visit,TIntermInvariantDeclaration * node)1846 bool OutputHLSL::visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
1847 {
1848     // Do not do any translation
1849     return false;
1850 }
1851 
visitFunctionPrototype(Visit visit,TIntermFunctionPrototype * node)1852 bool OutputHLSL::visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node)
1853 {
1854     TInfoSinkBase &out = getInfoSink();
1855 
1856     ASSERT(visit == PreVisit);
1857     size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
1858     // Skip the prototype if it is not implemented (and thus not used)
1859     if (index == CallDAG::InvalidIndex)
1860     {
1861         return false;
1862     }
1863 
1864     TIntermSequence *arguments = node->getSequence();
1865 
1866     TString name = DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
1867     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(arguments)
1868         << (mOutputLod0Function ? "Lod0(" : "(");
1869 
1870     for (unsigned int i = 0; i < arguments->size(); i++)
1871     {
1872         TIntermSymbol *symbol = (*arguments)[i]->getAsSymbolNode();
1873         ASSERT(symbol != nullptr);
1874 
1875         out << argumentString(symbol);
1876 
1877         if (i < arguments->size() - 1)
1878         {
1879             out << ", ";
1880         }
1881     }
1882 
1883     out << ");\n";
1884 
1885     // Also prototype the Lod0 variant if needed
1886     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
1887     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
1888     {
1889         mOutputLod0Function = true;
1890         node->traverse(this);
1891         mOutputLod0Function = false;
1892     }
1893 
1894     return false;
1895 }
1896 
visitAggregate(Visit visit,TIntermAggregate * node)1897 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
1898 {
1899     TInfoSinkBase &out = getInfoSink();
1900 
1901     switch (node->getOp())
1902     {
1903         case EOpCallBuiltInFunction:
1904         case EOpCallFunctionInAST:
1905         case EOpCallInternalRawFunction:
1906         {
1907             TIntermSequence *arguments = node->getSequence();
1908 
1909             bool lod0 = mInsideDiscontinuousLoop || mOutputLod0Function;
1910             if (node->getOp() == EOpCallFunctionInAST)
1911             {
1912                 if (node->isArray())
1913                 {
1914                     UNIMPLEMENTED();
1915                 }
1916                 size_t index = mCallDag.findIndex(node->getFunctionSymbolInfo());
1917                 ASSERT(index != CallDAG::InvalidIndex);
1918                 lod0 &= mASTMetadataList[index].mNeedsLod0;
1919 
1920                 out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj());
1921                 out << DisambiguateFunctionName(node->getSequence());
1922                 out << (lod0 ? "Lod0(" : "(");
1923             }
1924             else if (node->getOp() == EOpCallInternalRawFunction)
1925             {
1926                 // This path is used for internal functions that don't have their definitions in the
1927                 // AST, such as precision emulation functions.
1928                 out << DecorateFunctionIfNeeded(node->getFunctionSymbolInfo()->getNameObj()) << "(";
1929             }
1930             else if (node->getFunctionSymbolInfo()->isImageFunction())
1931             {
1932                 TString name              = node->getFunctionSymbolInfo()->getName();
1933                 TType type                = (*arguments)[0]->getAsTyped()->getType();
1934                 TString imageFunctionName = mImageFunctionHLSL->useImageFunction(
1935                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
1936                     type.getMemoryQualifier().readonly);
1937                 out << imageFunctionName << "(";
1938             }
1939             else
1940             {
1941                 const TString &name    = node->getFunctionSymbolInfo()->getName();
1942                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
1943                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
1944                 if (arguments->size() > 1)
1945                 {
1946                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
1947                 }
1948                 TString textureFunctionName = mTextureFunctionHLSL->useTextureFunction(
1949                     name, samplerType, coords, arguments->size(), lod0, mShaderType);
1950                 out << textureFunctionName << "(";
1951             }
1952 
1953             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
1954             {
1955                 TIntermTyped *typedArg = (*arg)->getAsTyped();
1956                 if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT && IsSampler(typedArg->getBasicType()))
1957                 {
1958                     out << "texture_";
1959                     (*arg)->traverse(this);
1960                     out << ", sampler_";
1961                 }
1962 
1963                 (*arg)->traverse(this);
1964 
1965                 if (typedArg->getType().isStructureContainingSamplers())
1966                 {
1967                     const TType &argType = typedArg->getType();
1968                     TVector<TIntermSymbol *> samplerSymbols;
1969                     TString structName = samplerNamePrefixFromStruct(typedArg);
1970                     argType.createSamplerSymbols("angle_" + structName, "", &samplerSymbols,
1971                                                  nullptr, mSymbolTable);
1972                     for (const TIntermSymbol *sampler : samplerSymbols)
1973                     {
1974                         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
1975                         {
1976                             out << ", texture_" << sampler->getSymbol();
1977                             out << ", sampler_" << sampler->getSymbol();
1978                         }
1979                         else
1980                         {
1981                             // In case of HLSL 4.1+, this symbol is the sampler index, and in case
1982                             // of D3D9, it's the sampler variable.
1983                             out << ", " + sampler->getSymbol();
1984                         }
1985                     }
1986                 }
1987 
1988                 if (arg < arguments->end() - 1)
1989                 {
1990                     out << ", ";
1991                 }
1992             }
1993 
1994             out << ")";
1995 
1996             return false;
1997         }
1998         case EOpConstruct:
1999             outputConstructor(out, visit, node);
2000             break;
2001         case EOpEqualComponentWise:
2002             outputTriplet(out, visit, "(", " == ", ")");
2003             break;
2004         case EOpNotEqualComponentWise:
2005             outputTriplet(out, visit, "(", " != ", ")");
2006             break;
2007         case EOpLessThanComponentWise:
2008             outputTriplet(out, visit, "(", " < ", ")");
2009             break;
2010         case EOpGreaterThanComponentWise:
2011             outputTriplet(out, visit, "(", " > ", ")");
2012             break;
2013         case EOpLessThanEqualComponentWise:
2014             outputTriplet(out, visit, "(", " <= ", ")");
2015             break;
2016         case EOpGreaterThanEqualComponentWise:
2017             outputTriplet(out, visit, "(", " >= ", ")");
2018             break;
2019         case EOpMod:
2020             ASSERT(node->getUseEmulatedFunction());
2021             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2022             break;
2023         case EOpModf:
2024             outputTriplet(out, visit, "modf(", ", ", ")");
2025             break;
2026         case EOpPow:
2027             outputTriplet(out, visit, "pow(", ", ", ")");
2028             break;
2029         case EOpAtan:
2030             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2031             ASSERT(node->getUseEmulatedFunction());
2032             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2033             break;
2034         case EOpMin:
2035             outputTriplet(out, visit, "min(", ", ", ")");
2036             break;
2037         case EOpMax:
2038             outputTriplet(out, visit, "max(", ", ", ")");
2039             break;
2040         case EOpClamp:
2041             outputTriplet(out, visit, "clamp(", ", ", ")");
2042             break;
2043         case EOpMix:
2044         {
2045             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2046             if (lastParamNode->getType().getBasicType() == EbtBool)
2047             {
2048                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2049                 // y, genBType a)",
2050                 // so use emulated version.
2051                 ASSERT(node->getUseEmulatedFunction());
2052                 writeEmulatedFunctionTriplet(out, visit, node->getOp());
2053             }
2054             else
2055             {
2056                 outputTriplet(out, visit, "lerp(", ", ", ")");
2057             }
2058             break;
2059         }
2060         case EOpStep:
2061             outputTriplet(out, visit, "step(", ", ", ")");
2062             break;
2063         case EOpSmoothStep:
2064             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2065             break;
2066         case EOpFrexp:
2067         case EOpLdexp:
2068             ASSERT(node->getUseEmulatedFunction());
2069             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2070             break;
2071         case EOpDistance:
2072             outputTriplet(out, visit, "distance(", ", ", ")");
2073             break;
2074         case EOpDot:
2075             outputTriplet(out, visit, "dot(", ", ", ")");
2076             break;
2077         case EOpCross:
2078             outputTriplet(out, visit, "cross(", ", ", ")");
2079             break;
2080         case EOpFaceforward:
2081             ASSERT(node->getUseEmulatedFunction());
2082             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2083             break;
2084         case EOpReflect:
2085             outputTriplet(out, visit, "reflect(", ", ", ")");
2086             break;
2087         case EOpRefract:
2088             outputTriplet(out, visit, "refract(", ", ", ")");
2089             break;
2090         case EOpOuterProduct:
2091             ASSERT(node->getUseEmulatedFunction());
2092             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2093             break;
2094         case EOpMulMatrixComponentWise:
2095             outputTriplet(out, visit, "(", " * ", ")");
2096             break;
2097         case EOpBitfieldExtract:
2098         case EOpBitfieldInsert:
2099         case EOpUaddCarry:
2100         case EOpUsubBorrow:
2101         case EOpUmulExtended:
2102         case EOpImulExtended:
2103             ASSERT(node->getUseEmulatedFunction());
2104             writeEmulatedFunctionTriplet(out, visit, node->getOp());
2105             break;
2106         default:
2107             UNREACHABLE();
2108     }
2109 
2110     return true;
2111 }
2112 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2113 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2114 {
2115     out << "if (";
2116 
2117     node->getCondition()->traverse(this);
2118 
2119     out << ")\n";
2120 
2121     outputLineDirective(out, node->getLine().first_line);
2122 
2123     bool discard = false;
2124 
2125     if (node->getTrueBlock())
2126     {
2127         // The trueBlock child node will output braces.
2128         node->getTrueBlock()->traverse(this);
2129 
2130         // Detect true discard
2131         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2132     }
2133     else
2134     {
2135         // TODO(oetuaho): Check if the semicolon inside is necessary.
2136         // It's there as a result of conservative refactoring of the output.
2137         out << "{;}\n";
2138     }
2139 
2140     outputLineDirective(out, node->getLine().first_line);
2141 
2142     if (node->getFalseBlock())
2143     {
2144         out << "else\n";
2145 
2146         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2147 
2148         // The falseBlock child node will output braces.
2149         node->getFalseBlock()->traverse(this);
2150 
2151         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2152 
2153         // Detect false discard
2154         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2155     }
2156 
2157     // ANGLE issue 486: Detect problematic conditional discard
2158     if (discard)
2159     {
2160         mUsesDiscardRewriting = true;
2161     }
2162 }
2163 
visitTernary(Visit,TIntermTernary *)2164 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2165 {
2166     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2167     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2168     UNREACHABLE();
2169     return false;
2170 }
2171 
visitIfElse(Visit visit,TIntermIfElse * node)2172 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2173 {
2174     TInfoSinkBase &out = getInfoSink();
2175 
2176     ASSERT(mInsideFunction);
2177 
2178     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2179     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2180     {
2181         out << "FLATTEN ";
2182     }
2183 
2184     writeIfElse(out, node);
2185 
2186     return false;
2187 }
2188 
visitSwitch(Visit visit,TIntermSwitch * node)2189 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2190 {
2191     TInfoSinkBase &out = getInfoSink();
2192 
2193     ASSERT(node->getStatementList());
2194     if (visit == PreVisit)
2195     {
2196         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
2197     }
2198     outputTriplet(out, visit, "switch (", ") ", "");
2199     // The curly braces get written when visiting the statementList block.
2200     return true;
2201 }
2202 
visitCase(Visit visit,TIntermCase * node)2203 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
2204 {
2205     TInfoSinkBase &out = getInfoSink();
2206 
2207     if (node->hasCondition())
2208     {
2209         outputTriplet(out, visit, "case (", "", "):\n");
2210         return true;
2211     }
2212     else
2213     {
2214         out << "default:\n";
2215         return false;
2216     }
2217 }
2218 
visitConstantUnion(TIntermConstantUnion * node)2219 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
2220 {
2221     TInfoSinkBase &out = getInfoSink();
2222     writeConstantUnion(out, node->getType(), node->getUnionArrayPointer());
2223 }
2224 
visitLoop(Visit visit,TIntermLoop * node)2225 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
2226 {
2227     mNestedLoopDepth++;
2228 
2229     bool wasDiscontinuous = mInsideDiscontinuousLoop;
2230     mInsideDiscontinuousLoop =
2231         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
2232 
2233     TInfoSinkBase &out = getInfoSink();
2234 
2235     if (mOutputType == SH_HLSL_3_0_OUTPUT)
2236     {
2237         if (handleExcessiveLoop(out, node))
2238         {
2239             mInsideDiscontinuousLoop = wasDiscontinuous;
2240             mNestedLoopDepth--;
2241 
2242             return false;
2243         }
2244     }
2245 
2246     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2247     if (node->getType() == ELoopDoWhile)
2248     {
2249         out << "{" << unroll << " do\n";
2250 
2251         outputLineDirective(out, node->getLine().first_line);
2252     }
2253     else
2254     {
2255         out << "{" << unroll << " for(";
2256 
2257         if (node->getInit())
2258         {
2259             node->getInit()->traverse(this);
2260         }
2261 
2262         out << "; ";
2263 
2264         if (node->getCondition())
2265         {
2266             node->getCondition()->traverse(this);
2267         }
2268 
2269         out << "; ";
2270 
2271         if (node->getExpression())
2272         {
2273             node->getExpression()->traverse(this);
2274         }
2275 
2276         out << ")\n";
2277 
2278         outputLineDirective(out, node->getLine().first_line);
2279     }
2280 
2281     if (node->getBody())
2282     {
2283         // The loop body node will output braces.
2284         node->getBody()->traverse(this);
2285     }
2286     else
2287     {
2288         // TODO(oetuaho): Check if the semicolon inside is necessary.
2289         // It's there as a result of conservative refactoring of the output.
2290         out << "{;}\n";
2291     }
2292 
2293     outputLineDirective(out, node->getLine().first_line);
2294 
2295     if (node->getType() == ELoopDoWhile)
2296     {
2297         outputLineDirective(out, node->getCondition()->getLine().first_line);
2298         out << "while (";
2299 
2300         node->getCondition()->traverse(this);
2301 
2302         out << ");\n";
2303     }
2304 
2305     out << "}\n";
2306 
2307     mInsideDiscontinuousLoop = wasDiscontinuous;
2308     mNestedLoopDepth--;
2309 
2310     return false;
2311 }
2312 
visitBranch(Visit visit,TIntermBranch * node)2313 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
2314 {
2315     if (visit == PreVisit)
2316     {
2317         TInfoSinkBase &out = getInfoSink();
2318 
2319         switch (node->getFlowOp())
2320         {
2321             case EOpKill:
2322                 out << "discard";
2323                 break;
2324             case EOpBreak:
2325                 if (mNestedLoopDepth > 1)
2326                 {
2327                     mUsesNestedBreak = true;
2328                 }
2329 
2330                 if (mExcessiveLoopIndex)
2331                 {
2332                     out << "{Break";
2333                     mExcessiveLoopIndex->traverse(this);
2334                     out << " = true; break;}\n";
2335                 }
2336                 else
2337                 {
2338                     out << "break";
2339                 }
2340                 break;
2341             case EOpContinue:
2342                 out << "continue";
2343                 break;
2344             case EOpReturn:
2345                 if (node->getExpression())
2346                 {
2347                     out << "return ";
2348                 }
2349                 else
2350                 {
2351                     out << "return";
2352                 }
2353                 break;
2354             default:
2355                 UNREACHABLE();
2356         }
2357     }
2358 
2359     return true;
2360 }
2361 
2362 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
2363 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
2364 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)2365 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
2366 {
2367     const int MAX_LOOP_ITERATIONS = 254;
2368 
2369     // Parse loops of the form:
2370     // for(int index = initial; index [comparator] limit; index += increment)
2371     TIntermSymbol *index = nullptr;
2372     TOperator comparator = EOpNull;
2373     int initial          = 0;
2374     int limit            = 0;
2375     int increment        = 0;
2376 
2377     // Parse index name and intial value
2378     if (node->getInit())
2379     {
2380         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
2381 
2382         if (init)
2383         {
2384             TIntermSequence *sequence = init->getSequence();
2385             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
2386 
2387             if (variable && variable->getQualifier() == EvqTemporary)
2388             {
2389                 TIntermBinary *assign = variable->getAsBinaryNode();
2390 
2391                 if (assign->getOp() == EOpInitialize)
2392                 {
2393                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
2394                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
2395 
2396                     if (symbol && constant)
2397                     {
2398                         if (constant->getBasicType() == EbtInt && constant->isScalar())
2399                         {
2400                             index   = symbol;
2401                             initial = constant->getIConst(0);
2402                         }
2403                     }
2404                 }
2405             }
2406         }
2407     }
2408 
2409     // Parse comparator and limit value
2410     if (index != nullptr && node->getCondition())
2411     {
2412         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
2413 
2414         if (test && test->getLeft()->getAsSymbolNode()->getId() == index->getId())
2415         {
2416             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
2417 
2418             if (constant)
2419             {
2420                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2421                 {
2422                     comparator = test->getOp();
2423                     limit      = constant->getIConst(0);
2424                 }
2425             }
2426         }
2427     }
2428 
2429     // Parse increment
2430     if (index != nullptr && comparator != EOpNull && node->getExpression())
2431     {
2432         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
2433         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
2434 
2435         if (binaryTerminal)
2436         {
2437             TOperator op                   = binaryTerminal->getOp();
2438             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
2439 
2440             if (constant)
2441             {
2442                 if (constant->getBasicType() == EbtInt && constant->isScalar())
2443                 {
2444                     int value = constant->getIConst(0);
2445 
2446                     switch (op)
2447                     {
2448                         case EOpAddAssign:
2449                             increment = value;
2450                             break;
2451                         case EOpSubAssign:
2452                             increment = -value;
2453                             break;
2454                         default:
2455                             UNIMPLEMENTED();
2456                     }
2457                 }
2458             }
2459         }
2460         else if (unaryTerminal)
2461         {
2462             TOperator op = unaryTerminal->getOp();
2463 
2464             switch (op)
2465             {
2466                 case EOpPostIncrement:
2467                     increment = 1;
2468                     break;
2469                 case EOpPostDecrement:
2470                     increment = -1;
2471                     break;
2472                 case EOpPreIncrement:
2473                     increment = 1;
2474                     break;
2475                 case EOpPreDecrement:
2476                     increment = -1;
2477                     break;
2478                 default:
2479                     UNIMPLEMENTED();
2480             }
2481         }
2482     }
2483 
2484     if (index != nullptr && comparator != EOpNull && increment != 0)
2485     {
2486         if (comparator == EOpLessThanEqual)
2487         {
2488             comparator = EOpLessThan;
2489             limit += 1;
2490         }
2491 
2492         if (comparator == EOpLessThan)
2493         {
2494             int iterations = (limit - initial) / increment;
2495 
2496             if (iterations <= MAX_LOOP_ITERATIONS)
2497             {
2498                 return false;  // Not an excessive loop
2499             }
2500 
2501             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
2502             mExcessiveLoopIndex         = index;
2503 
2504             out << "{int ";
2505             index->traverse(this);
2506             out << ";\n"
2507                    "bool Break";
2508             index->traverse(this);
2509             out << " = false;\n";
2510 
2511             bool firstLoopFragment = true;
2512 
2513             while (iterations > 0)
2514             {
2515                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
2516 
2517                 if (!firstLoopFragment)
2518                 {
2519                     out << "if (!Break";
2520                     index->traverse(this);
2521                     out << ") {\n";
2522                 }
2523 
2524                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
2525                 {
2526                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
2527                 }
2528 
2529                 // for(int index = initial; index < clampedLimit; index += increment)
2530                 const char *unroll =
2531                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
2532 
2533                 out << unroll << " for(";
2534                 index->traverse(this);
2535                 out << " = ";
2536                 out << initial;
2537 
2538                 out << "; ";
2539                 index->traverse(this);
2540                 out << " < ";
2541                 out << clampedLimit;
2542 
2543                 out << "; ";
2544                 index->traverse(this);
2545                 out << " += ";
2546                 out << increment;
2547                 out << ")\n";
2548 
2549                 outputLineDirective(out, node->getLine().first_line);
2550                 out << "{\n";
2551 
2552                 if (node->getBody())
2553                 {
2554                     node->getBody()->traverse(this);
2555                 }
2556 
2557                 outputLineDirective(out, node->getLine().first_line);
2558                 out << ";}\n";
2559 
2560                 if (!firstLoopFragment)
2561                 {
2562                     out << "}\n";
2563                 }
2564 
2565                 firstLoopFragment = false;
2566 
2567                 initial += MAX_LOOP_ITERATIONS * increment;
2568                 iterations -= MAX_LOOP_ITERATIONS;
2569             }
2570 
2571             out << "}";
2572 
2573             mExcessiveLoopIndex = restoreIndex;
2574 
2575             return true;
2576         }
2577         else
2578             UNIMPLEMENTED();
2579     }
2580 
2581     return false;  // Not handled as an excessive loop
2582 }
2583 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)2584 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
2585                                Visit visit,
2586                                const char *preString,
2587                                const char *inString,
2588                                const char *postString)
2589 {
2590     if (visit == PreVisit)
2591     {
2592         out << preString;
2593     }
2594     else if (visit == InVisit)
2595     {
2596         out << inString;
2597     }
2598     else if (visit == PostVisit)
2599     {
2600         out << postString;
2601     }
2602 }
2603 
outputLineDirective(TInfoSinkBase & out,int line)2604 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
2605 {
2606     if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
2607     {
2608         out << "\n";
2609         out << "#line " << line;
2610 
2611         if (mSourcePath)
2612         {
2613             out << " \"" << mSourcePath << "\"";
2614         }
2615 
2616         out << "\n";
2617     }
2618 }
2619 
argumentString(const TIntermSymbol * symbol)2620 TString OutputHLSL::argumentString(const TIntermSymbol *symbol)
2621 {
2622     TQualifier qualifier = symbol->getQualifier();
2623     const TType &type    = symbol->getType();
2624     const TName &name    = symbol->getName();
2625     TString nameStr;
2626 
2627     if (name.getString().empty())  // HLSL demands named arguments, also for prototypes
2628     {
2629         nameStr = "x" + str(mUniqueIndex++);
2630     }
2631     else
2632     {
2633         nameStr = DecorateVariableIfNeeded(name);
2634     }
2635 
2636     if (IsSampler(type.getBasicType()))
2637     {
2638         if (mOutputType == SH_HLSL_4_1_OUTPUT)
2639         {
2640             // Samplers are passed as indices to the sampler array.
2641             ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
2642             return "const uint " + nameStr + ArrayString(type);
2643         }
2644         if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2645         {
2646             return QualifierString(qualifier) + " " + TextureString(type.getBasicType()) +
2647                    " texture_" + nameStr + ArrayString(type) + ", " + QualifierString(qualifier) +
2648                    " " + SamplerString(type.getBasicType()) + " sampler_" + nameStr +
2649                    ArrayString(type);
2650         }
2651     }
2652 
2653     TStringStream argString;
2654     argString << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
2655               << ArrayString(type);
2656 
2657     // If the structure parameter contains samplers, they need to be passed into the function as
2658     // separate parameters. HLSL doesn't natively support samplers in structs.
2659     if (type.isStructureContainingSamplers())
2660     {
2661         ASSERT(qualifier != EvqOut && qualifier != EvqInOut);
2662         TVector<TIntermSymbol *> samplerSymbols;
2663         type.createSamplerSymbols("angle" + nameStr, "", &samplerSymbols, nullptr, mSymbolTable);
2664         for (const TIntermSymbol *sampler : samplerSymbols)
2665         {
2666             const TType &samplerType = sampler->getType();
2667             if (mOutputType == SH_HLSL_4_1_OUTPUT)
2668             {
2669                 argString << ", const uint " << sampler->getSymbol() << ArrayString(samplerType);
2670             }
2671             else if (mOutputType == SH_HLSL_4_0_FL9_3_OUTPUT)
2672             {
2673                 ASSERT(IsSampler(samplerType.getBasicType()));
2674                 argString << ", " << QualifierString(qualifier) << " "
2675                           << TextureString(samplerType.getBasicType()) << " texture_"
2676                           << sampler->getSymbol() << ArrayString(samplerType) << ", "
2677                           << QualifierString(qualifier) << " "
2678                           << SamplerString(samplerType.getBasicType()) << " sampler_"
2679                           << sampler->getSymbol() << ArrayString(samplerType);
2680             }
2681             else
2682             {
2683                 ASSERT(IsSampler(samplerType.getBasicType()));
2684                 argString << ", " << QualifierString(qualifier) << " " << TypeString(samplerType)
2685                           << " " << sampler->getSymbol() << ArrayString(samplerType);
2686             }
2687         }
2688     }
2689 
2690     return argString.str();
2691 }
2692 
initializer(const TType & type)2693 TString OutputHLSL::initializer(const TType &type)
2694 {
2695     TString string;
2696 
2697     size_t size = type.getObjectSize();
2698     for (size_t component = 0; component < size; component++)
2699     {
2700         string += "0";
2701 
2702         if (component + 1 < size)
2703         {
2704             string += ", ";
2705         }
2706     }
2707 
2708     return "{" + string + "}";
2709 }
2710 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)2711 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
2712 {
2713     // Array constructors should have been already pruned from the code.
2714     ASSERT(!node->getType().isArray());
2715 
2716     if (visit == PreVisit)
2717     {
2718         TString constructorName;
2719         if (node->getBasicType() == EbtStruct)
2720         {
2721             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
2722         }
2723         else
2724         {
2725             constructorName =
2726                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
2727         }
2728         out << constructorName << "(";
2729     }
2730     else if (visit == InVisit)
2731     {
2732         out << ", ";
2733     }
2734     else if (visit == PostVisit)
2735     {
2736         out << ")";
2737     }
2738 }
2739 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)2740 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
2741                                                      const TType &type,
2742                                                      const TConstantUnion *const constUnion)
2743 {
2744     const TConstantUnion *constUnionIterated = constUnion;
2745 
2746     const TStructure *structure = type.getStruct();
2747     if (structure)
2748     {
2749         out << mStructureHLSL->addStructConstructor(*structure) << "(";
2750 
2751         const TFieldList &fields = structure->fields();
2752 
2753         for (size_t i = 0; i < fields.size(); i++)
2754         {
2755             const TType *fieldType = fields[i]->type();
2756             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
2757 
2758             if (i != fields.size() - 1)
2759             {
2760                 out << ", ";
2761             }
2762         }
2763 
2764         out << ")";
2765     }
2766     else
2767     {
2768         size_t size    = type.getObjectSize();
2769         bool writeType = size > 1;
2770 
2771         if (writeType)
2772         {
2773             out << TypeString(type) << "(";
2774         }
2775         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
2776         if (writeType)
2777         {
2778             out << ")";
2779         }
2780     }
2781 
2782     return constUnionIterated;
2783 }
2784 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,TOperator op)2785 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out, Visit visit, TOperator op)
2786 {
2787     if (visit == PreVisit)
2788     {
2789         const char *opStr = GetOperatorString(op);
2790         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, opStr);
2791         out << "(";
2792     }
2793     else
2794     {
2795         outputTriplet(out, visit, nullptr, ", ", ")");
2796     }
2797 }
2798 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)2799 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
2800                                             TIntermSymbol *symbolNode,
2801                                             TIntermTyped *expression)
2802 {
2803     sh::SearchSymbol searchSymbol(symbolNode->getSymbol());
2804     expression->traverse(&searchSymbol);
2805 
2806     if (searchSymbol.foundMatch())
2807     {
2808         // Type already printed
2809         out << "t" + str(mUniqueIndex) + " = ";
2810         expression->traverse(this);
2811         out << ", ";
2812         symbolNode->traverse(this);
2813         out << " = t" + str(mUniqueIndex);
2814 
2815         mUniqueIndex++;
2816         return true;
2817     }
2818 
2819     return false;
2820 }
2821 
canWriteAsHLSLLiteral(TIntermTyped * expression)2822 bool OutputHLSL::canWriteAsHLSLLiteral(TIntermTyped *expression)
2823 {
2824     // We support writing constant unions and constructors that only take constant unions as
2825     // parameters as HLSL literals.
2826     return !expression->getType().isArrayOfArrays() &&
2827            (expression->getAsConstantUnion() ||
2828             expression->isConstructorWithOnlyConstantUnionParameters());
2829 }
2830 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)2831 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
2832                                              TIntermSymbol *symbolNode,
2833                                              TIntermTyped *initializer)
2834 {
2835     if (canWriteAsHLSLLiteral(initializer))
2836     {
2837         symbolNode->traverse(this);
2838         ASSERT(!symbolNode->getType().isArrayOfArrays());
2839         if (symbolNode->getType().isArray())
2840         {
2841             out << "[" << symbolNode->getType().getOutermostArraySize() << "]";
2842         }
2843         out << " = {";
2844         if (initializer->getAsConstantUnion())
2845         {
2846             TIntermConstantUnion *nodeConst  = initializer->getAsConstantUnion();
2847             const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
2848             writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
2849         }
2850         else
2851         {
2852             TIntermAggregate *constructor = initializer->getAsAggregate();
2853             ASSERT(constructor != nullptr);
2854             for (TIntermNode *&node : *constructor->getSequence())
2855             {
2856                 TIntermConstantUnion *nodeConst = node->getAsConstantUnion();
2857                 ASSERT(nodeConst);
2858                 const TConstantUnion *constUnion = nodeConst->getUnionArrayPointer();
2859                 writeConstantUnionArray(out, constUnion, nodeConst->getType().getObjectSize());
2860                 if (node != constructor->getSequence()->back())
2861                 {
2862                     out << ", ";
2863                 }
2864             }
2865         }
2866         out << "}";
2867         return true;
2868     }
2869     return false;
2870 }
2871 
addStructEqualityFunction(const TStructure & structure)2872 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
2873 {
2874     const TFieldList &fields = structure.fields();
2875 
2876     for (const auto &eqFunction : mStructEqualityFunctions)
2877     {
2878         if (eqFunction->structure == &structure)
2879         {
2880             return eqFunction->functionName;
2881         }
2882     }
2883 
2884     const TString &structNameString = StructNameString(structure);
2885 
2886     StructEqualityFunction *function = new StructEqualityFunction();
2887     function->structure              = &structure;
2888     function->functionName           = "angle_eq_" + structNameString;
2889 
2890     TInfoSinkBase fnOut;
2891 
2892     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
2893           << structNameString + " b)\n"
2894           << "{\n"
2895              "    return ";
2896 
2897     for (size_t i = 0; i < fields.size(); i++)
2898     {
2899         const TField *field    = fields[i];
2900         const TType *fieldType = field->type();
2901 
2902         const TString &fieldNameA = "a." + Decorate(field->name());
2903         const TString &fieldNameB = "b." + Decorate(field->name());
2904 
2905         if (i > 0)
2906         {
2907             fnOut << " && ";
2908         }
2909 
2910         fnOut << "(";
2911         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
2912         fnOut << fieldNameA;
2913         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
2914         fnOut << fieldNameB;
2915         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
2916         fnOut << ")";
2917     }
2918 
2919     fnOut << ";\n"
2920           << "}\n";
2921 
2922     function->functionDefinition = fnOut.c_str();
2923 
2924     mStructEqualityFunctions.push_back(function);
2925     mEqualityFunctions.push_back(function);
2926 
2927     return function->functionName;
2928 }
2929 
addArrayEqualityFunction(const TType & type)2930 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
2931 {
2932     for (const auto &eqFunction : mArrayEqualityFunctions)
2933     {
2934         if (eqFunction->type == type)
2935         {
2936             return eqFunction->functionName;
2937         }
2938     }
2939 
2940     TType elementType(type);
2941     elementType.toArrayElementType();
2942 
2943     ArrayHelperFunction *function = new ArrayHelperFunction();
2944     function->type                = type;
2945 
2946     function->functionName = ArrayHelperFunctionName("angle_eq", type);
2947 
2948     TInfoSinkBase fnOut;
2949 
2950     const TString &typeName = TypeString(type);
2951     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
2952           << ", " << typeName << " b" << ArrayString(type) << ")\n"
2953           << "{\n"
2954              "    for (int i = 0; i < "
2955           << type.getOutermostArraySize()
2956           << "; ++i)\n"
2957              "    {\n"
2958              "        if (";
2959 
2960     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
2961     fnOut << "a[i]";
2962     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
2963     fnOut << "b[i]";
2964     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
2965 
2966     fnOut << ") { return false; }\n"
2967              "    }\n"
2968              "    return true;\n"
2969              "}\n";
2970 
2971     function->functionDefinition = fnOut.c_str();
2972 
2973     mArrayEqualityFunctions.push_back(function);
2974     mEqualityFunctions.push_back(function);
2975 
2976     return function->functionName;
2977 }
2978 
addArrayAssignmentFunction(const TType & type)2979 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
2980 {
2981     for (const auto &assignFunction : mArrayAssignmentFunctions)
2982     {
2983         if (assignFunction.type == type)
2984         {
2985             return assignFunction.functionName;
2986         }
2987     }
2988 
2989     TType elementType(type);
2990     elementType.toArrayElementType();
2991 
2992     ArrayHelperFunction function;
2993     function.type = type;
2994 
2995     function.functionName = ArrayHelperFunctionName("angle_assign", type);
2996 
2997     TInfoSinkBase fnOut;
2998 
2999     const TString &typeName = TypeString(type);
3000     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3001           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3002           << "{\n"
3003              "    for (int i = 0; i < "
3004           << type.getOutermostArraySize()
3005           << "; ++i)\n"
3006              "    {\n"
3007              "        ";
3008 
3009     outputAssign(PreVisit, elementType, fnOut);
3010     fnOut << "a[i]";
3011     outputAssign(InVisit, elementType, fnOut);
3012     fnOut << "b[i]";
3013     outputAssign(PostVisit, elementType, fnOut);
3014 
3015     fnOut << ";\n"
3016              "    }\n"
3017              "}\n";
3018 
3019     function.functionDefinition = fnOut.c_str();
3020 
3021     mArrayAssignmentFunctions.push_back(function);
3022 
3023     return function.functionName;
3024 }
3025 
addArrayConstructIntoFunction(const TType & type)3026 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3027 {
3028     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3029     {
3030         if (constructIntoFunction.type == type)
3031         {
3032             return constructIntoFunction.functionName;
3033         }
3034     }
3035 
3036     TType elementType(type);
3037     elementType.toArrayElementType();
3038 
3039     ArrayHelperFunction function;
3040     function.type = type;
3041 
3042     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3043 
3044     TInfoSinkBase fnOut;
3045 
3046     const TString &typeName = TypeString(type);
3047     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3048     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3049     {
3050         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3051     }
3052     fnOut << ")\n"
3053              "{\n";
3054 
3055     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3056     {
3057         fnOut << "    ";
3058         outputAssign(PreVisit, elementType, fnOut);
3059         fnOut << "a[" << i << "]";
3060         outputAssign(InVisit, elementType, fnOut);
3061         fnOut << "b" << i;
3062         outputAssign(PostVisit, elementType, fnOut);
3063         fnOut << ";\n";
3064     }
3065     fnOut << "}\n";
3066 
3067     function.functionDefinition = fnOut.c_str();
3068 
3069     mArrayConstructIntoFunctions.push_back(function);
3070 
3071     return function.functionName;
3072 }
3073 
ensureStructDefined(const TType & type)3074 void OutputHLSL::ensureStructDefined(const TType &type)
3075 {
3076     const TStructure *structure = type.getStruct();
3077     if (structure)
3078     {
3079         ASSERT(type.getBasicType() == EbtStruct);
3080         mStructureHLSL->ensureStructDefined(*structure);
3081     }
3082 }
3083 
3084 }  // namespace sh
3085