1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ShaderStorageBlockOutputHLSL: A traverser to translate a ssbo_access_chain to an offset of
7 // RWByteAddressBuffer.
8 //     //EOpIndexDirectInterfaceBlock
9 //     ssbo_variable :=
10 //       | the name of the SSBO
11 //       | the name of a variable in an SSBO backed interface block
12 
13 //     // EOpIndexInDirect
14 //     // EOpIndexDirect
15 //     ssbo_array_indexing := ssbo_access_chain[expr_no_ssbo]
16 
17 //     // EOpIndexDirectStruct
18 //     ssbo_structure_access := ssbo_access_chain.identifier
19 
20 //     ssbo_access_chain :=
21 //       | ssbo_variable
22 //       | ssbo_array_indexing
23 //       | ssbo_structure_access
24 //
25 
26 #include "compiler/translator/ShaderStorageBlockOutputHLSL.h"
27 
28 #include "compiler/translator/ResourcesHLSL.h"
29 #include "compiler/translator/blocklayoutHLSL.h"
30 #include "compiler/translator/tree_util/IntermNode_util.h"
31 #include "compiler/translator/util.h"
32 
33 namespace sh
34 {
35 
36 namespace
37 {
38 
GetBlockLayoutInfo(TIntermTyped * node,bool rowMajorAlreadyAssigned,TLayoutBlockStorage * storage,bool * rowMajor)39 void GetBlockLayoutInfo(TIntermTyped *node,
40                         bool rowMajorAlreadyAssigned,
41                         TLayoutBlockStorage *storage,
42                         bool *rowMajor)
43 {
44     TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
45     if (swizzleNode)
46     {
47         return GetBlockLayoutInfo(swizzleNode->getOperand(), rowMajorAlreadyAssigned, storage,
48                                   rowMajor);
49     }
50 
51     TIntermBinary *binaryNode = node->getAsBinaryNode();
52     if (binaryNode)
53     {
54         switch (binaryNode->getOp())
55         {
56             case EOpIndexDirectInterfaceBlock:
57             {
58                 // The column_major/row_major qualifier of a field member overrides the interface
59                 // block's row_major/column_major. So we can assign rowMajor here and don't need to
60                 // assign it again. But we still need to call recursively to get the storage's
61                 // value.
62                 const TType &type = node->getType();
63                 *rowMajor         = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
64                 return GetBlockLayoutInfo(binaryNode->getLeft(), true, storage, rowMajor);
65             }
66             case EOpIndexIndirect:
67             case EOpIndexDirect:
68             case EOpIndexDirectStruct:
69                 return GetBlockLayoutInfo(binaryNode->getLeft(), rowMajorAlreadyAssigned, storage,
70                                           rowMajor);
71             default:
72                 UNREACHABLE();
73                 return;
74         }
75     }
76 
77     const TType &type = node->getType();
78     ASSERT(type.getQualifier() == EvqBuffer);
79     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
80     ASSERT(interfaceBlock);
81     *storage = interfaceBlock->blockStorage();
82     // If the block doesn't have an instance name, rowMajorAlreadyAssigned will be false. In
83     // this situation, we still need to set rowMajor's value.
84     if (!rowMajorAlreadyAssigned)
85     {
86         *rowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
87     }
88 }
89 
90 // It's possible that the current type has lost the original layout information. So we should pass
91 // the right layout information to GetBlockMemberInfoByType.
GetBlockMemberInfoByType(const TType & type,TLayoutBlockStorage storage,bool rowMajor)92 const BlockMemberInfo GetBlockMemberInfoByType(const TType &type,
93                                                TLayoutBlockStorage storage,
94                                                bool rowMajor)
95 {
96     sh::Std140BlockEncoder std140Encoder;
97     sh::Std430BlockEncoder std430Encoder;
98     sh::HLSLBlockEncoder hlslEncoder(sh::HLSLBlockEncoder::ENCODE_PACKED, false);
99     sh::BlockLayoutEncoder *encoder = nullptr;
100 
101     if (storage == EbsStd140)
102     {
103         encoder = &std140Encoder;
104     }
105     else if (storage == EbsStd430)
106     {
107         encoder = &std430Encoder;
108     }
109     else
110     {
111         encoder = &hlslEncoder;
112     }
113 
114     std::vector<unsigned int> arraySizes;
115     const TSpan<const unsigned int> &typeArraySizes = type.getArraySizes();
116     if (!typeArraySizes.empty())
117     {
118         arraySizes.assign(typeArraySizes.begin(), typeArraySizes.end());
119     }
120     return encoder->encodeType(GLVariableType(type), arraySizes, rowMajor);
121 }
122 
GetFieldMemberInShaderStorageBlock(const TInterfaceBlock * interfaceBlock,const ImmutableString & variableName)123 const TField *GetFieldMemberInShaderStorageBlock(const TInterfaceBlock *interfaceBlock,
124                                                  const ImmutableString &variableName)
125 {
126     for (const TField *field : interfaceBlock->fields())
127     {
128         if (field->name() == variableName)
129         {
130             return field;
131         }
132     }
133     return nullptr;
134 }
135 
FindInterfaceBlock(const TInterfaceBlock * needle,const std::vector<InterfaceBlock> & haystack)136 const InterfaceBlock *FindInterfaceBlock(const TInterfaceBlock *needle,
137                                          const std::vector<InterfaceBlock> &haystack)
138 {
139     for (const InterfaceBlock &block : haystack)
140     {
141         if (strcmp(block.name.c_str(), needle->name().data()) == 0)
142         {
143             ASSERT(block.fields.size() == needle->fields().size());
144             return &block;
145         }
146     }
147 
148     UNREACHABLE();
149     return nullptr;
150 }
151 
StripArrayIndices(const std::string & nameIn)152 std::string StripArrayIndices(const std::string &nameIn)
153 {
154     std::string name = nameIn;
155     size_t pos       = name.find('[');
156     while (pos != std::string::npos)
157     {
158         size_t closePos = name.find(']', pos);
159         ASSERT(closePos != std::string::npos);
160         name.erase(pos, closePos - pos + 1);
161         pos = name.find('[', pos);
162     }
163     ASSERT(name.find(']') == std::string::npos);
164     return name;
165 }
166 
167 // Does not include any array indices.
MapVariableToField(const ShaderVariable & variable,const TField * field,std::string currentName,ShaderVarToFieldMap * shaderVarToFieldMap)168 void MapVariableToField(const ShaderVariable &variable,
169                         const TField *field,
170                         std::string currentName,
171                         ShaderVarToFieldMap *shaderVarToFieldMap)
172 {
173     ASSERT((field->type()->getStruct() == nullptr) == variable.fields.empty());
174     (*shaderVarToFieldMap)[currentName] = field;
175 
176     if (!variable.fields.empty())
177     {
178         const TStructure *subStruct = field->type()->getStruct();
179         ASSERT(variable.fields.size() == subStruct->fields().size());
180 
181         for (size_t index = 0; index < variable.fields.size(); ++index)
182         {
183             const TField *subField            = subStruct->fields()[index];
184             const ShaderVariable &subVariable = variable.fields[index];
185             std::string subName               = currentName + "." + subVariable.name;
186             MapVariableToField(subVariable, subField, subName, shaderVarToFieldMap);
187         }
188     }
189 }
190 
191 class BlockInfoVisitor final : public BlockEncoderVisitor
192 {
193   public:
BlockInfoVisitor(const std::string & prefix,TLayoutBlockStorage storage,const ShaderVarToFieldMap & shaderVarToFieldMap,BlockMemberInfoMap * blockInfoOut)194     BlockInfoVisitor(const std::string &prefix,
195                      TLayoutBlockStorage storage,
196                      const ShaderVarToFieldMap &shaderVarToFieldMap,
197                      BlockMemberInfoMap *blockInfoOut)
198         : BlockEncoderVisitor(prefix, "", getEncoder(storage)),
199           mShaderVarToFieldMap(shaderVarToFieldMap),
200           mBlockInfoOut(blockInfoOut),
201           mHLSLEncoder(HLSLBlockEncoder::ENCODE_PACKED, false),
202           mStorage(storage)
203     {}
204 
getEncoder(TLayoutBlockStorage storage)205     BlockLayoutEncoder *getEncoder(TLayoutBlockStorage storage)
206     {
207         switch (storage)
208         {
209             case EbsStd140:
210                 return &mStd140Encoder;
211             case EbsStd430:
212                 return &mStd430Encoder;
213             default:
214                 return &mHLSLEncoder;
215         }
216     }
217 
enterStructAccess(const ShaderVariable & structVar,bool isRowMajor)218     void enterStructAccess(const ShaderVariable &structVar, bool isRowMajor) override
219     {
220         BlockEncoderVisitor::enterStructAccess(structVar, isRowMajor);
221 
222         std::string variableName = StripArrayIndices(collapseNameStack());
223 
224         // Remove the trailing "."
225         variableName.pop_back();
226 
227         BlockInfoVisitor childVisitor(variableName, mStorage, mShaderVarToFieldMap, mBlockInfoOut);
228         childVisitor.getEncoder(mStorage)->enterAggregateType(structVar);
229         TraverseShaderVariables(structVar.fields, isRowMajor, &childVisitor);
230         childVisitor.getEncoder(mStorage)->exitAggregateType(structVar);
231 
232         int offset      = static_cast<int>(getEncoder(mStorage)->getCurrentOffset());
233         int arrayStride = static_cast<int>(childVisitor.getEncoder(mStorage)->getCurrentOffset());
234 
235         auto iter = mShaderVarToFieldMap.find(variableName);
236         if (iter == mShaderVarToFieldMap.end())
237             return;
238 
239         const TField *structField = iter->second;
240         if (mBlockInfoOut->count(structField) == 0)
241         {
242             mBlockInfoOut->emplace(structField, BlockMemberInfo(offset, arrayStride, -1, false));
243         }
244     }
245 
encodeVariable(const ShaderVariable & variable,const BlockMemberInfo & variableInfo,const std::string & name,const std::string & mappedName)246     void encodeVariable(const ShaderVariable &variable,
247                         const BlockMemberInfo &variableInfo,
248                         const std::string &name,
249                         const std::string &mappedName) override
250     {
251         auto iter = mShaderVarToFieldMap.find(StripArrayIndices(name));
252         if (iter == mShaderVarToFieldMap.end())
253             return;
254 
255         const TField *field = iter->second;
256         if (mBlockInfoOut->count(field) == 0)
257         {
258             mBlockInfoOut->emplace(field, variableInfo);
259         }
260     }
261 
262   private:
263     const ShaderVarToFieldMap &mShaderVarToFieldMap;
264     BlockMemberInfoMap *mBlockInfoOut;
265     Std140BlockEncoder mStd140Encoder;
266     Std430BlockEncoder mStd430Encoder;
267     HLSLBlockEncoder mHLSLEncoder;
268     TLayoutBlockStorage mStorage;
269 };
270 
GetShaderStorageBlockMembersInfo(const TInterfaceBlock * interfaceBlock,const std::vector<InterfaceBlock> & shaderStorageBlocks,BlockMemberInfoMap * blockInfoOut)271 void GetShaderStorageBlockMembersInfo(const TInterfaceBlock *interfaceBlock,
272                                       const std::vector<InterfaceBlock> &shaderStorageBlocks,
273                                       BlockMemberInfoMap *blockInfoOut)
274 {
275     // Find the sh::InterfaceBlock.
276     const InterfaceBlock *block = FindInterfaceBlock(interfaceBlock, shaderStorageBlocks);
277     ASSERT(block);
278 
279     // Map ShaderVariable to TField.
280     ShaderVarToFieldMap shaderVarToFieldMap;
281     for (size_t index = 0; index < block->fields.size(); ++index)
282     {
283         const TField *field            = interfaceBlock->fields()[index];
284         const ShaderVariable &variable = block->fields[index];
285         MapVariableToField(variable, field, variable.name, &shaderVarToFieldMap);
286     }
287 
288     BlockInfoVisitor visitor("", interfaceBlock->blockStorage(), shaderVarToFieldMap, blockInfoOut);
289     TraverseShaderVariables(block->fields, false, &visitor);
290 }
291 
Mul(TIntermTyped * left,TIntermTyped * right)292 TIntermTyped *Mul(TIntermTyped *left, TIntermTyped *right)
293 {
294     return left && right ? new TIntermBinary(EOpMul, left, right) : nullptr;
295 }
296 
Add(TIntermTyped * left,TIntermTyped * right)297 TIntermTyped *Add(TIntermTyped *left, TIntermTyped *right)
298 {
299     return left ? right ? new TIntermBinary(EOpAdd, left, right) : left : right;
300 }
301 
302 }  // anonymous namespace
303 
ShaderStorageBlockOutputHLSL(OutputHLSL * outputHLSL,ResourcesHLSL * resourcesHLSL,const std::vector<InterfaceBlock> & shaderStorageBlocks)304 ShaderStorageBlockOutputHLSL::ShaderStorageBlockOutputHLSL(
305     OutputHLSL *outputHLSL,
306     ResourcesHLSL *resourcesHLSL,
307     const std::vector<InterfaceBlock> &shaderStorageBlocks)
308     : mOutputHLSL(outputHLSL),
309       mResourcesHLSL(resourcesHLSL),
310       mShaderStorageBlocks(shaderStorageBlocks)
311 {
312     mSSBOFunctionHLSL = new ShaderStorageBlockFunctionHLSL;
313 }
314 
~ShaderStorageBlockOutputHLSL()315 ShaderStorageBlockOutputHLSL::~ShaderStorageBlockOutputHLSL()
316 {
317     SafeDelete(mSSBOFunctionHLSL);
318 }
319 
outputStoreFunctionCallPrefix(TIntermTyped * node)320 void ShaderStorageBlockOutputHLSL::outputStoreFunctionCallPrefix(TIntermTyped *node)
321 {
322     traverseSSBOAccess(node, SSBOMethod::STORE);
323 }
324 
outputLoadFunctionCall(TIntermTyped * node)325 void ShaderStorageBlockOutputHLSL::outputLoadFunctionCall(TIntermTyped *node)
326 {
327     traverseSSBOAccess(node, SSBOMethod::LOAD);
328     mOutputHLSL->getInfoSink() << ")";
329 }
330 
outputLengthFunctionCall(TIntermTyped * node)331 void ShaderStorageBlockOutputHLSL::outputLengthFunctionCall(TIntermTyped *node)
332 {
333     traverseSSBOAccess(node, SSBOMethod::LENGTH);
334     mOutputHLSL->getInfoSink() << ")";
335 }
336 
outputAtomicMemoryFunctionCallPrefix(TIntermTyped * node,TOperator op)337 void ShaderStorageBlockOutputHLSL::outputAtomicMemoryFunctionCallPrefix(TIntermTyped *node,
338                                                                         TOperator op)
339 {
340     switch (op)
341     {
342         case EOpAtomicAdd:
343             traverseSSBOAccess(node, SSBOMethod::ATOMIC_ADD);
344             break;
345         case EOpAtomicMin:
346             traverseSSBOAccess(node, SSBOMethod::ATOMIC_MIN);
347             break;
348         case EOpAtomicMax:
349             traverseSSBOAccess(node, SSBOMethod::ATOMIC_MAX);
350             break;
351         case EOpAtomicAnd:
352             traverseSSBOAccess(node, SSBOMethod::ATOMIC_AND);
353             break;
354         case EOpAtomicOr:
355             traverseSSBOAccess(node, SSBOMethod::ATOMIC_OR);
356             break;
357         case EOpAtomicXor:
358             traverseSSBOAccess(node, SSBOMethod::ATOMIC_XOR);
359             break;
360         case EOpAtomicExchange:
361             traverseSSBOAccess(node, SSBOMethod::ATOMIC_EXCHANGE);
362             break;
363         case EOpAtomicCompSwap:
364             traverseSSBOAccess(node, SSBOMethod::ATOMIC_COMPSWAP);
365             break;
366         default:
367             UNREACHABLE();
368             break;
369     }
370 }
371 
372 // Note that we must calculate the matrix stride here instead of ShaderStorageBlockFunctionHLSL.
373 // It's because that if the current node's type is a vector which comes from a matrix, we will
374 // lose the matrix type info once we enter ShaderStorageBlockFunctionHLSL.
getMatrixStride(TIntermTyped * node,TLayoutBlockStorage storage,bool rowMajor,bool * isRowMajorMatrix) const375 int ShaderStorageBlockOutputHLSL::getMatrixStride(TIntermTyped *node,
376                                                   TLayoutBlockStorage storage,
377                                                   bool rowMajor,
378                                                   bool *isRowMajorMatrix) const
379 {
380     if (node->getType().isMatrix())
381     {
382         *isRowMajorMatrix = rowMajor;
383         return GetBlockMemberInfoByType(node->getType(), storage, rowMajor).matrixStride;
384     }
385 
386     if (node->getType().isVector())
387     {
388         TIntermBinary *binaryNode = node->getAsBinaryNode();
389         if (binaryNode)
390         {
391             return getMatrixStride(binaryNode->getLeft(), storage, rowMajor, isRowMajorMatrix);
392         }
393         else
394         {
395             TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
396             if (swizzleNode)
397             {
398                 return getMatrixStride(swizzleNode->getOperand(), storage, rowMajor,
399                                        isRowMajorMatrix);
400             }
401         }
402     }
403     return 0;
404 }
405 
collectShaderStorageBlocks(TIntermTyped * node)406 void ShaderStorageBlockOutputHLSL::collectShaderStorageBlocks(TIntermTyped *node)
407 {
408     TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
409     if (swizzleNode)
410     {
411         return collectShaderStorageBlocks(swizzleNode->getOperand());
412     }
413 
414     TIntermBinary *binaryNode = node->getAsBinaryNode();
415     if (binaryNode)
416     {
417         switch (binaryNode->getOp())
418         {
419             case EOpIndexDirectInterfaceBlock:
420             case EOpIndexIndirect:
421             case EOpIndexDirect:
422             case EOpIndexDirectStruct:
423                 return collectShaderStorageBlocks(binaryNode->getLeft());
424             default:
425                 UNREACHABLE();
426                 return;
427         }
428     }
429 
430     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
431     const TType &type               = symbolNode->getType();
432     ASSERT(type.getQualifier() == EvqBuffer);
433     const TVariable &variable = symbolNode->variable();
434 
435     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
436     ASSERT(interfaceBlock);
437     if (mReferencedShaderStorageBlocks.count(interfaceBlock->uniqueId().get()) == 0)
438     {
439         const TVariable *instanceVariable = nullptr;
440         if (type.isInterfaceBlock())
441         {
442             instanceVariable = &variable;
443         }
444         mReferencedShaderStorageBlocks[interfaceBlock->uniqueId().get()] =
445             new TReferencedBlock(interfaceBlock, instanceVariable);
446         GetShaderStorageBlockMembersInfo(interfaceBlock, mShaderStorageBlocks,
447                                          &mBlockMemberInfoMap);
448     }
449 }
450 
traverseSSBOAccess(TIntermTyped * node,SSBOMethod method)451 void ShaderStorageBlockOutputHLSL::traverseSSBOAccess(TIntermTyped *node, SSBOMethod method)
452 {
453     // TODO: Merge collectShaderStorageBlocks and GetBlockLayoutInfo to simplify the code.
454     collectShaderStorageBlocks(node);
455 
456     // Note that we don't have correct BlockMemberInfo from mBlockMemberInfoMap at the current
457     // point. But we must use those information to generate the right function name. So here we have
458     // to calculate them again.
459     TLayoutBlockStorage storage;
460     bool rowMajor;
461     GetBlockLayoutInfo(node, false, &storage, &rowMajor);
462     int unsizedArrayStride = 0;
463     if (node->getType().isUnsizedArray())
464     {
465         // The unsized array member must be the last member of a shader storage block.
466         TIntermBinary *binaryNode = node->getAsBinaryNode();
467         if (binaryNode)
468         {
469             const TInterfaceBlock *interfaceBlock =
470                 binaryNode->getLeft()->getType().getInterfaceBlock();
471             ASSERT(interfaceBlock);
472             const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
473             const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
474             auto fieldInfoIter                = mBlockMemberInfoMap.find(field);
475             ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
476             unsizedArrayStride = fieldInfoIter->second.arrayStride;
477         }
478         else
479         {
480             const TIntermSymbol *symbolNode       = node->getAsSymbolNode();
481             const TVariable &variable             = symbolNode->variable();
482             const TInterfaceBlock *interfaceBlock = symbolNode->getType().getInterfaceBlock();
483             ASSERT(interfaceBlock);
484             const TField *field =
485                 GetFieldMemberInShaderStorageBlock(interfaceBlock, variable.name());
486             auto fieldInfoIter = mBlockMemberInfoMap.find(field);
487             ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
488             unsizedArrayStride = fieldInfoIter->second.arrayStride;
489         }
490     }
491     bool isRowMajorMatrix = false;
492     int matrixStride      = getMatrixStride(node, storage, rowMajor, &isRowMajorMatrix);
493 
494     const TString &functionName = mSSBOFunctionHLSL->registerShaderStorageBlockFunction(
495         node->getType(), method, storage, isRowMajorMatrix, matrixStride, unsizedArrayStride,
496         node->getAsSwizzleNode());
497     TInfoSinkBase &out = mOutputHLSL->getInfoSink();
498     out << functionName;
499     out << "(";
500     BlockMemberInfo blockMemberInfo;
501     TIntermNode *loc = traverseNode(out, node, &blockMemberInfo);
502     out << ", ";
503     loc->traverse(mOutputHLSL);
504 }
505 
writeShaderStorageBlocksHeader(TInfoSinkBase & out) const506 void ShaderStorageBlockOutputHLSL::writeShaderStorageBlocksHeader(TInfoSinkBase &out) const
507 {
508     out << mResourcesHLSL->shaderStorageBlocksHeader(mReferencedShaderStorageBlocks);
509     mSSBOFunctionHLSL->shaderStorageBlockFunctionHeader(out);
510 }
511 
traverseNode(TInfoSinkBase & out,TIntermTyped * node,BlockMemberInfo * blockMemberInfo)512 TIntermTyped *ShaderStorageBlockOutputHLSL::traverseNode(TInfoSinkBase &out,
513                                                          TIntermTyped *node,
514                                                          BlockMemberInfo *blockMemberInfo)
515 {
516     if (TIntermSymbol *symbolNode = node->getAsSymbolNode())
517     {
518         const TVariable &variable = symbolNode->variable();
519         const TType &type         = variable.getType();
520         if (type.isInterfaceBlock())
521         {
522             out << DecorateVariableIfNeeded(variable);
523         }
524         else
525         {
526             const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
527             out << Decorate(interfaceBlock->name());
528             const TField *field =
529                 GetFieldMemberInShaderStorageBlock(interfaceBlock, variable.name());
530             return createFieldOffset(field, blockMemberInfo);
531         }
532     }
533     else if (TIntermSwizzle *swizzleNode = node->getAsSwizzleNode())
534     {
535         return traverseNode(out, swizzleNode->getOperand(), blockMemberInfo);
536     }
537     else if (TIntermBinary *binaryNode = node->getAsBinaryNode())
538     {
539         switch (binaryNode->getOp())
540         {
541             case EOpIndexDirect:
542             {
543                 const TType &leftType = binaryNode->getLeft()->getType();
544                 if (leftType.isInterfaceBlock())
545                 {
546                     ASSERT(leftType.getQualifier() == EvqBuffer);
547                     TIntermSymbol *instanceArraySymbol = binaryNode->getLeft()->getAsSymbolNode();
548 
549                     const int arrayIndex =
550                         binaryNode->getRight()->getAsConstantUnion()->getIConst(0);
551                     out << mResourcesHLSL->InterfaceBlockInstanceString(
552                         instanceArraySymbol->getName(), arrayIndex);
553                 }
554                 else
555                 {
556                     return writeEOpIndexDirectOrIndirectOutput(out, binaryNode, blockMemberInfo);
557                 }
558                 break;
559             }
560             case EOpIndexIndirect:
561             {
562                 // We do not currently support indirect references to interface blocks
563                 ASSERT(binaryNode->getLeft()->getBasicType() != EbtInterfaceBlock);
564                 return writeEOpIndexDirectOrIndirectOutput(out, binaryNode, blockMemberInfo);
565                 break;
566             }
567             case EOpIndexDirectStruct:
568             {
569                 // We do not currently support direct references to interface blocks
570                 ASSERT(binaryNode->getLeft()->getBasicType() != EbtInterfaceBlock);
571                 TIntermTyped *left = traverseNode(out, binaryNode->getLeft(), blockMemberInfo);
572                 const TStructure *structure       = binaryNode->getLeft()->getType().getStruct();
573                 const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
574                 const TField *field               = structure->fields()[index->getIConst(0)];
575                 return Add(createFieldOffset(field, blockMemberInfo), left);
576                 break;
577             }
578             case EOpIndexDirectInterfaceBlock:
579             {
580                 ASSERT(IsInShaderStorageBlock(binaryNode->getLeft()));
581                 traverseNode(out, binaryNode->getLeft(), blockMemberInfo);
582                 const TInterfaceBlock *interfaceBlock =
583                     binaryNode->getLeft()->getType().getInterfaceBlock();
584                 const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
585                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
586                 return createFieldOffset(field, blockMemberInfo);
587                 break;
588             }
589             default:
590                 return nullptr;
591         }
592     }
593     return nullptr;
594 }
595 
writeEOpIndexDirectOrIndirectOutput(TInfoSinkBase & out,TIntermBinary * node,BlockMemberInfo * blockMemberInfo)596 TIntermTyped *ShaderStorageBlockOutputHLSL::writeEOpIndexDirectOrIndirectOutput(
597     TInfoSinkBase &out,
598     TIntermBinary *node,
599     BlockMemberInfo *blockMemberInfo)
600 {
601     ASSERT(IsInShaderStorageBlock(node->getLeft()));
602     TIntermTyped *left  = traverseNode(out, node->getLeft(), blockMemberInfo);
603     TIntermTyped *right = node->getRight()->deepCopy();
604     const TType &type   = node->getLeft()->getType();
605     TLayoutBlockStorage storage;
606     bool rowMajor;
607     GetBlockLayoutInfo(node, false, &storage, &rowMajor);
608 
609     if (type.isArray())
610     {
611         const TSpan<const unsigned int> &arraySizes = type.getArraySizes();
612         for (unsigned int i = 0; i < arraySizes.size() - 1; i++)
613         {
614             right = Mul(CreateUIntNode(arraySizes[i]), right);
615         }
616         right = Mul(CreateUIntNode(blockMemberInfo->arrayStride), right);
617     }
618     else if (type.isMatrix())
619     {
620         if (rowMajor)
621         {
622             right = Mul(CreateUIntNode(BlockLayoutEncoder::kBytesPerComponent), right);
623         }
624         else
625         {
626             right = Mul(CreateUIntNode(blockMemberInfo->matrixStride), right);
627         }
628     }
629     else if (type.isVector())
630     {
631         if (blockMemberInfo->isRowMajorMatrix)
632         {
633             right = Mul(CreateUIntNode(blockMemberInfo->matrixStride), right);
634         }
635         else
636         {
637             right = Mul(CreateUIntNode(BlockLayoutEncoder::kBytesPerComponent), right);
638         }
639     }
640     return Add(left, right);
641 }
642 
createFieldOffset(const TField * field,BlockMemberInfo * blockMemberInfo)643 TIntermTyped *ShaderStorageBlockOutputHLSL::createFieldOffset(const TField *field,
644                                                               BlockMemberInfo *blockMemberInfo)
645 {
646     auto fieldInfoIter = mBlockMemberInfoMap.find(field);
647     ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
648     *blockMemberInfo = fieldInfoIter->second;
649     return CreateUIntNode(blockMemberInfo->offset);
650 }
651 
652 }  // namespace sh
653