1 //#include "Engine/Assert.h"
2 #include "Engine.h"
3 
4 #include "HLSLTree.h"
5 #include <assert.h>
6 #include <map>
7 #include <string>
8 #include <algorithm>
9 
10 namespace M4
11 {
12 
HLSLTree(Allocator * allocator)13 HLSLTree::HLSLTree(Allocator* allocator) :
14     m_allocator(allocator), m_stringPool(allocator)
15 {
16     m_firstPage         = m_allocator->New<NodePage>();
17     m_firstPage->next   = NULL;
18 
19     m_currentPage       = m_firstPage;
20     m_currentPageOffset = 0;
21 
22     m_root              = AddNode<HLSLRoot>(NULL, 1);
23 }
24 
~HLSLTree()25 HLSLTree::~HLSLTree()
26 {
27     NodePage* page = m_firstPage;
28     while (page != NULL)
29     {
30         NodePage* next = page->next;
31         m_allocator->Delete(page);
32         page = next;
33     }
34 }
35 
AllocatePage()36 void HLSLTree::AllocatePage()
37 {
38     NodePage* newPage    = m_allocator->New<NodePage>();
39     newPage->next        = NULL;
40     m_currentPage->next  = newPage;
41     m_currentPageOffset  = 0;
42     m_currentPage        = newPage;
43 }
44 
AddString(const char * string)45 const char* HLSLTree::AddString(const char* string)
46 {
47     return m_stringPool.AddString(string);
48 }
49 
AddStringFormat(const char * format,...)50 const char* HLSLTree::AddStringFormat(const char* format, ...)
51 {
52     va_list args;
53     va_start(args, format);
54     const char * string = m_stringPool.AddStringFormatList(format, args);
55     va_end(args);
56     return string;
57 }
58 
GetContainsString(const char * string) const59 bool HLSLTree::GetContainsString(const char* string) const
60 {
61     return m_stringPool.GetContainsString(string);
62 }
63 
GetRoot() const64 HLSLRoot* HLSLTree::GetRoot() const
65 {
66     return m_root;
67 }
68 
AllocateMemory(size_t size)69 void* HLSLTree::AllocateMemory(size_t size)
70 {
71     if (m_currentPageOffset + size > s_nodePageSize)
72     {
73         AllocatePage();
74     }
75     void* buffer = m_currentPage->buffer + m_currentPageOffset;
76     m_currentPageOffset += size;
77     return buffer;
78 }
79 
80 // @@ This doesn't do any parameter matching. Simply returns the first function with that name.
FindFunction(const char * name)81 HLSLFunction * HLSLTree::FindFunction(const char * name)
82 {
83     HLSLStatement * statement = m_root->statement;
84     while (statement != NULL)
85     {
86         if (statement->nodeType == HLSLNodeType_Function)
87         {
88             HLSLFunction * function = (HLSLFunction *)statement;
89             if (String_Equal(name, function->name))
90             {
91                 return function;
92             }
93         }
94 
95         statement = statement->nextStatement;
96     }
97 
98     return NULL;
99 }
100 
FindGlobalDeclaration(const char * name,HLSLBuffer ** buffer_out)101 HLSLDeclaration * HLSLTree::FindGlobalDeclaration(const char * name, HLSLBuffer ** buffer_out/*=NULL*/)
102 {
103     HLSLStatement * statement = m_root->statement;
104     while (statement != NULL)
105     {
106         if (statement->nodeType == HLSLNodeType_Declaration)
107         {
108             HLSLDeclaration * declaration = (HLSLDeclaration *)statement;
109             if (String_Equal(name, declaration->name))
110             {
111                 if (buffer_out) *buffer_out = NULL;
112                 return declaration;
113             }
114         }
115         else if (statement->nodeType == HLSLNodeType_Buffer)
116         {
117             HLSLBuffer* buffer = (HLSLBuffer*)statement;
118 
119             HLSLDeclaration* field = buffer->field;
120             while (field != NULL)
121             {
122                 ASSERT(field->nodeType == HLSLNodeType_Declaration);
123                 if (String_Equal(name, field->name))
124                 {
125                     if (buffer_out) *buffer_out = buffer;
126                     return field;
127                 }
128                 field = (HLSLDeclaration*)field->nextStatement;
129             }
130         }
131 
132         statement = statement->nextStatement;
133     }
134 
135     if (buffer_out) *buffer_out = NULL;
136     return NULL;
137 }
138 
FindGlobalStruct(const char * name)139 HLSLStruct * HLSLTree::FindGlobalStruct(const char * name)
140 {
141     HLSLStatement * statement = m_root->statement;
142     while (statement != NULL)
143     {
144         if (statement->nodeType == HLSLNodeType_Struct)
145         {
146             HLSLStruct * declaration = (HLSLStruct *)statement;
147             if (String_Equal(name, declaration->name))
148             {
149                 return declaration;
150             }
151         }
152 
153         statement = statement->nextStatement;
154     }
155 
156     return NULL;
157 }
158 
FindTechnique(const char * name)159 HLSLTechnique * HLSLTree::FindTechnique(const char * name)
160 {
161     HLSLStatement * statement = m_root->statement;
162     while (statement != NULL)
163     {
164         if (statement->nodeType == HLSLNodeType_Technique)
165         {
166             HLSLTechnique * technique = (HLSLTechnique *)statement;
167             if (String_Equal(name, technique->name))
168             {
169                 return technique;
170             }
171         }
172 
173         statement = statement->nextStatement;
174     }
175 
176     return NULL;
177 }
178 
FindFirstPipeline()179 HLSLPipeline * HLSLTree::FindFirstPipeline()
180 {
181     return FindNextPipeline(NULL);
182 }
183 
FindNextPipeline(HLSLPipeline * current)184 HLSLPipeline * HLSLTree::FindNextPipeline(HLSLPipeline * current)
185 {
186     HLSLStatement * statement = current ? current : m_root->statement;
187     while (statement != NULL)
188     {
189         if (statement->nodeType == HLSLNodeType_Pipeline)
190         {
191             return (HLSLPipeline *)statement;
192         }
193 
194         statement = statement->nextStatement;
195     }
196 
197     return NULL;
198 }
199 
FindPipeline(const char * name)200 HLSLPipeline * HLSLTree::FindPipeline(const char * name)
201 {
202     HLSLStatement * statement = m_root->statement;
203     while (statement != NULL)
204     {
205         if (statement->nodeType == HLSLNodeType_Pipeline)
206         {
207             HLSLPipeline * pipeline = (HLSLPipeline *)statement;
208             if (String_Equal(name, pipeline->name))
209             {
210                 return pipeline;
211             }
212         }
213 
214         statement = statement->nextStatement;
215     }
216 
217     return NULL;
218 }
219 
FindBuffer(const char * name)220 HLSLBuffer * HLSLTree::FindBuffer(const char * name)
221 {
222     HLSLStatement * statement = m_root->statement;
223     while (statement != NULL)
224     {
225         if (statement->nodeType == HLSLNodeType_Buffer)
226         {
227             HLSLBuffer * buffer = (HLSLBuffer *)statement;
228             if (String_Equal(name, buffer->name))
229             {
230                 return buffer;
231             }
232         }
233 
234         statement = statement->nextStatement;
235     }
236 
237     return NULL;
238 }
239 
240 
241 
GetExpressionValue(HLSLExpression * expression,int & value)242 bool HLSLTree::GetExpressionValue(HLSLExpression * expression, int & value)
243 {
244     ASSERT (expression != NULL);
245 
246     // Expression must be constant.
247     if ((expression->expressionType.flags & HLSLTypeFlag_Const) == 0)
248     {
249         return false;
250     }
251 
252     // We are expecting an integer scalar. @@ Add support for type conversion from other scalar types.
253     if (expression->expressionType.baseType != HLSLBaseType_Int &&
254         expression->expressionType.baseType != HLSLBaseType_Bool)
255     {
256         return false;
257     }
258 
259     if (expression->expressionType.array)
260     {
261         return false;
262     }
263 
264     if (expression->nodeType == HLSLNodeType_BinaryExpression)
265     {
266         HLSLBinaryExpression * binaryExpression = (HLSLBinaryExpression *)expression;
267 
268         int value1, value2;
269         if (!GetExpressionValue(binaryExpression->expression1, value1) ||
270             !GetExpressionValue(binaryExpression->expression2, value2))
271         {
272             return false;
273         }
274 
275         switch(binaryExpression->binaryOp)
276         {
277             case HLSLBinaryOp_And:
278                 value = value1 && value2;
279                 return true;
280             case HLSLBinaryOp_Or:
281                 value = value1 || value2;
282                 return true;
283             case HLSLBinaryOp_Add:
284                 value = value1 + value2;
285                 return true;
286             case HLSLBinaryOp_Sub:
287                 value = value1 - value2;
288                 return true;
289             case HLSLBinaryOp_Mul:
290                 value = value1 * value2;
291                 return true;
292             case HLSLBinaryOp_Div:
293                 value = value1 / value2;
294                 return true;
295             case HLSLBinaryOp_Mod:
296                 value = value1 % value2;
297                 return true;
298             case HLSLBinaryOp_Less:
299                 value = value1 < value2;
300                 return true;
301             case HLSLBinaryOp_Greater:
302                 value = value1 > value2;
303                 return true;
304             case HLSLBinaryOp_LessEqual:
305                 value = value1 <= value2;
306                 return true;
307             case HLSLBinaryOp_GreaterEqual:
308                 value = value1 >= value2;
309                 return true;
310             case HLSLBinaryOp_Equal:
311                 value = value1 == value2;
312                 return true;
313             case HLSLBinaryOp_NotEqual:
314                 value = value1 != value2;
315                 return true;
316             case HLSLBinaryOp_BitAnd:
317                 value = value1 & value2;
318                 return true;
319             case HLSLBinaryOp_BitOr:
320                 value = value1 | value2;
321                 return true;
322             case HLSLBinaryOp_BitXor:
323                 value = value1 ^ value2;
324                 return true;
325             case HLSLBinaryOp_Assign:
326             case HLSLBinaryOp_AddAssign:
327             case HLSLBinaryOp_SubAssign:
328             case HLSLBinaryOp_MulAssign:
329             case HLSLBinaryOp_DivAssign:
330                 // IC: These are not valid on non-constant expressions and should fail earlier when querying expression value.
331                 return false;
332         }
333     }
334     else if (expression->nodeType == HLSLNodeType_UnaryExpression)
335     {
336         HLSLUnaryExpression * unaryExpression = (HLSLUnaryExpression *)expression;
337 
338         if (!GetExpressionValue(unaryExpression->expression, value))
339         {
340             return false;
341         }
342 
343         switch(unaryExpression->unaryOp)
344         {
345             case HLSLUnaryOp_Negative:
346                 value = -value;
347                 return true;
348             case HLSLUnaryOp_Positive:
349                 // nop.
350                 return true;
351             case HLSLUnaryOp_Not:
352                 value = !value;
353                 return true;
354             case HLSLUnaryOp_BitNot:
355                 value = ~value;
356                 return true;
357             case HLSLUnaryOp_PostDecrement:
358             case HLSLUnaryOp_PostIncrement:
359             case HLSLUnaryOp_PreDecrement:
360             case HLSLUnaryOp_PreIncrement:
361                 // IC: These are not valid on non-constant expressions and should fail earlier when querying expression value.
362                 return false;
363         }
364     }
365     else if (expression->nodeType == HLSLNodeType_IdentifierExpression)
366     {
367         HLSLIdentifierExpression * identifier = (HLSLIdentifierExpression *)expression;
368 
369         HLSLDeclaration * declaration = FindGlobalDeclaration(identifier->name);
370         if (declaration == NULL)
371         {
372             return false;
373         }
374         if ((declaration->type.flags & HLSLTypeFlag_Const) == 0)
375         {
376             return false;
377         }
378 
379         return GetExpressionValue(declaration->assignment, value);
380     }
381     else if (expression->nodeType == HLSLNodeType_LiteralExpression)
382     {
383         HLSLLiteralExpression * literal = (HLSLLiteralExpression *)expression;
384 
385         if (literal->expressionType.baseType == HLSLBaseType_Int) value = literal->iValue;
386         else if (literal->expressionType.baseType == HLSLBaseType_Bool) value = (int)literal->bValue;
387         else return false;
388 
389         return true;
390     }
391 
392     return false;
393 }
394 
NeedsFunction(const char * name)395 bool HLSLTree::NeedsFunction(const char* name)
396 {
397     // Early out
398     if (!GetContainsString(name))
399         return false;
400 
401     struct NeedsFunctionVisitor: HLSLTreeVisitor
402     {
403         const char* name;
404         bool result;
405 
406         virtual void VisitTopLevelStatement(HLSLStatement * node)
407         {
408             if (!node->hidden)
409                 HLSLTreeVisitor::VisitTopLevelStatement(node);
410         }
411 
412         virtual void VisitFunctionCall(HLSLFunctionCall * node)
413         {
414             result = result || String_Equal(name, node->function->name);
415 
416             HLSLTreeVisitor::VisitFunctionCall(node);
417         }
418     };
419 
420     NeedsFunctionVisitor visitor;
421     visitor.name = name;
422     visitor.result = false;
423 
424     visitor.VisitRoot(m_root);
425 
426     return visitor.result;
427 }
428 
GetVectorDimension(HLSLType & type)429 int GetVectorDimension(HLSLType & type)
430 {
431     if (type.baseType >= HLSLBaseType_FirstNumeric &&
432         type.baseType <= HLSLBaseType_LastNumeric)
433     {
434         if (type.baseType == HLSLBaseType_Float) return 1;
435         if (type.baseType == HLSLBaseType_Float2) return 2;
436         if (type.baseType == HLSLBaseType_Float3) return 3;
437         if (type.baseType == HLSLBaseType_Float4) return 4;
438 
439     }
440     return 0;
441 }
442 
443 // Returns dimension, 0 if invalid.
GetExpressionValue(HLSLExpression * expression,float values[4])444 int HLSLTree::GetExpressionValue(HLSLExpression * expression, float values[4])
445 {
446     ASSERT (expression != NULL);
447 
448     // Expression must be constant.
449     if ((expression->expressionType.flags & HLSLTypeFlag_Const) == 0)
450     {
451         return 0;
452     }
453 
454     if (expression->expressionType.baseType == HLSLBaseType_Int ||
455         expression->expressionType.baseType == HLSLBaseType_Bool)
456     {
457         int int_value;
458         if (GetExpressionValue(expression, int_value)) {
459             for (int i = 0; i < 4; i++) values[i] = (float)int_value;   // @@ Warn if conversion is not exact.
460             return 1;
461         }
462 
463         return 0;
464     }
465     if (expression->expressionType.baseType >= HLSLBaseType_FirstInteger && expression->expressionType.baseType <= HLSLBaseType_LastInteger)
466     {
467         // @@ Add support for uints?
468         // @@ Add support for int vectors?
469         return 0;
470     }
471     if (expression->expressionType.baseType > HLSLBaseType_LastNumeric)
472     {
473         return 0;
474     }
475 
476     // @@ Not supported yet, but we may need it?
477     if (expression->expressionType.array)
478     {
479         return false;
480     }
481 
482     if (expression->nodeType == HLSLNodeType_BinaryExpression)
483     {
484         HLSLBinaryExpression * binaryExpression = (HLSLBinaryExpression *)expression;
485         int dim = GetVectorDimension(binaryExpression->expressionType);
486 
487         float values1[4], values2[4];
488         int dim1 = GetExpressionValue(binaryExpression->expression1, values1);
489         int dim2 = GetExpressionValue(binaryExpression->expression2, values2);
490 
491         if (dim1 == 0 || dim2 == 0)
492         {
493             return 0;
494         }
495 
496         if (dim1 != dim2)
497         {
498             // Brodacast scalar to vector size.
499             if (dim1 == 1)
500             {
501                 for (int i = 1; i < dim2; i++) values1[i] = values1[0];
502                 dim1 = dim2;
503             }
504             else if (dim2 == 1)
505             {
506                 for (int i = 1; i < dim1; i++) values2[i] = values2[0];
507                 dim2 = dim1;
508             }
509             else
510             {
511                 return 0;
512             }
513         }
514         ASSERT(dim == dim1);
515 
516         switch(binaryExpression->binaryOp)
517         {
518             case HLSLBinaryOp_Add:
519                 for (int i = 0; i < dim; i++) values[i] = values1[i] + values2[i];
520                 return dim;
521             case HLSLBinaryOp_Sub:
522                 for (int i = 0; i < dim; i++) values[i] = values1[i] - values2[i];
523                 return dim;
524             case HLSLBinaryOp_Mul:
525                 for (int i = 0; i < dim; i++) values[i] = values1[i] * values2[i];
526                 return dim;
527             case HLSLBinaryOp_Div:
528                 for (int i = 0; i < dim; i++) values[i] = values1[i] / values2[i];
529                 return dim;
530             case HLSLBinaryOp_Mod:
531                 for (int i = 0; i < dim; i++) values[i] = int(values1[i]) % int(values2[i]);
532                 return dim;
533             default:
534                 return 0;
535         }
536     }
537     else if (expression->nodeType == HLSLNodeType_UnaryExpression)
538     {
539         HLSLUnaryExpression * unaryExpression = (HLSLUnaryExpression *)expression;
540         int dim = GetVectorDimension(unaryExpression->expressionType);
541 
542         int dim1 = GetExpressionValue(unaryExpression->expression, values);
543         if (dim1 == 0)
544         {
545             return 0;
546         }
547         ASSERT(dim == dim1);
548 
549         switch(unaryExpression->unaryOp)
550         {
551             case HLSLUnaryOp_Negative:
552                 for (int i = 0; i < dim; i++) values[i] = -values[i];
553                 return dim;
554             case HLSLUnaryOp_Positive:
555                 // nop.
556                 return dim;
557             default:
558                 return 0;
559         }
560     }
561     else if (expression->nodeType == HLSLNodeType_ConstructorExpression)
562     {
563         HLSLConstructorExpression * constructor = (HLSLConstructorExpression *)expression;
564 
565         int dim = GetVectorDimension(constructor->expressionType);
566 
567         int idx = 0;
568         HLSLExpression * arg = constructor->argument;
569         while (arg != NULL)
570         {
571             float tmp[4];
572             int n = GetExpressionValue(arg, tmp);
573             for (int i = 0; i < n; i++) values[idx + i] = tmp[i];
574             idx += n;
575 
576             arg = arg->nextExpression;
577         }
578         ASSERT(dim == idx);
579 
580         return dim;
581     }
582     else if (expression->nodeType == HLSLNodeType_IdentifierExpression)
583     {
584         HLSLIdentifierExpression * identifier = (HLSLIdentifierExpression *)expression;
585 
586         HLSLDeclaration * declaration = FindGlobalDeclaration(identifier->name);
587         if (declaration == NULL)
588         {
589             return 0;
590         }
591         if ((declaration->type.flags & HLSLTypeFlag_Const) == 0)
592         {
593             return 0;
594         }
595 
596         return GetExpressionValue(declaration->assignment, values);
597     }
598     else if (expression->nodeType == HLSLNodeType_LiteralExpression)
599     {
600         HLSLLiteralExpression * literal = (HLSLLiteralExpression *)expression;
601 
602         if (literal->expressionType.baseType == HLSLBaseType_Float) values[0] = literal->fValue;
603         else if (literal->expressionType.baseType == HLSLBaseType_Bool) values[0] = literal->bValue;
604         else if (literal->expressionType.baseType == HLSLBaseType_Int) values[0] = (float)literal->iValue;  // @@ Warn if conversion is not exact.
605         else return 0;
606 
607         return 1;
608     }
609 
610     return 0;
611 }
612 
ReplaceUniformsAssignments()613 bool HLSLTree::ReplaceUniformsAssignments()
614 {
615     struct ReplaceUniformsAssignmentsVisitor: HLSLTreeVisitor
616     {
617         HLSLTree * tree;
618         std::map<std::string, HLSLDeclaration *> uniforms;
619         std::map<std::string, std::string> uniformsReplaced;
620         bool withinAssignment;
621 
622         virtual void VisitDeclaration(HLSLDeclaration * node)
623         {
624             HLSLTreeVisitor::VisitDeclaration(node);
625 
626             // Enumerate uniforms
627             if (node->type.flags & HLSLTypeFlag_Uniform)
628             {
629                 uniforms[node->name] = node;
630             }
631         }
632 
633         virtual void VisitFunction(HLSLFunction * node)
634         {
635             uniformsReplaced.clear();
636 
637             // Detect uniforms assignments
638             HLSLTreeVisitor::VisitFunction(node);
639 
640             // Declare uniforms replacements
641             std::map<std::string, std::string>::const_iterator iter = uniformsReplaced.cbegin();
642             for ( ; iter != uniformsReplaced.cend(); ++iter)
643             {
644                 HLSLDeclaration * uniformDeclaration = uniforms[iter->first];
645                 HLSLDeclaration * declaration = tree->AddNode<HLSLDeclaration>(node->fileName, node->line);
646 
647                 declaration->name = tree->AddString(iter->second.c_str());
648                 declaration->type = uniformDeclaration->type;
649 
650                 // Add declaration within function statements
651                 declaration->nextStatement = node->statement;
652                 node->statement = declaration;
653             }
654         }
655 
656         virtual void VisitBinaryExpression(HLSLBinaryExpression * node)
657         {
658             // Visit expression 2 first to not replace possible uniform reading
659             VisitExpression(node->expression2);
660 
661             if (IsAssignOp(node->binaryOp))
662             {
663                 withinAssignment = true;
664             }
665 
666             VisitExpression(node->expression1);
667 
668             withinAssignment = false;
669         }
670 
671         virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node)
672         {
673             if (withinAssignment)
674             {
675                 // Check if variable is a uniform
676                 if (uniforms.find(node->name) != uniforms.end())
677                 {
678                     // Check if variable is not already replaced
679                     if (uniformsReplaced.find(node->name) == uniformsReplaced.end())
680                     {
681                         std::string newName(node->name);
682                         do
683                         {
684                             newName.insert(0, "new");
685                         }
686                         while(tree->GetContainsString(newName.c_str()));
687 
688                         uniformsReplaced[node->name] = newName;
689                     }
690                 }
691             }
692 
693             // Check if variable need to be replaced
694             if (uniformsReplaced.find(node->name) != uniformsReplaced.end())
695             {
696                 // Replace
697                 node->name = tree->AddString( uniformsReplaced[node->name].c_str() );
698             }
699         }
700     };
701 
702     ReplaceUniformsAssignmentsVisitor visitor;
703     visitor.tree = this;
704     visitor.withinAssignment = false;
705     visitor.VisitRoot(m_root);
706 
707     return true;
708 }
709 
710 
matrixCtorBuilder(HLSLType type,HLSLExpression * arguments)711 matrixCtor matrixCtorBuilder(HLSLType type, HLSLExpression * arguments) {
712     matrixCtor ctor;
713 
714     ctor.matrixType = type.baseType;
715 
716     // Fetch all arguments
717     HLSLExpression* argument = arguments;
718     while (argument != NULL)
719     {
720         ctor.argumentTypes.push_back(argument->expressionType.baseType);
721         argument = argument->nextExpression;
722     }
723 
724     return ctor;
725 }
726 
EnumerateMatrixCtorsNeeded(std::vector<matrixCtor> & matrixCtors)727 void HLSLTree::EnumerateMatrixCtorsNeeded(std::vector<matrixCtor> & matrixCtors) {
728 
729     struct EnumerateMatrixCtorsVisitor: HLSLTreeVisitor
730     {
731         std::vector<matrixCtor> matrixCtorsNeeded;
732 
733         virtual void VisitConstructorExpression(HLSLConstructorExpression * node)
734         {
735             if (IsMatrixType(node->expressionType.baseType))
736             {
737                 matrixCtor ctor = matrixCtorBuilder(node->expressionType, node->argument);
738 
739                 if (std::find(matrixCtorsNeeded.cbegin(), matrixCtorsNeeded.cend(), ctor) == matrixCtorsNeeded.cend())
740                 {
741                     matrixCtorsNeeded.push_back(ctor);
742                 }
743             }
744 
745             HLSLTreeVisitor::VisitConstructorExpression(node);
746         }
747 
748         virtual void VisitDeclaration(HLSLDeclaration * node)
749         {
750             if (    IsMatrixType(node->type.baseType) &&
751                     (node->type.flags & HLSLArgumentModifier_Uniform) == 0 )
752             {
753                 matrixCtor ctor = matrixCtorBuilder(node->type, node->assignment);
754 
755                 // No special constructor needed if it already a matrix
756                 bool matrixArgument = false;
757                 for(HLSLBaseType & type: ctor.argumentTypes)
758                 {
759                     if (IsMatrixType(type))
760                     {
761                         matrixArgument = true;
762                         break;
763                     }
764                 }
765 
766                 if (    !matrixArgument &&
767                         std::find(matrixCtorsNeeded.cbegin(), matrixCtorsNeeded.cend(), ctor) == matrixCtorsNeeded.cend())
768                 {
769                     matrixCtorsNeeded.push_back(ctor);
770                 }
771             }
772 
773             HLSLTreeVisitor::VisitDeclaration(node);
774         }
775     };
776 
777     EnumerateMatrixCtorsVisitor visitor;
778     visitor.VisitRoot(m_root);
779 
780     matrixCtors = visitor.matrixCtorsNeeded;
781 }
782 
783 
VisitType(HLSLType & type)784 void HLSLTreeVisitor::VisitType(HLSLType & type)
785 {
786 }
787 
VisitRoot(HLSLRoot * root)788 void HLSLTreeVisitor::VisitRoot(HLSLRoot * root)
789 {
790     HLSLStatement * statement = root->statement;
791     while (statement != NULL) {
792         VisitTopLevelStatement(statement);
793         statement = statement->nextStatement;
794     }
795 }
796 
VisitTopLevelStatement(HLSLStatement * node)797 void HLSLTreeVisitor::VisitTopLevelStatement(HLSLStatement * node)
798 {
799     if (node->nodeType == HLSLNodeType_Declaration) {
800         VisitDeclaration((HLSLDeclaration *)node);
801     }
802     else if (node->nodeType == HLSLNodeType_Struct) {
803         VisitStruct((HLSLStruct *)node);
804     }
805     else if (node->nodeType == HLSLNodeType_Buffer) {
806         VisitBuffer((HLSLBuffer *)node);
807     }
808     else if (node->nodeType == HLSLNodeType_Function) {
809         VisitFunction((HLSLFunction *)node);
810     }
811     else if (node->nodeType == HLSLNodeType_Technique) {
812         VisitTechnique((HLSLTechnique *)node);
813     }
814     else if (node->nodeType == HLSLNodeType_Pipeline) {
815         VisitPipeline((HLSLPipeline *)node);
816     }
817     else {
818         ASSERT(0);
819     }
820 }
821 
VisitStatements(HLSLStatement * statement)822 void HLSLTreeVisitor::VisitStatements(HLSLStatement * statement)
823 {
824     while (statement != NULL) {
825         VisitStatement(statement);
826         statement = statement->nextStatement;
827     }
828 }
829 
VisitStatement(HLSLStatement * node)830 void HLSLTreeVisitor::VisitStatement(HLSLStatement * node)
831 {
832     // Function statements
833     if (node->nodeType == HLSLNodeType_Declaration) {
834         VisitDeclaration((HLSLDeclaration *)node);
835     }
836     else if (node->nodeType == HLSLNodeType_ExpressionStatement) {
837         VisitExpressionStatement((HLSLExpressionStatement *)node);
838     }
839     else if (node->nodeType == HLSLNodeType_ReturnStatement) {
840         VisitReturnStatement((HLSLReturnStatement *)node);
841     }
842     else if (node->nodeType == HLSLNodeType_DiscardStatement) {
843         VisitDiscardStatement((HLSLDiscardStatement *)node);
844     }
845     else if (node->nodeType == HLSLNodeType_BreakStatement) {
846         VisitBreakStatement((HLSLBreakStatement *)node);
847     }
848     else if (node->nodeType == HLSLNodeType_ContinueStatement) {
849         VisitContinueStatement((HLSLContinueStatement *)node);
850     }
851     else if (node->nodeType == HLSLNodeType_IfStatement) {
852         VisitIfStatement((HLSLIfStatement *)node);
853     }
854     else if (node->nodeType == HLSLNodeType_ForStatement) {
855         VisitForStatement((HLSLForStatement *)node);
856     }
857     else if (node->nodeType == HLSLNodeType_WhileStatement) {
858         VisitWhileStatement((HLSLWhileStatement *)node);
859     }
860     else if (node->nodeType == HLSLNodeType_BlockStatement) {
861         VisitBlockStatement((HLSLBlockStatement *)node);
862     }
863     else {
864         ASSERT(0);
865     }
866 }
867 
VisitDeclaration(HLSLDeclaration * node)868 void HLSLTreeVisitor::VisitDeclaration(HLSLDeclaration * node)
869 {
870     VisitType(node->type);
871     /*do {
872         VisitExpression(node->assignment);
873         node = node->nextDeclaration;
874     } while (node);*/
875     if (node->assignment != NULL) {
876         VisitExpression(node->assignment);
877     }
878     if (node->nextDeclaration != NULL) {
879         VisitDeclaration(node->nextDeclaration);
880     }
881 }
882 
VisitStruct(HLSLStruct * node)883 void HLSLTreeVisitor::VisitStruct(HLSLStruct * node)
884 {
885     HLSLStructField * field = node->field;
886     while (field != NULL) {
887         VisitStructField(field);
888         field = field->nextField;
889     }
890 }
891 
VisitStructField(HLSLStructField * node)892 void HLSLTreeVisitor::VisitStructField(HLSLStructField * node)
893 {
894     VisitType(node->type);
895 }
896 
VisitBuffer(HLSLBuffer * node)897 void HLSLTreeVisitor::VisitBuffer(HLSLBuffer * node)
898 {
899     HLSLDeclaration * field = node->field;
900     while (field != NULL) {
901         ASSERT(field->nodeType == HLSLNodeType_Declaration);
902         VisitDeclaration(field);
903         ASSERT(field->nextDeclaration == NULL);
904         field = (HLSLDeclaration *)field->nextStatement;
905     }
906 }
907 
908 /*void HLSLTreeVisitor::VisitBufferField(HLSLBufferField * node)
909 {
910     VisitType(node->type);
911 }*/
912 
VisitFunction(HLSLFunction * node)913 void HLSLTreeVisitor::VisitFunction(HLSLFunction * node)
914 {
915     VisitType(node->returnType);
916 
917     HLSLArgument * argument = node->argument;
918     while (argument != NULL) {
919         VisitArgument(argument);
920         argument = argument->nextArgument;
921     }
922 
923     VisitStatements(node->statement);
924 }
925 
VisitArgument(HLSLArgument * node)926 void HLSLTreeVisitor::VisitArgument(HLSLArgument * node)
927 {
928     VisitType(node->type);
929     if (node->defaultValue != NULL) {
930         VisitExpression(node->defaultValue);
931     }
932 }
933 
VisitExpressionStatement(HLSLExpressionStatement * node)934 void HLSLTreeVisitor::VisitExpressionStatement(HLSLExpressionStatement * node)
935 {
936     VisitExpression(node->expression);
937 }
938 
VisitExpression(HLSLExpression * node)939 void HLSLTreeVisitor::VisitExpression(HLSLExpression * node)
940 {
941     VisitType(node->expressionType);
942 
943     if (node->nodeType == HLSLNodeType_UnaryExpression) {
944         VisitUnaryExpression((HLSLUnaryExpression *)node);
945     }
946     else if (node->nodeType == HLSLNodeType_BinaryExpression) {
947         VisitBinaryExpression((HLSLBinaryExpression *)node);
948     }
949     else if (node->nodeType == HLSLNodeType_ConditionalExpression) {
950         VisitConditionalExpression((HLSLConditionalExpression *)node);
951     }
952     else if (node->nodeType == HLSLNodeType_CastingExpression) {
953         VisitCastingExpression((HLSLCastingExpression *)node);
954     }
955     else if (node->nodeType == HLSLNodeType_LiteralExpression) {
956         VisitLiteralExpression((HLSLLiteralExpression *)node);
957     }
958     else if (node->nodeType == HLSLNodeType_IdentifierExpression) {
959         VisitIdentifierExpression((HLSLIdentifierExpression *)node);
960     }
961     else if (node->nodeType == HLSLNodeType_ConstructorExpression) {
962         VisitConstructorExpression((HLSLConstructorExpression *)node);
963     }
964     else if (node->nodeType == HLSLNodeType_MemberAccess) {
965         VisitMemberAccess((HLSLMemberAccess *)node);
966     }
967     else if (node->nodeType == HLSLNodeType_ArrayAccess) {
968         VisitArrayAccess((HLSLArrayAccess *)node);
969     }
970     else if (node->nodeType == HLSLNodeType_FunctionCall) {
971         VisitFunctionCall((HLSLFunctionCall *)node);
972     }
973     // Acoget-TODO: This was missing. Did adding it break anything?
974     else if (node->nodeType == HLSLNodeType_SamplerState) {
975         VisitSamplerState((HLSLSamplerState *)node);
976     }
977     else {
978         ASSERT(0);
979     }
980 }
981 
VisitReturnStatement(HLSLReturnStatement * node)982 void HLSLTreeVisitor::VisitReturnStatement(HLSLReturnStatement * node)
983 {
984     VisitExpression(node->expression);
985 }
986 
VisitDiscardStatement(HLSLDiscardStatement * node)987 void HLSLTreeVisitor::VisitDiscardStatement(HLSLDiscardStatement * node) {}
VisitBreakStatement(HLSLBreakStatement * node)988 void HLSLTreeVisitor::VisitBreakStatement(HLSLBreakStatement * node) {}
VisitContinueStatement(HLSLContinueStatement * node)989 void HLSLTreeVisitor::VisitContinueStatement(HLSLContinueStatement * node) {}
990 
VisitIfStatement(HLSLIfStatement * node)991 void HLSLTreeVisitor::VisitIfStatement(HLSLIfStatement * node)
992 {
993     VisitExpression(node->condition);
994     VisitStatements(node->statement);
995     if (node->elseStatement) {
996         VisitStatements(node->elseStatement);
997     }
998 }
999 
VisitForStatement(HLSLForStatement * node)1000 void HLSLTreeVisitor::VisitForStatement(HLSLForStatement * node)
1001 {
1002     if (node->initialization) {
1003         VisitDeclaration(node->initialization);
1004     }
1005     if (node->condition) {
1006         VisitExpression(node->condition);
1007     }
1008     if (node->increment) {
1009         VisitExpression(node->increment);
1010     }
1011     VisitStatements(node->statement);
1012 }
1013 
VisitWhileStatement(HLSLWhileStatement * node)1014 void HLSLTreeVisitor::VisitWhileStatement(HLSLWhileStatement * node)
1015 {
1016     if (node->condition) {
1017         VisitExpression(node->condition);
1018     }
1019     VisitStatements(node->statement);
1020 }
1021 
VisitBlockStatement(HLSLBlockStatement * node)1022 void HLSLTreeVisitor::VisitBlockStatement(HLSLBlockStatement * node)
1023 {
1024     VisitStatements(node->statement);
1025 }
1026 
VisitUnaryExpression(HLSLUnaryExpression * node)1027 void HLSLTreeVisitor::VisitUnaryExpression(HLSLUnaryExpression * node)
1028 {
1029     VisitExpression(node->expression);
1030 }
1031 
VisitBinaryExpression(HLSLBinaryExpression * node)1032 void HLSLTreeVisitor::VisitBinaryExpression(HLSLBinaryExpression * node)
1033 {
1034     VisitExpression(node->expression1);
1035     VisitExpression(node->expression2);
1036 }
1037 
VisitConditionalExpression(HLSLConditionalExpression * node)1038 void HLSLTreeVisitor::VisitConditionalExpression(HLSLConditionalExpression * node)
1039 {
1040     VisitExpression(node->condition);
1041     VisitExpression(node->falseExpression);
1042     VisitExpression(node->trueExpression);
1043 }
1044 
VisitCastingExpression(HLSLCastingExpression * node)1045 void HLSLTreeVisitor::VisitCastingExpression(HLSLCastingExpression * node)
1046 {
1047     VisitType(node->type);
1048     VisitExpression(node->expression);
1049 }
1050 
VisitLiteralExpression(HLSLLiteralExpression * node)1051 void HLSLTreeVisitor::VisitLiteralExpression(HLSLLiteralExpression * node) {}
VisitIdentifierExpression(HLSLIdentifierExpression * node)1052 void HLSLTreeVisitor::VisitIdentifierExpression(HLSLIdentifierExpression * node) {}
1053 
VisitConstructorExpression(HLSLConstructorExpression * node)1054 void HLSLTreeVisitor::VisitConstructorExpression(HLSLConstructorExpression * node)
1055 {
1056     HLSLExpression * argument = node->argument;
1057     while (argument != NULL) {
1058         VisitExpression(argument);
1059         argument = argument->nextExpression;
1060     }
1061 }
1062 
VisitMemberAccess(HLSLMemberAccess * node)1063 void HLSLTreeVisitor::VisitMemberAccess(HLSLMemberAccess * node)
1064 {
1065     VisitExpression(node->object);
1066 }
1067 
VisitArrayAccess(HLSLArrayAccess * node)1068 void HLSLTreeVisitor::VisitArrayAccess(HLSLArrayAccess * node)
1069 {
1070     VisitExpression(node->array);
1071     VisitExpression(node->index);
1072 }
1073 
VisitFunctionCall(HLSLFunctionCall * node)1074 void HLSLTreeVisitor::VisitFunctionCall(HLSLFunctionCall * node)
1075 {
1076     HLSLExpression * argument = node->argument;
1077     while (argument != NULL) {
1078         VisitExpression(argument);
1079         argument = argument->nextExpression;
1080     }
1081 }
1082 
VisitStateAssignment(HLSLStateAssignment * node)1083 void HLSLTreeVisitor::VisitStateAssignment(HLSLStateAssignment * node) {}
1084 
VisitSamplerState(HLSLSamplerState * node)1085 void HLSLTreeVisitor::VisitSamplerState(HLSLSamplerState * node)
1086 {
1087     HLSLStateAssignment * stateAssignment = node->stateAssignments;
1088     while (stateAssignment != NULL) {
1089         VisitStateAssignment(stateAssignment);
1090         stateAssignment = stateAssignment->nextStateAssignment;
1091     }
1092 }
1093 
VisitPass(HLSLPass * node)1094 void HLSLTreeVisitor::VisitPass(HLSLPass * node)
1095 {
1096     HLSLStateAssignment * stateAssignment = node->stateAssignments;
1097     while (stateAssignment != NULL) {
1098         VisitStateAssignment(stateAssignment);
1099         stateAssignment = stateAssignment->nextStateAssignment;
1100     }
1101 }
1102 
VisitTechnique(HLSLTechnique * node)1103 void HLSLTreeVisitor::VisitTechnique(HLSLTechnique * node)
1104 {
1105     HLSLPass * pass = node->passes;
1106     while (pass != NULL) {
1107         VisitPass(pass);
1108         pass = pass->nextPass;
1109     }
1110 }
1111 
VisitPipeline(HLSLPipeline * node)1112 void HLSLTreeVisitor::VisitPipeline(HLSLPipeline * node)
1113 {
1114     // @@ ?
1115 }
1116 
VisitFunctions(HLSLRoot * root)1117 void HLSLTreeVisitor::VisitFunctions(HLSLRoot * root)
1118 {
1119     HLSLStatement * statement = root->statement;
1120     while (statement != NULL) {
1121         if (statement->nodeType == HLSLNodeType_Function) {
1122             VisitFunction((HLSLFunction *)statement);
1123         }
1124 
1125         statement = statement->nextStatement;
1126     }
1127 }
1128 
VisitParameters(HLSLRoot * root)1129 void HLSLTreeVisitor::VisitParameters(HLSLRoot * root)
1130 {
1131     HLSLStatement * statement = root->statement;
1132     while (statement != NULL) {
1133         if (statement->nodeType == HLSLNodeType_Declaration) {
1134             VisitDeclaration((HLSLDeclaration *)statement);
1135         }
1136 
1137         statement = statement->nextStatement;
1138     }
1139 }
1140 
1141 
1142 class ResetHiddenFlagVisitor : public HLSLTreeVisitor
1143 {
1144 public:
VisitTopLevelStatement(HLSLStatement * statement)1145     virtual void VisitTopLevelStatement(HLSLStatement * statement)
1146     {
1147         statement->hidden = true;
1148 
1149         if (statement->nodeType == HLSLNodeType_Buffer)
1150         {
1151             VisitBuffer((HLSLBuffer*)statement);
1152         }
1153     }
1154 
1155     // Hide buffer fields.
VisitDeclaration(HLSLDeclaration * node)1156     virtual void VisitDeclaration(HLSLDeclaration * node)
1157     {
1158         node->hidden = true;
1159     }
1160 
VisitArgument(HLSLArgument * node)1161     virtual void VisitArgument(HLSLArgument * node)
1162     {
1163         node->hidden = false;   // Arguments are visible by default.
1164     }
1165 };
1166 
1167 class MarkVisibleStatementsVisitor : public HLSLTreeVisitor
1168 {
1169 public:
1170     HLSLTree * tree;
MarkVisibleStatementsVisitor(HLSLTree * _tree)1171     MarkVisibleStatementsVisitor(HLSLTree * _tree) : tree(_tree) {}
1172 
VisitFunction(HLSLFunction * node)1173     virtual void VisitFunction(HLSLFunction * node)
1174     {
1175         node->hidden = false;
1176         HLSLTreeVisitor::VisitFunction(node);
1177 
1178         if (node->forward)
1179             VisitFunction(node->forward);
1180     }
1181 
VisitFunctionCall(HLSLFunctionCall * node)1182     virtual void VisitFunctionCall(HLSLFunctionCall * node)
1183     {
1184         HLSLTreeVisitor::VisitFunctionCall(node);
1185 
1186         if (node->function->hidden)
1187         {
1188             VisitFunction(const_cast<HLSLFunction*>(node->function));
1189         }
1190     }
1191 
VisitIdentifierExpression(HLSLIdentifierExpression * node)1192     virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node)
1193     {
1194         HLSLTreeVisitor::VisitIdentifierExpression(node);
1195 
1196         if (node->global)
1197         {
1198             HLSLDeclaration * declaration = tree->FindGlobalDeclaration(node->name);
1199             if (declaration != NULL && declaration->hidden)
1200             {
1201                 declaration->hidden = false;
1202                 VisitDeclaration(declaration);
1203             }
1204         }
1205     }
1206 
VisitType(HLSLType & type)1207     virtual void VisitType(HLSLType & type)
1208     {
1209         if (type.baseType == HLSLBaseType_UserDefined)
1210         {
1211             HLSLStruct * globalStruct = tree->FindGlobalStruct(type.typeName);
1212             if (globalStruct != NULL)
1213             {
1214                 globalStruct->hidden = false;
1215                 VisitStruct(globalStruct);
1216             }
1217         }
1218     }
1219 
1220 };
1221 
1222 
PruneTree(HLSLTree * tree,const char * entryName0,const char * entryName1)1223 void PruneTree(HLSLTree* tree, const char* entryName0, const char* entryName1/*=NULL*/)
1224 {
1225     HLSLRoot* root = tree->GetRoot();
1226 
1227     // Reset all flags.
1228     ResetHiddenFlagVisitor reset;
1229     reset.VisitRoot(root);
1230 
1231     // Mark all the statements necessary for these entrypoints.
1232     HLSLFunction* entry = tree->FindFunction(entryName0);
1233     if (entry != NULL)
1234     {
1235         MarkVisibleStatementsVisitor mark(tree);
1236         mark.VisitFunction(entry);
1237     }
1238 
1239     if (entryName1 != NULL)
1240     {
1241         entry = tree->FindFunction(entryName1);
1242         if (entry != NULL)
1243         {
1244             MarkVisibleStatementsVisitor mark(tree);
1245             mark.VisitFunction(entry);
1246         }
1247     }
1248 
1249     // Mark buffers visible, if any of their fields is visible.
1250     HLSLStatement * statement = root->statement;
1251     while (statement != NULL)
1252     {
1253         if (statement->nodeType == HLSLNodeType_Buffer)
1254         {
1255             HLSLBuffer* buffer = (HLSLBuffer*)statement;
1256 
1257             HLSLDeclaration* field = buffer->field;
1258             while (field != NULL)
1259             {
1260                 ASSERT(field->nodeType == HLSLNodeType_Declaration);
1261                 if (!field->hidden)
1262                 {
1263                     buffer->hidden = false;
1264                     break;
1265                 }
1266                 field = (HLSLDeclaration*)field->nextStatement;
1267             }
1268         }
1269 
1270         statement = statement->nextStatement;
1271     }
1272 }
1273 
1274 
SortTree(HLSLTree * tree)1275 void SortTree(HLSLTree * tree)
1276 {
1277     // Stable sort so that statements are in this order:
1278     // structs, declarations, functions, techniques.
1279 	// but their relative order is preserved.
1280 
1281     HLSLRoot* root = tree->GetRoot();
1282 
1283     HLSLStatement* structs = NULL;
1284     HLSLStatement* lastStruct = NULL;
1285     HLSLStatement* constDeclarations = NULL;
1286     HLSLStatement* lastConstDeclaration = NULL;
1287     HLSLStatement* declarations = NULL;
1288     HLSLStatement* lastDeclaration = NULL;
1289     HLSLStatement* functions = NULL;
1290     HLSLStatement* lastFunction = NULL;
1291     HLSLStatement* other = NULL;
1292     HLSLStatement* lastOther = NULL;
1293 
1294     HLSLStatement* statement = root->statement;
1295     while (statement != NULL) {
1296         HLSLStatement* nextStatement = statement->nextStatement;
1297         statement->nextStatement = NULL;
1298 
1299         if (statement->nodeType == HLSLNodeType_Struct) {
1300             if (structs == NULL) structs = statement;
1301             if (lastStruct != NULL) lastStruct->nextStatement = statement;
1302             lastStruct = statement;
1303         }
1304         else if (statement->nodeType == HLSLNodeType_Declaration || statement->nodeType == HLSLNodeType_Buffer) {
1305             if (statement->nodeType == HLSLNodeType_Declaration && (((HLSLDeclaration *)statement)->type.flags & HLSLTypeFlag_Const)) {
1306                 if (constDeclarations == NULL) constDeclarations = statement;
1307                 if (lastConstDeclaration != NULL) lastConstDeclaration->nextStatement = statement;
1308                 lastConstDeclaration = statement;
1309             }
1310             else {
1311                 if (declarations == NULL) declarations = statement;
1312                 if (lastDeclaration != NULL) lastDeclaration->nextStatement = statement;
1313                 lastDeclaration = statement;
1314             }
1315         }
1316         else if (statement->nodeType == HLSLNodeType_Function) {
1317             if (functions == NULL) functions = statement;
1318             if (lastFunction != NULL) lastFunction->nextStatement = statement;
1319             lastFunction = statement;
1320         }
1321         else {
1322             if (other == NULL) other = statement;
1323             if (lastOther != NULL) lastOther->nextStatement = statement;
1324             lastOther = statement;
1325         }
1326 
1327         statement = nextStatement;
1328     }
1329 
1330     // Chain all the statements in the order that we want.
1331     HLSLStatement * firstStatement = structs;
1332     HLSLStatement * lastStatement = lastStruct;
1333 
1334     if (constDeclarations != NULL) {
1335         if (firstStatement == NULL) firstStatement = constDeclarations;
1336         else lastStatement->nextStatement = constDeclarations;
1337         lastStatement = lastConstDeclaration;
1338     }
1339 
1340     if (declarations != NULL) {
1341         if (firstStatement == NULL) firstStatement = declarations;
1342         else lastStatement->nextStatement = declarations;
1343         lastStatement = lastDeclaration;
1344     }
1345 
1346     if (functions != NULL) {
1347         if (firstStatement == NULL) firstStatement = functions;
1348         else lastStatement->nextStatement = functions;
1349         lastStatement = lastFunction;
1350     }
1351 
1352     if (other != NULL) {
1353         if (firstStatement == NULL) firstStatement = other;
1354         else lastStatement->nextStatement = other;
1355         lastStatement = lastOther;
1356     }
1357 
1358     root->statement = firstStatement;
1359 }
1360 
1361 
1362 
1363 
1364 
1365 // First and last can be the same.
AddStatements(HLSLRoot * root,HLSLStatement * before,HLSLStatement * first,HLSLStatement * last)1366 void AddStatements(HLSLRoot * root, HLSLStatement * before, HLSLStatement * first, HLSLStatement * last)
1367 {
1368     if (before == NULL) {
1369         last->nextStatement = root->statement;
1370         root->statement = first;
1371     }
1372     else {
1373         last->nextStatement = before->nextStatement;
1374         before->nextStatement = first;
1375     }
1376 }
1377 
AddSingleStatement(HLSLRoot * root,HLSLStatement * before,HLSLStatement * statement)1378 void AddSingleStatement(HLSLRoot * root, HLSLStatement * before, HLSLStatement * statement)
1379 {
1380     AddStatements(root, before, statement, statement);
1381 }
1382 
1383 
1384 
1385 // @@ This is very game-specific. Should be moved to pipeline_parser or somewhere else.
GroupParameters(HLSLTree * tree)1386 void GroupParameters(HLSLTree * tree)
1387 {
1388     // Sort parameters based on semantic and group them in cbuffers.
1389 
1390     HLSLRoot* root = tree->GetRoot();
1391 
1392     HLSLDeclaration * firstPerItemDeclaration = NULL;
1393     HLSLDeclaration * lastPerItemDeclaration = NULL;
1394 
1395     HLSLDeclaration * instanceDataDeclaration = NULL;
1396 
1397     HLSLDeclaration * firstPerPassDeclaration = NULL;
1398     HLSLDeclaration * lastPerPassDeclaration = NULL;
1399 
1400     HLSLDeclaration * firstPerItemSampler = NULL;
1401     HLSLDeclaration * lastPerItemSampler = NULL;
1402 
1403     HLSLDeclaration * firstPerPassSampler = NULL;
1404     HLSLDeclaration * lastPerPassSampler = NULL;
1405 
1406     HLSLStatement * statementBeforeBuffers = NULL;
1407 
1408     HLSLStatement* previousStatement = NULL;
1409     HLSLStatement* statement = root->statement;
1410     while (statement != NULL)
1411     {
1412         HLSLStatement* nextStatement = statement->nextStatement;
1413 
1414         if (statement->nodeType == HLSLNodeType_Struct) // Do not remove this, or it will mess the else clause below.
1415         {
1416             statementBeforeBuffers = statement;
1417         }
1418         else if (statement->nodeType == HLSLNodeType_Declaration)
1419         {
1420             HLSLDeclaration* declaration = (HLSLDeclaration*)statement;
1421 
1422             // We insert buffers after the last const declaration.
1423             if ((declaration->type.flags & HLSLTypeFlag_Const) != 0)
1424             {
1425                 statementBeforeBuffers = statement;
1426             }
1427 
1428             // Do not move samplers or static/const parameters.
1429             if ((declaration->type.flags & (HLSLTypeFlag_Static|HLSLTypeFlag_Const)) == 0)
1430             {
1431                 // Unlink statement.
1432                 statement->nextStatement = NULL;
1433                 if (previousStatement != NULL) previousStatement->nextStatement = nextStatement;
1434                 else root->statement = nextStatement;
1435 
1436                 while(declaration != NULL)
1437                 {
1438                     HLSLDeclaration* nextDeclaration = declaration->nextDeclaration;
1439 
1440                     if (declaration->semantic != NULL && String_EqualNoCase(declaration->semantic, "PER_INSTANCED_ITEM"))
1441                     {
1442                         ASSERT(instanceDataDeclaration == NULL);
1443                         instanceDataDeclaration = declaration;
1444                     }
1445                     else
1446                     {
1447                         // Select group based on type and semantic.
1448                         HLSLDeclaration ** first, ** last;
1449                         if (declaration->semantic == NULL || String_EqualNoCase(declaration->semantic, "PER_ITEM") || String_EqualNoCase(declaration->semantic, "PER_MATERIAL"))
1450                         {
1451                             if (IsSamplerType(declaration->type))
1452                             {
1453                                 first = &firstPerItemSampler;
1454                                 last = &lastPerItemSampler;
1455                             }
1456                             else
1457                             {
1458                                 first = &firstPerItemDeclaration;
1459                                 last = &lastPerItemDeclaration;
1460                             }
1461                         }
1462                         else
1463                         {
1464                             if (IsSamplerType(declaration->type))
1465                             {
1466                                 first = &firstPerPassSampler;
1467                                 last = &lastPerPassSampler;
1468                             }
1469                             else
1470                             {
1471                                 first = &firstPerPassDeclaration;
1472                                 last = &lastPerPassDeclaration;
1473                             }
1474                         }
1475 
1476                         // Add declaration to new list.
1477                         if (*first == NULL) *first = declaration;
1478                         else (*last)->nextStatement = declaration;
1479                         *last = declaration;
1480                     }
1481 
1482                     // Unlink from declaration list.
1483                     declaration->nextDeclaration = NULL;
1484 
1485                     // Reset attributes.
1486                     declaration->registerName = NULL;
1487                     //declaration->semantic = NULL;         // @@ Don't do this!
1488 
1489                     declaration = nextDeclaration;
1490                 }
1491             }
1492         }
1493         /*else
1494         {
1495             if (statementBeforeBuffers == NULL) {
1496                 // This is the location where we will insert our buffers.
1497                 statementBeforeBuffers = previousStatement;
1498             }
1499         }*/
1500 
1501         if (statement->nextStatement == nextStatement) {
1502             previousStatement = statement;
1503         }
1504         statement = nextStatement;
1505     }
1506 
1507 
1508     // Add instance data declaration at the end of the per_item buffer.
1509     if (instanceDataDeclaration != NULL)
1510     {
1511         if (firstPerItemDeclaration == NULL) firstPerItemDeclaration = instanceDataDeclaration;
1512         else lastPerItemDeclaration->nextStatement = instanceDataDeclaration;
1513     }
1514 
1515 
1516     // Add samplers.
1517     if (firstPerItemSampler != NULL) {
1518         AddStatements(root, statementBeforeBuffers, firstPerItemSampler, lastPerItemSampler);
1519         statementBeforeBuffers = lastPerItemSampler;
1520     }
1521     if (firstPerPassSampler != NULL) {
1522         AddStatements(root, statementBeforeBuffers, firstPerPassSampler, lastPerPassSampler);
1523         statementBeforeBuffers = lastPerPassSampler;
1524     }
1525 
1526 
1527     // @@ We are assuming per_item and per_pass buffers don't already exist. @@ We should assert on that.
1528 
1529     if (firstPerItemDeclaration != NULL)
1530     {
1531         // Create buffer statement.
1532         HLSLBuffer * perItemBuffer = tree->AddNode<HLSLBuffer>(firstPerItemDeclaration->fileName, firstPerItemDeclaration->line-1);
1533         perItemBuffer->name = tree->AddString("per_item");
1534         perItemBuffer->registerName = tree->AddString("b0");
1535         perItemBuffer->field = firstPerItemDeclaration;
1536 
1537         // Set declaration buffer pointers.
1538         HLSLDeclaration * field = perItemBuffer->field;
1539         while (field != NULL)
1540         {
1541             field->buffer = perItemBuffer;
1542             field = (HLSLDeclaration *)field->nextStatement;
1543         }
1544 
1545         // Add buffer to statements.
1546         AddSingleStatement(root, statementBeforeBuffers, perItemBuffer);
1547         statementBeforeBuffers = perItemBuffer;
1548     }
1549 
1550     if (firstPerPassDeclaration != NULL)
1551     {
1552         // Create buffer statement.
1553         HLSLBuffer * perPassBuffer = tree->AddNode<HLSLBuffer>(firstPerPassDeclaration->fileName, firstPerPassDeclaration->line-1);
1554         perPassBuffer->name = tree->AddString("per_pass");
1555         perPassBuffer->registerName = tree->AddString("b1");
1556         perPassBuffer->field = firstPerPassDeclaration;
1557 
1558         // Set declaration buffer pointers.
1559         HLSLDeclaration * field = perPassBuffer->field;
1560         while (field != NULL)
1561         {
1562             field->buffer = perPassBuffer;
1563             field = (HLSLDeclaration *)field->nextStatement;
1564         }
1565 
1566         // Add buffer to statements.
1567         AddSingleStatement(root, statementBeforeBuffers, perPassBuffer);
1568     }
1569 }
1570 
1571 
1572 class FindArgumentVisitor : public HLSLTreeVisitor
1573 {
1574 public:
1575     bool found;
1576     const char * name;
1577 
FindArgumentVisitor()1578 	FindArgumentVisitor()
1579 	{
1580 		found = false;
1581 		name  = NULL;
1582 	}
1583 
FindArgument(const char * _name,HLSLFunction * function)1584     bool FindArgument(const char * _name, HLSLFunction * function)
1585     {
1586         this->found = false;
1587         this->name = _name;
1588         VisitStatements(function->statement);
1589         return found;
1590     }
1591 
VisitStatements(HLSLStatement * statement)1592     virtual void VisitStatements(HLSLStatement * statement) override
1593     {
1594         while (statement != NULL && !found)
1595         {
1596             VisitStatement(statement);
1597             statement = statement->nextStatement;
1598         }
1599     }
1600 
VisitIdentifierExpression(HLSLIdentifierExpression * node)1601     virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node) override
1602     {
1603         if (node->name == name)
1604         {
1605             found = true;
1606         }
1607     }
1608 };
1609 
1610 
HideUnusedArguments(HLSLFunction * function)1611 void HideUnusedArguments(HLSLFunction * function)
1612 {
1613     FindArgumentVisitor visitor;
1614 
1615     // For each argument.
1616     HLSLArgument * arg = function->argument;
1617     while (arg != NULL)
1618     {
1619         if (!visitor.FindArgument(arg->name, function))
1620         {
1621             arg->hidden = true;
1622         }
1623 
1624         arg = arg->nextArgument;
1625     }
1626 }
1627 
EmulateAlphaTest(HLSLTree * tree,const char * entryName,float alphaRef)1628 bool EmulateAlphaTest(HLSLTree* tree, const char* entryName, float alphaRef/*=0.5*/)
1629 {
1630     // Find all return statements of this entry point.
1631     HLSLFunction* entry = tree->FindFunction(entryName);
1632     if (entry != NULL)
1633     {
1634         HLSLStatement ** ptr = &entry->statement;
1635         HLSLStatement * statement = entry->statement;
1636         while (statement != NULL)
1637         {
1638             if (statement->nodeType == HLSLNodeType_ReturnStatement)
1639             {
1640                 HLSLReturnStatement * returnStatement = (HLSLReturnStatement *)statement;
1641                 HLSLBaseType returnType = returnStatement->expression->expressionType.baseType;
1642 
1643                 // Build statement: "if (%s.a < 0.5) discard;"
1644 
1645                 HLSLDiscardStatement * discard = tree->AddNode<HLSLDiscardStatement>(statement->fileName, statement->line);
1646 
1647                 HLSLExpression * alpha = NULL;
1648                 if (returnType == HLSLBaseType_Float4)
1649                 {
1650                     // @@ If return expression is a constructor, grab 4th argument.
1651                     // That's not as easy, since we support 'float4(float3, float)' or 'float4(float, float3)', extracting
1652                     // the latter is not that easy.
1653                     /*if (returnStatement->expression->nodeType == HLSLNodeType_ConstructorExpression) {
1654                         HLSLConstructorExpression * constructor = (HLSLConstructorExpression *)returnStatement->expression;
1655                         //constructor->
1656                     }
1657                     */
1658 
1659                     if (alpha == NULL) {
1660                         HLSLMemberAccess * access = tree->AddNode<HLSLMemberAccess>(statement->fileName, statement->line);
1661                         access->expressionType = HLSLType(HLSLBaseType_Float);
1662                         access->object = returnStatement->expression;     // @@ Is reference OK? Or should we clone expression?
1663                         access->field = tree->AddString("a");
1664                         access->swizzle = true;
1665 
1666                         alpha = access;
1667                     }
1668                 }
1669                 else if (returnType == HLSLBaseType_Float)
1670                 {
1671                     alpha = returnStatement->expression;     // @@ Is reference OK? Or should we clone expression?
1672                 }
1673                 else
1674                 {
1675                     return false;
1676                 }
1677 
1678                 HLSLLiteralExpression * threshold = tree->AddNode<HLSLLiteralExpression>(statement->fileName, statement->line);
1679                 threshold->expressionType = HLSLType(HLSLBaseType_Float);
1680                 threshold->fValue = alphaRef;
1681                 threshold->type = HLSLBaseType_Float;
1682 
1683                 HLSLBinaryExpression * condition = tree->AddNode<HLSLBinaryExpression>(statement->fileName, statement->line);
1684                 condition->expressionType = HLSLType(HLSLBaseType_Bool);
1685                 condition->binaryOp = HLSLBinaryOp_Less;
1686                 condition->expression1 = alpha;
1687                 condition->expression2 = threshold;
1688 
1689                 // Insert statement.
1690                 HLSLIfStatement * st = tree->AddNode<HLSLIfStatement>(statement->fileName, statement->line);
1691                 st->condition = condition;
1692                 st->statement = discard;
1693                 st->nextStatement = statement;
1694                 *ptr = st;
1695             }
1696 
1697             ptr = &statement->nextStatement;
1698             statement = statement->nextStatement;
1699         }
1700     }
1701 
1702     return true;
1703 }
1704 
NeedsFlattening(HLSLExpression * expr,int level=0)1705 bool NeedsFlattening(HLSLExpression * expr, int level = 0) {
1706     if (expr == NULL) {
1707         return false;
1708     }
1709     if (expr->nodeType == HLSLNodeType_UnaryExpression) {
1710         HLSLUnaryExpression * unaryExpr = (HLSLUnaryExpression *)expr;
1711         return NeedsFlattening(unaryExpr->expression, level+1) || NeedsFlattening(expr->nextExpression, level);
1712     }
1713     else if (expr->nodeType == HLSLNodeType_BinaryExpression) {
1714         HLSLBinaryExpression * binaryExpr = (HLSLBinaryExpression *)expr;
1715         if (IsAssignOp(binaryExpr->binaryOp)) {
1716             return NeedsFlattening(binaryExpr->expression2, level+1) || NeedsFlattening(expr->nextExpression, level);
1717         }
1718         else {
1719             return NeedsFlattening(binaryExpr->expression1, level+1) || NeedsFlattening(binaryExpr->expression2, level+1) || NeedsFlattening(expr->nextExpression, level);
1720         }
1721     }
1722     else if (expr->nodeType == HLSLNodeType_ConditionalExpression) {
1723         HLSLConditionalExpression * conditionalExpr = (HLSLConditionalExpression *)expr;
1724         return NeedsFlattening(conditionalExpr->condition, level+1) || NeedsFlattening(conditionalExpr->trueExpression, level+1) || NeedsFlattening(conditionalExpr->falseExpression, level+1) || NeedsFlattening(expr->nextExpression, level);
1725     }
1726     else if (expr->nodeType == HLSLNodeType_CastingExpression) {
1727         HLSLCastingExpression * castingExpr = (HLSLCastingExpression *)expr;
1728         return NeedsFlattening(castingExpr->expression, level+1) || NeedsFlattening(expr->nextExpression, level);
1729     }
1730     else if (expr->nodeType == HLSLNodeType_LiteralExpression) {
1731         return NeedsFlattening(expr->nextExpression, level);
1732     }
1733     else if (expr->nodeType == HLSLNodeType_IdentifierExpression) {
1734         return NeedsFlattening(expr->nextExpression, level);
1735     }
1736     else if (expr->nodeType == HLSLNodeType_ConstructorExpression) {
1737         HLSLConstructorExpression * constructorExpr = (HLSLConstructorExpression *)expr;
1738         return NeedsFlattening(constructorExpr->argument, level+1) || NeedsFlattening(expr->nextExpression, level);
1739     }
1740     else if (expr->nodeType == HLSLNodeType_MemberAccess) {
1741         return NeedsFlattening(expr->nextExpression, level+1);
1742     }
1743     else if (expr->nodeType == HLSLNodeType_ArrayAccess) {
1744         HLSLArrayAccess * arrayAccess = (HLSLArrayAccess *)expr;
1745         return NeedsFlattening(arrayAccess->array, level+1) || NeedsFlattening(arrayAccess->index, level+1) || NeedsFlattening(expr->nextExpression, level);
1746     }
1747     else if (expr->nodeType == HLSLNodeType_FunctionCall) {
1748         HLSLFunctionCall * functionCall = (HLSLFunctionCall *)expr;
1749         if (functionCall->function->numOutputArguments > 0) {
1750             if (level > 0) {
1751                 return true;
1752             }
1753         }
1754         return NeedsFlattening(functionCall->argument, level+1) || NeedsFlattening(expr->nextExpression, level);
1755     }
1756     else {
1757         //assert(false);
1758         return false;
1759     }
1760 }
1761 
1762 
1763 struct StatementList {
1764     HLSLStatement * head = NULL;
1765     HLSLStatement * tail = NULL;
appendM4::StatementList1766     void append(HLSLStatement * st) {
1767         if (head == NULL) {
1768             tail = head = st;
1769         }
1770         tail->nextStatement = st;
1771         tail = st;
1772     }
1773 };
1774 
1775 
1776     class ExpressionFlattener : public HLSLTreeVisitor
1777     {
1778     public:
1779         HLSLTree * m_tree;
1780         int tmp_index;
1781         HLSLStatement ** statement_pointer;
1782         HLSLFunction * current_function;
1783 
ExpressionFlattener()1784         ExpressionFlattener()
1785         {
1786             m_tree = NULL;
1787             tmp_index = 0;
1788             statement_pointer = NULL;
1789             current_function = NULL;
1790         }
1791 
FlattenExpressions(HLSLTree * tree)1792         void FlattenExpressions(HLSLTree * tree)
1793         {
1794             m_tree = tree;
1795             VisitRoot(tree->GetRoot());
1796         }
1797 
1798         // Visit all statements updating the statement_pointer so that we can insert and replace statements. @@ Add this to the default visitor?
VisitFunction(HLSLFunction * node)1799         virtual void VisitFunction(HLSLFunction * node) override
1800         {
1801             current_function = node;
1802             statement_pointer = &node->statement;
1803             VisitStatements(node->statement);
1804             statement_pointer = NULL;
1805             current_function = NULL;
1806         }
1807 
VisitIfStatement(HLSLIfStatement * node)1808         virtual void VisitIfStatement(HLSLIfStatement * node) override
1809         {
1810             if (NeedsFlattening(node->condition, 1)) {
1811                 assert(false);  // @@ Add statements before if statement.
1812             }
1813 
1814             statement_pointer = &node->statement;
1815             VisitStatements(node->statement);
1816             if (node->elseStatement) {
1817                 statement_pointer = &node->elseStatement;
1818                 VisitStatements(node->elseStatement);
1819             }
1820         }
1821 
VisitForStatement(HLSLForStatement * node)1822         virtual void VisitForStatement(HLSLForStatement * node) override
1823         {
1824             if (NeedsFlattening(node->initialization->assignment, 1)) {
1825                 assert(false);  // @@ Add statements before for statement.
1826             }
1827             if (NeedsFlattening(node->condition, 1) || NeedsFlattening(node->increment, 1)) {
1828                 assert(false);  // @@ These are tricky to implement. Need to handle all loop exits.
1829             }
1830 
1831             statement_pointer = &node->statement;
1832             VisitStatements(node->statement);
1833         }
1834 
VisitBlockStatement(HLSLBlockStatement * node)1835         virtual void VisitBlockStatement(HLSLBlockStatement * node) override
1836         {
1837             statement_pointer = &node->statement;
1838             VisitStatements(node->statement);
1839         }
1840 
VisitStatements(HLSLStatement * statement)1841         virtual void VisitStatements(HLSLStatement * statement) override
1842         {
1843             while (statement != NULL) {
1844                 VisitStatement(statement);
1845                 statement_pointer = &statement->nextStatement;
1846                 statement = statement->nextStatement;
1847             }
1848         }
1849 
1850         // This is usually a function call or assignment.
VisitExpressionStatement(HLSLExpressionStatement * node)1851         virtual void VisitExpressionStatement(HLSLExpressionStatement * node) override
1852         {
1853             if (NeedsFlattening(node->expression, 0))
1854             {
1855                 StatementList statements;
1856                 Flatten(node->expression, statements, false);
1857 
1858                 // Link beginning of statement list.
1859                 *statement_pointer = statements.head;
1860 
1861                 // Link end of statement list.
1862                 HLSLStatement * tail = statements.tail;
1863                 tail->nextStatement = node->nextStatement;
1864 
1865                 // Update statement pointer.
1866                 statement_pointer = &tail->nextStatement;
1867 
1868                 // @@ Delete node?
1869             }
1870         }
1871 
VisitDeclaration(HLSLDeclaration * node)1872         virtual void VisitDeclaration(HLSLDeclaration * node) override
1873         {
1874             // Skip global declarations.
1875             if (statement_pointer == NULL) return;
1876 
1877             if (NeedsFlattening(node->assignment, 1))
1878             {
1879                 StatementList statements;
1880                 HLSLIdentifierExpression * ident = Flatten(node->assignment, statements, true);
1881 
1882                 // @@ Delete node->assignment?
1883 
1884                 node->assignment = ident;
1885                 statements.append(node);
1886 
1887                 // Link beginning of statement list.
1888                 *statement_pointer = statements.head;
1889 
1890                 // Link end of statement list.
1891                 HLSLStatement * tail = statements.tail;
1892                 tail->nextStatement = node->nextStatement;
1893 
1894                 // Update statement pointer.
1895                 statement_pointer = &tail->nextStatement;
1896             }
1897         }
1898 
VisitReturnStatement(HLSLReturnStatement * node)1899         virtual void VisitReturnStatement(HLSLReturnStatement * node) override
1900         {
1901             if (NeedsFlattening(node->expression, 1))
1902             {
1903                 StatementList statements;
1904                 HLSLIdentifierExpression * ident = Flatten(node->expression, statements, true);
1905 
1906                 // @@ Delete node->expression?
1907 
1908                 node->expression = ident;
1909                 statements.append(node);
1910 
1911                 // Link beginning of statement list.
1912                 *statement_pointer = statements.head;
1913 
1914                 // Link end of statement list.
1915                 HLSLStatement * tail = statements.tail;
1916                 tail->nextStatement = node->nextStatement;
1917 
1918                 // Update statement pointer.
1919                 statement_pointer = &tail->nextStatement;
1920             }
1921         }
1922 
1923 
BuildTemporaryDeclaration(HLSLExpression * expr)1924         HLSLDeclaration * BuildTemporaryDeclaration(HLSLExpression * expr)
1925         {
1926             assert(expr->expressionType.baseType != HLSLBaseType_Void);
1927 
1928             HLSLDeclaration * declaration = m_tree->AddNode<HLSLDeclaration>(expr->fileName, expr->line);
1929             declaration->name = m_tree->AddStringFormat("tmp%d", tmp_index++);
1930             declaration->type = expr->expressionType;
1931             declaration->assignment = expr;
1932 
1933             return declaration;
1934         }
1935 
BuildExpressionStatement(HLSLExpression * expr)1936         HLSLExpressionStatement * BuildExpressionStatement(HLSLExpression * expr)
1937         {
1938             HLSLExpressionStatement * statement = m_tree->AddNode<HLSLExpressionStatement>(expr->fileName, expr->line);
1939             statement->expression = expr;
1940             return statement;
1941         }
1942 
AddExpressionStatement(HLSLExpression * expr,StatementList & statements,bool wantIdent)1943         HLSLIdentifierExpression * AddExpressionStatement(HLSLExpression * expr, StatementList & statements, bool wantIdent)
1944         {
1945             if (wantIdent) {
1946                 HLSLDeclaration * declaration = BuildTemporaryDeclaration(expr);
1947                 statements.append(declaration);
1948 
1949                 HLSLIdentifierExpression * ident = m_tree->AddNode<HLSLIdentifierExpression>(expr->fileName, expr->line);
1950                 ident->name = declaration->name;
1951                 ident->expressionType = declaration->type;
1952                 return ident;
1953             }
1954             else {
1955                 HLSLExpressionStatement * statement = BuildExpressionStatement(expr);
1956                 statements.append(statement);
1957                 return NULL;
1958             }
1959         }
1960 
Flatten(HLSLExpression * expr,StatementList & statements,bool wantIdent=true)1961         HLSLIdentifierExpression * Flatten(HLSLExpression * expr, StatementList & statements, bool wantIdent = true)
1962         {
1963             if (!NeedsFlattening(expr, wantIdent)) {
1964                 return AddExpressionStatement(expr, statements, wantIdent);
1965             }
1966 
1967             if (expr->nodeType == HLSLNodeType_UnaryExpression) {
1968                 assert(expr->nextExpression == NULL);
1969 
1970                 HLSLUnaryExpression * unaryExpr = (HLSLUnaryExpression *)expr;
1971 
1972                 HLSLIdentifierExpression * tmp = Flatten(unaryExpr->expression, statements, true);
1973 
1974                 HLSLUnaryExpression * newUnaryExpr = m_tree->AddNode<HLSLUnaryExpression>(unaryExpr->fileName, unaryExpr->line);
1975                 newUnaryExpr->unaryOp = unaryExpr->unaryOp;
1976                 newUnaryExpr->expression = tmp;
1977                 newUnaryExpr->expressionType = unaryExpr->expressionType;
1978 
1979                 return AddExpressionStatement(newUnaryExpr, statements, wantIdent);
1980             }
1981             else if (expr->nodeType == HLSLNodeType_BinaryExpression) {
1982                 assert(expr->nextExpression == NULL);
1983 
1984                 HLSLBinaryExpression * binaryExpr = (HLSLBinaryExpression *)expr;
1985 
1986                 if (IsAssignOp(binaryExpr->binaryOp)) {
1987                     // Flatten right hand side only.
1988                     HLSLIdentifierExpression * tmp2 = Flatten(binaryExpr->expression2, statements, true);
1989 
1990                     HLSLBinaryExpression * newBinaryExpr = m_tree->AddNode<HLSLBinaryExpression>(binaryExpr->fileName, binaryExpr->line);
1991                     newBinaryExpr->binaryOp = binaryExpr->binaryOp;
1992                     newBinaryExpr->expression1 = binaryExpr->expression1;
1993                     newBinaryExpr->expression2 = tmp2;
1994                     newBinaryExpr->expressionType = binaryExpr->expressionType;
1995 
1996                     return AddExpressionStatement(newBinaryExpr, statements, wantIdent);
1997                 }
1998                 else {
1999                     HLSLIdentifierExpression * tmp1 = Flatten(binaryExpr->expression1, statements, true);
2000                     HLSLIdentifierExpression * tmp2 = Flatten(binaryExpr->expression2, statements, true);
2001 
2002                     HLSLBinaryExpression * newBinaryExpr = m_tree->AddNode<HLSLBinaryExpression>(binaryExpr->fileName, binaryExpr->line);
2003                     newBinaryExpr->binaryOp = binaryExpr->binaryOp;
2004                     newBinaryExpr->expression1 = tmp1;
2005                     newBinaryExpr->expression2 = tmp2;
2006                     newBinaryExpr->expressionType = binaryExpr->expressionType;
2007 
2008                     return AddExpressionStatement(newBinaryExpr, statements, wantIdent);
2009                 }
2010             }
2011             else if (expr->nodeType == HLSLNodeType_ConditionalExpression) {
2012                 assert(false);
2013             }
2014             else if (expr->nodeType == HLSLNodeType_CastingExpression) {
2015                 assert(false);
2016             }
2017             else if (expr->nodeType == HLSLNodeType_LiteralExpression) {
2018                 assert(false);
2019             }
2020             else if (expr->nodeType == HLSLNodeType_IdentifierExpression) {
2021                 assert(false);
2022             }
2023             else if (expr->nodeType == HLSLNodeType_ConstructorExpression) {
2024                 assert(false);
2025             }
2026             else if (expr->nodeType == HLSLNodeType_MemberAccess) {
2027                 assert(false);
2028             }
2029             else if (expr->nodeType == HLSLNodeType_ArrayAccess) {
2030                 assert(false);
2031             }
2032             else if (expr->nodeType == HLSLNodeType_FunctionCall) {
2033                 HLSLFunctionCall * functionCall = (HLSLFunctionCall *)expr;
2034 
2035                 // @@ Output function as is?
2036                 // @@ We have to flatten function arguments! This is tricky, need to handle input/output arguments.
2037                 assert(!NeedsFlattening(functionCall->argument));
2038 
2039                 return AddExpressionStatement(expr, statements, wantIdent);
2040             }
2041             else {
2042                 assert(false);
2043             }
2044             return NULL;
2045         }
2046     };
2047 
2048 
FlattenExpressions(HLSLTree * tree)2049 void FlattenExpressions(HLSLTree* tree) {
2050     ExpressionFlattener flattener;
2051     flattener.FlattenExpressions(tree);
2052 }
2053 
2054 } // M4
2055 
2056