1 //#include "Engine/Assert.h"
2 #include "Engine.h"
3
4 #include "HLSLTree.h"
5 #include <assert.h>
6 #include <map>
7 #include <string>
8 #include <algorithm>
9
10 namespace M4
11 {
12
HLSLTree(Allocator * allocator)13 HLSLTree::HLSLTree(Allocator* allocator) :
14 m_allocator(allocator), m_stringPool(allocator)
15 {
16 m_firstPage = m_allocator->New<NodePage>();
17 m_firstPage->next = NULL;
18
19 m_currentPage = m_firstPage;
20 m_currentPageOffset = 0;
21
22 m_root = AddNode<HLSLRoot>(NULL, 1);
23 }
24
~HLSLTree()25 HLSLTree::~HLSLTree()
26 {
27 NodePage* page = m_firstPage;
28 while (page != NULL)
29 {
30 NodePage* next = page->next;
31 m_allocator->Delete(page);
32 page = next;
33 }
34 }
35
AllocatePage()36 void HLSLTree::AllocatePage()
37 {
38 NodePage* newPage = m_allocator->New<NodePage>();
39 newPage->next = NULL;
40 m_currentPage->next = newPage;
41 m_currentPageOffset = 0;
42 m_currentPage = newPage;
43 }
44
AddString(const char * string)45 const char* HLSLTree::AddString(const char* string)
46 {
47 return m_stringPool.AddString(string);
48 }
49
AddStringFormat(const char * format,...)50 const char* HLSLTree::AddStringFormat(const char* format, ...)
51 {
52 va_list args;
53 va_start(args, format);
54 const char * string = m_stringPool.AddStringFormatList(format, args);
55 va_end(args);
56 return string;
57 }
58
GetContainsString(const char * string) const59 bool HLSLTree::GetContainsString(const char* string) const
60 {
61 return m_stringPool.GetContainsString(string);
62 }
63
GetRoot() const64 HLSLRoot* HLSLTree::GetRoot() const
65 {
66 return m_root;
67 }
68
AllocateMemory(size_t size)69 void* HLSLTree::AllocateMemory(size_t size)
70 {
71 if (m_currentPageOffset + size > s_nodePageSize)
72 {
73 AllocatePage();
74 }
75 void* buffer = m_currentPage->buffer + m_currentPageOffset;
76 m_currentPageOffset += size;
77 return buffer;
78 }
79
80 // @@ This doesn't do any parameter matching. Simply returns the first function with that name.
FindFunction(const char * name)81 HLSLFunction * HLSLTree::FindFunction(const char * name)
82 {
83 HLSLStatement * statement = m_root->statement;
84 while (statement != NULL)
85 {
86 if (statement->nodeType == HLSLNodeType_Function)
87 {
88 HLSLFunction * function = (HLSLFunction *)statement;
89 if (String_Equal(name, function->name))
90 {
91 return function;
92 }
93 }
94
95 statement = statement->nextStatement;
96 }
97
98 return NULL;
99 }
100
FindGlobalDeclaration(const char * name,HLSLBuffer ** buffer_out)101 HLSLDeclaration * HLSLTree::FindGlobalDeclaration(const char * name, HLSLBuffer ** buffer_out/*=NULL*/)
102 {
103 HLSLStatement * statement = m_root->statement;
104 while (statement != NULL)
105 {
106 if (statement->nodeType == HLSLNodeType_Declaration)
107 {
108 HLSLDeclaration * declaration = (HLSLDeclaration *)statement;
109 if (String_Equal(name, declaration->name))
110 {
111 if (buffer_out) *buffer_out = NULL;
112 return declaration;
113 }
114 }
115 else if (statement->nodeType == HLSLNodeType_Buffer)
116 {
117 HLSLBuffer* buffer = (HLSLBuffer*)statement;
118
119 HLSLDeclaration* field = buffer->field;
120 while (field != NULL)
121 {
122 ASSERT(field->nodeType == HLSLNodeType_Declaration);
123 if (String_Equal(name, field->name))
124 {
125 if (buffer_out) *buffer_out = buffer;
126 return field;
127 }
128 field = (HLSLDeclaration*)field->nextStatement;
129 }
130 }
131
132 statement = statement->nextStatement;
133 }
134
135 if (buffer_out) *buffer_out = NULL;
136 return NULL;
137 }
138
FindGlobalStruct(const char * name)139 HLSLStruct * HLSLTree::FindGlobalStruct(const char * name)
140 {
141 HLSLStatement * statement = m_root->statement;
142 while (statement != NULL)
143 {
144 if (statement->nodeType == HLSLNodeType_Struct)
145 {
146 HLSLStruct * declaration = (HLSLStruct *)statement;
147 if (String_Equal(name, declaration->name))
148 {
149 return declaration;
150 }
151 }
152
153 statement = statement->nextStatement;
154 }
155
156 return NULL;
157 }
158
FindTechnique(const char * name)159 HLSLTechnique * HLSLTree::FindTechnique(const char * name)
160 {
161 HLSLStatement * statement = m_root->statement;
162 while (statement != NULL)
163 {
164 if (statement->nodeType == HLSLNodeType_Technique)
165 {
166 HLSLTechnique * technique = (HLSLTechnique *)statement;
167 if (String_Equal(name, technique->name))
168 {
169 return technique;
170 }
171 }
172
173 statement = statement->nextStatement;
174 }
175
176 return NULL;
177 }
178
FindFirstPipeline()179 HLSLPipeline * HLSLTree::FindFirstPipeline()
180 {
181 return FindNextPipeline(NULL);
182 }
183
FindNextPipeline(HLSLPipeline * current)184 HLSLPipeline * HLSLTree::FindNextPipeline(HLSLPipeline * current)
185 {
186 HLSLStatement * statement = current ? current : m_root->statement;
187 while (statement != NULL)
188 {
189 if (statement->nodeType == HLSLNodeType_Pipeline)
190 {
191 return (HLSLPipeline *)statement;
192 }
193
194 statement = statement->nextStatement;
195 }
196
197 return NULL;
198 }
199
FindPipeline(const char * name)200 HLSLPipeline * HLSLTree::FindPipeline(const char * name)
201 {
202 HLSLStatement * statement = m_root->statement;
203 while (statement != NULL)
204 {
205 if (statement->nodeType == HLSLNodeType_Pipeline)
206 {
207 HLSLPipeline * pipeline = (HLSLPipeline *)statement;
208 if (String_Equal(name, pipeline->name))
209 {
210 return pipeline;
211 }
212 }
213
214 statement = statement->nextStatement;
215 }
216
217 return NULL;
218 }
219
FindBuffer(const char * name)220 HLSLBuffer * HLSLTree::FindBuffer(const char * name)
221 {
222 HLSLStatement * statement = m_root->statement;
223 while (statement != NULL)
224 {
225 if (statement->nodeType == HLSLNodeType_Buffer)
226 {
227 HLSLBuffer * buffer = (HLSLBuffer *)statement;
228 if (String_Equal(name, buffer->name))
229 {
230 return buffer;
231 }
232 }
233
234 statement = statement->nextStatement;
235 }
236
237 return NULL;
238 }
239
240
241
GetExpressionValue(HLSLExpression * expression,int & value)242 bool HLSLTree::GetExpressionValue(HLSLExpression * expression, int & value)
243 {
244 ASSERT (expression != NULL);
245
246 // Expression must be constant.
247 if ((expression->expressionType.flags & HLSLTypeFlag_Const) == 0)
248 {
249 return false;
250 }
251
252 // We are expecting an integer scalar. @@ Add support for type conversion from other scalar types.
253 if (expression->expressionType.baseType != HLSLBaseType_Int &&
254 expression->expressionType.baseType != HLSLBaseType_Bool)
255 {
256 return false;
257 }
258
259 if (expression->expressionType.array)
260 {
261 return false;
262 }
263
264 if (expression->nodeType == HLSLNodeType_BinaryExpression)
265 {
266 HLSLBinaryExpression * binaryExpression = (HLSLBinaryExpression *)expression;
267
268 int value1, value2;
269 if (!GetExpressionValue(binaryExpression->expression1, value1) ||
270 !GetExpressionValue(binaryExpression->expression2, value2))
271 {
272 return false;
273 }
274
275 switch(binaryExpression->binaryOp)
276 {
277 case HLSLBinaryOp_And:
278 value = value1 && value2;
279 return true;
280 case HLSLBinaryOp_Or:
281 value = value1 || value2;
282 return true;
283 case HLSLBinaryOp_Add:
284 value = value1 + value2;
285 return true;
286 case HLSLBinaryOp_Sub:
287 value = value1 - value2;
288 return true;
289 case HLSLBinaryOp_Mul:
290 value = value1 * value2;
291 return true;
292 case HLSLBinaryOp_Div:
293 value = value1 / value2;
294 return true;
295 case HLSLBinaryOp_Mod:
296 value = value1 % value2;
297 return true;
298 case HLSLBinaryOp_Less:
299 value = value1 < value2;
300 return true;
301 case HLSLBinaryOp_Greater:
302 value = value1 > value2;
303 return true;
304 case HLSLBinaryOp_LessEqual:
305 value = value1 <= value2;
306 return true;
307 case HLSLBinaryOp_GreaterEqual:
308 value = value1 >= value2;
309 return true;
310 case HLSLBinaryOp_Equal:
311 value = value1 == value2;
312 return true;
313 case HLSLBinaryOp_NotEqual:
314 value = value1 != value2;
315 return true;
316 case HLSLBinaryOp_BitAnd:
317 value = value1 & value2;
318 return true;
319 case HLSLBinaryOp_BitOr:
320 value = value1 | value2;
321 return true;
322 case HLSLBinaryOp_BitXor:
323 value = value1 ^ value2;
324 return true;
325 case HLSLBinaryOp_Assign:
326 case HLSLBinaryOp_AddAssign:
327 case HLSLBinaryOp_SubAssign:
328 case HLSLBinaryOp_MulAssign:
329 case HLSLBinaryOp_DivAssign:
330 // IC: These are not valid on non-constant expressions and should fail earlier when querying expression value.
331 return false;
332 }
333 }
334 else if (expression->nodeType == HLSLNodeType_UnaryExpression)
335 {
336 HLSLUnaryExpression * unaryExpression = (HLSLUnaryExpression *)expression;
337
338 if (!GetExpressionValue(unaryExpression->expression, value))
339 {
340 return false;
341 }
342
343 switch(unaryExpression->unaryOp)
344 {
345 case HLSLUnaryOp_Negative:
346 value = -value;
347 return true;
348 case HLSLUnaryOp_Positive:
349 // nop.
350 return true;
351 case HLSLUnaryOp_Not:
352 value = !value;
353 return true;
354 case HLSLUnaryOp_BitNot:
355 value = ~value;
356 return true;
357 case HLSLUnaryOp_PostDecrement:
358 case HLSLUnaryOp_PostIncrement:
359 case HLSLUnaryOp_PreDecrement:
360 case HLSLUnaryOp_PreIncrement:
361 // IC: These are not valid on non-constant expressions and should fail earlier when querying expression value.
362 return false;
363 }
364 }
365 else if (expression->nodeType == HLSLNodeType_IdentifierExpression)
366 {
367 HLSLIdentifierExpression * identifier = (HLSLIdentifierExpression *)expression;
368
369 HLSLDeclaration * declaration = FindGlobalDeclaration(identifier->name);
370 if (declaration == NULL)
371 {
372 return false;
373 }
374 if ((declaration->type.flags & HLSLTypeFlag_Const) == 0)
375 {
376 return false;
377 }
378
379 return GetExpressionValue(declaration->assignment, value);
380 }
381 else if (expression->nodeType == HLSLNodeType_LiteralExpression)
382 {
383 HLSLLiteralExpression * literal = (HLSLLiteralExpression *)expression;
384
385 if (literal->expressionType.baseType == HLSLBaseType_Int) value = literal->iValue;
386 else if (literal->expressionType.baseType == HLSLBaseType_Bool) value = (int)literal->bValue;
387 else return false;
388
389 return true;
390 }
391
392 return false;
393 }
394
NeedsFunction(const char * name)395 bool HLSLTree::NeedsFunction(const char* name)
396 {
397 // Early out
398 if (!GetContainsString(name))
399 return false;
400
401 struct NeedsFunctionVisitor: HLSLTreeVisitor
402 {
403 const char* name;
404 bool result;
405
406 virtual void VisitTopLevelStatement(HLSLStatement * node)
407 {
408 if (!node->hidden)
409 HLSLTreeVisitor::VisitTopLevelStatement(node);
410 }
411
412 virtual void VisitFunctionCall(HLSLFunctionCall * node)
413 {
414 result = result || String_Equal(name, node->function->name);
415
416 HLSLTreeVisitor::VisitFunctionCall(node);
417 }
418 };
419
420 NeedsFunctionVisitor visitor;
421 visitor.name = name;
422 visitor.result = false;
423
424 visitor.VisitRoot(m_root);
425
426 return visitor.result;
427 }
428
GetVectorDimension(HLSLType & type)429 int GetVectorDimension(HLSLType & type)
430 {
431 if (type.baseType >= HLSLBaseType_FirstNumeric &&
432 type.baseType <= HLSLBaseType_LastNumeric)
433 {
434 if (type.baseType == HLSLBaseType_Float) return 1;
435 if (type.baseType == HLSLBaseType_Float2) return 2;
436 if (type.baseType == HLSLBaseType_Float3) return 3;
437 if (type.baseType == HLSLBaseType_Float4) return 4;
438
439 }
440 return 0;
441 }
442
443 // Returns dimension, 0 if invalid.
GetExpressionValue(HLSLExpression * expression,float values[4])444 int HLSLTree::GetExpressionValue(HLSLExpression * expression, float values[4])
445 {
446 ASSERT (expression != NULL);
447
448 // Expression must be constant.
449 if ((expression->expressionType.flags & HLSLTypeFlag_Const) == 0)
450 {
451 return 0;
452 }
453
454 if (expression->expressionType.baseType == HLSLBaseType_Int ||
455 expression->expressionType.baseType == HLSLBaseType_Bool)
456 {
457 int int_value;
458 if (GetExpressionValue(expression, int_value)) {
459 for (int i = 0; i < 4; i++) values[i] = (float)int_value; // @@ Warn if conversion is not exact.
460 return 1;
461 }
462
463 return 0;
464 }
465 if (expression->expressionType.baseType >= HLSLBaseType_FirstInteger && expression->expressionType.baseType <= HLSLBaseType_LastInteger)
466 {
467 // @@ Add support for uints?
468 // @@ Add support for int vectors?
469 return 0;
470 }
471 if (expression->expressionType.baseType > HLSLBaseType_LastNumeric)
472 {
473 return 0;
474 }
475
476 // @@ Not supported yet, but we may need it?
477 if (expression->expressionType.array)
478 {
479 return false;
480 }
481
482 if (expression->nodeType == HLSLNodeType_BinaryExpression)
483 {
484 HLSLBinaryExpression * binaryExpression = (HLSLBinaryExpression *)expression;
485 int dim = GetVectorDimension(binaryExpression->expressionType);
486
487 float values1[4], values2[4];
488 int dim1 = GetExpressionValue(binaryExpression->expression1, values1);
489 int dim2 = GetExpressionValue(binaryExpression->expression2, values2);
490
491 if (dim1 == 0 || dim2 == 0)
492 {
493 return 0;
494 }
495
496 if (dim1 != dim2)
497 {
498 // Brodacast scalar to vector size.
499 if (dim1 == 1)
500 {
501 for (int i = 1; i < dim2; i++) values1[i] = values1[0];
502 dim1 = dim2;
503 }
504 else if (dim2 == 1)
505 {
506 for (int i = 1; i < dim1; i++) values2[i] = values2[0];
507 dim2 = dim1;
508 }
509 else
510 {
511 return 0;
512 }
513 }
514 ASSERT(dim == dim1);
515
516 switch(binaryExpression->binaryOp)
517 {
518 case HLSLBinaryOp_Add:
519 for (int i = 0; i < dim; i++) values[i] = values1[i] + values2[i];
520 return dim;
521 case HLSLBinaryOp_Sub:
522 for (int i = 0; i < dim; i++) values[i] = values1[i] - values2[i];
523 return dim;
524 case HLSLBinaryOp_Mul:
525 for (int i = 0; i < dim; i++) values[i] = values1[i] * values2[i];
526 return dim;
527 case HLSLBinaryOp_Div:
528 for (int i = 0; i < dim; i++) values[i] = values1[i] / values2[i];
529 return dim;
530 case HLSLBinaryOp_Mod:
531 for (int i = 0; i < dim; i++) values[i] = int(values1[i]) % int(values2[i]);
532 return dim;
533 default:
534 return 0;
535 }
536 }
537 else if (expression->nodeType == HLSLNodeType_UnaryExpression)
538 {
539 HLSLUnaryExpression * unaryExpression = (HLSLUnaryExpression *)expression;
540 int dim = GetVectorDimension(unaryExpression->expressionType);
541
542 int dim1 = GetExpressionValue(unaryExpression->expression, values);
543 if (dim1 == 0)
544 {
545 return 0;
546 }
547 ASSERT(dim == dim1);
548
549 switch(unaryExpression->unaryOp)
550 {
551 case HLSLUnaryOp_Negative:
552 for (int i = 0; i < dim; i++) values[i] = -values[i];
553 return dim;
554 case HLSLUnaryOp_Positive:
555 // nop.
556 return dim;
557 default:
558 return 0;
559 }
560 }
561 else if (expression->nodeType == HLSLNodeType_ConstructorExpression)
562 {
563 HLSLConstructorExpression * constructor = (HLSLConstructorExpression *)expression;
564
565 int dim = GetVectorDimension(constructor->expressionType);
566
567 int idx = 0;
568 HLSLExpression * arg = constructor->argument;
569 while (arg != NULL)
570 {
571 float tmp[4];
572 int n = GetExpressionValue(arg, tmp);
573 for (int i = 0; i < n; i++) values[idx + i] = tmp[i];
574 idx += n;
575
576 arg = arg->nextExpression;
577 }
578 ASSERT(dim == idx);
579
580 return dim;
581 }
582 else if (expression->nodeType == HLSLNodeType_IdentifierExpression)
583 {
584 HLSLIdentifierExpression * identifier = (HLSLIdentifierExpression *)expression;
585
586 HLSLDeclaration * declaration = FindGlobalDeclaration(identifier->name);
587 if (declaration == NULL)
588 {
589 return 0;
590 }
591 if ((declaration->type.flags & HLSLTypeFlag_Const) == 0)
592 {
593 return 0;
594 }
595
596 return GetExpressionValue(declaration->assignment, values);
597 }
598 else if (expression->nodeType == HLSLNodeType_LiteralExpression)
599 {
600 HLSLLiteralExpression * literal = (HLSLLiteralExpression *)expression;
601
602 if (literal->expressionType.baseType == HLSLBaseType_Float) values[0] = literal->fValue;
603 else if (literal->expressionType.baseType == HLSLBaseType_Bool) values[0] = literal->bValue;
604 else if (literal->expressionType.baseType == HLSLBaseType_Int) values[0] = (float)literal->iValue; // @@ Warn if conversion is not exact.
605 else return 0;
606
607 return 1;
608 }
609
610 return 0;
611 }
612
ReplaceUniformsAssignments()613 bool HLSLTree::ReplaceUniformsAssignments()
614 {
615 struct ReplaceUniformsAssignmentsVisitor: HLSLTreeVisitor
616 {
617 HLSLTree * tree;
618 std::map<std::string, HLSLDeclaration *> uniforms;
619 std::map<std::string, std::string> uniformsReplaced;
620 bool withinAssignment;
621
622 virtual void VisitDeclaration(HLSLDeclaration * node)
623 {
624 HLSLTreeVisitor::VisitDeclaration(node);
625
626 // Enumerate uniforms
627 if (node->type.flags & HLSLTypeFlag_Uniform)
628 {
629 uniforms[node->name] = node;
630 }
631 }
632
633 virtual void VisitFunction(HLSLFunction * node)
634 {
635 uniformsReplaced.clear();
636
637 // Detect uniforms assignments
638 HLSLTreeVisitor::VisitFunction(node);
639
640 // Declare uniforms replacements
641 std::map<std::string, std::string>::const_iterator iter = uniformsReplaced.cbegin();
642 for ( ; iter != uniformsReplaced.cend(); ++iter)
643 {
644 HLSLDeclaration * uniformDeclaration = uniforms[iter->first];
645 HLSLDeclaration * declaration = tree->AddNode<HLSLDeclaration>(node->fileName, node->line);
646
647 declaration->name = tree->AddString(iter->second.c_str());
648 declaration->type = uniformDeclaration->type;
649
650 // Add declaration within function statements
651 declaration->nextStatement = node->statement;
652 node->statement = declaration;
653 }
654 }
655
656 virtual void VisitBinaryExpression(HLSLBinaryExpression * node)
657 {
658 // Visit expression 2 first to not replace possible uniform reading
659 VisitExpression(node->expression2);
660
661 if (IsAssignOp(node->binaryOp))
662 {
663 withinAssignment = true;
664 }
665
666 VisitExpression(node->expression1);
667
668 withinAssignment = false;
669 }
670
671 virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node)
672 {
673 if (withinAssignment)
674 {
675 // Check if variable is a uniform
676 if (uniforms.find(node->name) != uniforms.end())
677 {
678 // Check if variable is not already replaced
679 if (uniformsReplaced.find(node->name) == uniformsReplaced.end())
680 {
681 std::string newName(node->name);
682 do
683 {
684 newName.insert(0, "new");
685 }
686 while(tree->GetContainsString(newName.c_str()));
687
688 uniformsReplaced[node->name] = newName;
689 }
690 }
691 }
692
693 // Check if variable need to be replaced
694 if (uniformsReplaced.find(node->name) != uniformsReplaced.end())
695 {
696 // Replace
697 node->name = tree->AddString( uniformsReplaced[node->name].c_str() );
698 }
699 }
700 };
701
702 ReplaceUniformsAssignmentsVisitor visitor;
703 visitor.tree = this;
704 visitor.withinAssignment = false;
705 visitor.VisitRoot(m_root);
706
707 return true;
708 }
709
710
matrixCtorBuilder(HLSLType type,HLSLExpression * arguments)711 matrixCtor matrixCtorBuilder(HLSLType type, HLSLExpression * arguments) {
712 matrixCtor ctor;
713
714 ctor.matrixType = type.baseType;
715
716 // Fetch all arguments
717 HLSLExpression* argument = arguments;
718 while (argument != NULL)
719 {
720 ctor.argumentTypes.push_back(argument->expressionType.baseType);
721 argument = argument->nextExpression;
722 }
723
724 return ctor;
725 }
726
EnumerateMatrixCtorsNeeded(std::vector<matrixCtor> & matrixCtors)727 void HLSLTree::EnumerateMatrixCtorsNeeded(std::vector<matrixCtor> & matrixCtors) {
728
729 struct EnumerateMatrixCtorsVisitor: HLSLTreeVisitor
730 {
731 std::vector<matrixCtor> matrixCtorsNeeded;
732
733 virtual void VisitConstructorExpression(HLSLConstructorExpression * node)
734 {
735 if (IsMatrixType(node->expressionType.baseType))
736 {
737 matrixCtor ctor = matrixCtorBuilder(node->expressionType, node->argument);
738
739 if (std::find(matrixCtorsNeeded.cbegin(), matrixCtorsNeeded.cend(), ctor) == matrixCtorsNeeded.cend())
740 {
741 matrixCtorsNeeded.push_back(ctor);
742 }
743 }
744
745 HLSLTreeVisitor::VisitConstructorExpression(node);
746 }
747
748 virtual void VisitDeclaration(HLSLDeclaration * node)
749 {
750 if ( IsMatrixType(node->type.baseType) &&
751 (node->type.flags & HLSLArgumentModifier_Uniform) == 0 )
752 {
753 matrixCtor ctor = matrixCtorBuilder(node->type, node->assignment);
754
755 // No special constructor needed if it already a matrix
756 bool matrixArgument = false;
757 for(HLSLBaseType & type: ctor.argumentTypes)
758 {
759 if (IsMatrixType(type))
760 {
761 matrixArgument = true;
762 break;
763 }
764 }
765
766 if ( !matrixArgument &&
767 std::find(matrixCtorsNeeded.cbegin(), matrixCtorsNeeded.cend(), ctor) == matrixCtorsNeeded.cend())
768 {
769 matrixCtorsNeeded.push_back(ctor);
770 }
771 }
772
773 HLSLTreeVisitor::VisitDeclaration(node);
774 }
775 };
776
777 EnumerateMatrixCtorsVisitor visitor;
778 visitor.VisitRoot(m_root);
779
780 matrixCtors = visitor.matrixCtorsNeeded;
781 }
782
783
VisitType(HLSLType & type)784 void HLSLTreeVisitor::VisitType(HLSLType & type)
785 {
786 }
787
VisitRoot(HLSLRoot * root)788 void HLSLTreeVisitor::VisitRoot(HLSLRoot * root)
789 {
790 HLSLStatement * statement = root->statement;
791 while (statement != NULL) {
792 VisitTopLevelStatement(statement);
793 statement = statement->nextStatement;
794 }
795 }
796
VisitTopLevelStatement(HLSLStatement * node)797 void HLSLTreeVisitor::VisitTopLevelStatement(HLSLStatement * node)
798 {
799 if (node->nodeType == HLSLNodeType_Declaration) {
800 VisitDeclaration((HLSLDeclaration *)node);
801 }
802 else if (node->nodeType == HLSLNodeType_Struct) {
803 VisitStruct((HLSLStruct *)node);
804 }
805 else if (node->nodeType == HLSLNodeType_Buffer) {
806 VisitBuffer((HLSLBuffer *)node);
807 }
808 else if (node->nodeType == HLSLNodeType_Function) {
809 VisitFunction((HLSLFunction *)node);
810 }
811 else if (node->nodeType == HLSLNodeType_Technique) {
812 VisitTechnique((HLSLTechnique *)node);
813 }
814 else if (node->nodeType == HLSLNodeType_Pipeline) {
815 VisitPipeline((HLSLPipeline *)node);
816 }
817 else {
818 ASSERT(0);
819 }
820 }
821
VisitStatements(HLSLStatement * statement)822 void HLSLTreeVisitor::VisitStatements(HLSLStatement * statement)
823 {
824 while (statement != NULL) {
825 VisitStatement(statement);
826 statement = statement->nextStatement;
827 }
828 }
829
VisitStatement(HLSLStatement * node)830 void HLSLTreeVisitor::VisitStatement(HLSLStatement * node)
831 {
832 // Function statements
833 if (node->nodeType == HLSLNodeType_Declaration) {
834 VisitDeclaration((HLSLDeclaration *)node);
835 }
836 else if (node->nodeType == HLSLNodeType_ExpressionStatement) {
837 VisitExpressionStatement((HLSLExpressionStatement *)node);
838 }
839 else if (node->nodeType == HLSLNodeType_ReturnStatement) {
840 VisitReturnStatement((HLSLReturnStatement *)node);
841 }
842 else if (node->nodeType == HLSLNodeType_DiscardStatement) {
843 VisitDiscardStatement((HLSLDiscardStatement *)node);
844 }
845 else if (node->nodeType == HLSLNodeType_BreakStatement) {
846 VisitBreakStatement((HLSLBreakStatement *)node);
847 }
848 else if (node->nodeType == HLSLNodeType_ContinueStatement) {
849 VisitContinueStatement((HLSLContinueStatement *)node);
850 }
851 else if (node->nodeType == HLSLNodeType_IfStatement) {
852 VisitIfStatement((HLSLIfStatement *)node);
853 }
854 else if (node->nodeType == HLSLNodeType_ForStatement) {
855 VisitForStatement((HLSLForStatement *)node);
856 }
857 else if (node->nodeType == HLSLNodeType_WhileStatement) {
858 VisitWhileStatement((HLSLWhileStatement *)node);
859 }
860 else if (node->nodeType == HLSLNodeType_BlockStatement) {
861 VisitBlockStatement((HLSLBlockStatement *)node);
862 }
863 else {
864 ASSERT(0);
865 }
866 }
867
VisitDeclaration(HLSLDeclaration * node)868 void HLSLTreeVisitor::VisitDeclaration(HLSLDeclaration * node)
869 {
870 VisitType(node->type);
871 /*do {
872 VisitExpression(node->assignment);
873 node = node->nextDeclaration;
874 } while (node);*/
875 if (node->assignment != NULL) {
876 VisitExpression(node->assignment);
877 }
878 if (node->nextDeclaration != NULL) {
879 VisitDeclaration(node->nextDeclaration);
880 }
881 }
882
VisitStruct(HLSLStruct * node)883 void HLSLTreeVisitor::VisitStruct(HLSLStruct * node)
884 {
885 HLSLStructField * field = node->field;
886 while (field != NULL) {
887 VisitStructField(field);
888 field = field->nextField;
889 }
890 }
891
VisitStructField(HLSLStructField * node)892 void HLSLTreeVisitor::VisitStructField(HLSLStructField * node)
893 {
894 VisitType(node->type);
895 }
896
VisitBuffer(HLSLBuffer * node)897 void HLSLTreeVisitor::VisitBuffer(HLSLBuffer * node)
898 {
899 HLSLDeclaration * field = node->field;
900 while (field != NULL) {
901 ASSERT(field->nodeType == HLSLNodeType_Declaration);
902 VisitDeclaration(field);
903 ASSERT(field->nextDeclaration == NULL);
904 field = (HLSLDeclaration *)field->nextStatement;
905 }
906 }
907
908 /*void HLSLTreeVisitor::VisitBufferField(HLSLBufferField * node)
909 {
910 VisitType(node->type);
911 }*/
912
VisitFunction(HLSLFunction * node)913 void HLSLTreeVisitor::VisitFunction(HLSLFunction * node)
914 {
915 VisitType(node->returnType);
916
917 HLSLArgument * argument = node->argument;
918 while (argument != NULL) {
919 VisitArgument(argument);
920 argument = argument->nextArgument;
921 }
922
923 VisitStatements(node->statement);
924 }
925
VisitArgument(HLSLArgument * node)926 void HLSLTreeVisitor::VisitArgument(HLSLArgument * node)
927 {
928 VisitType(node->type);
929 if (node->defaultValue != NULL) {
930 VisitExpression(node->defaultValue);
931 }
932 }
933
VisitExpressionStatement(HLSLExpressionStatement * node)934 void HLSLTreeVisitor::VisitExpressionStatement(HLSLExpressionStatement * node)
935 {
936 VisitExpression(node->expression);
937 }
938
VisitExpression(HLSLExpression * node)939 void HLSLTreeVisitor::VisitExpression(HLSLExpression * node)
940 {
941 VisitType(node->expressionType);
942
943 if (node->nodeType == HLSLNodeType_UnaryExpression) {
944 VisitUnaryExpression((HLSLUnaryExpression *)node);
945 }
946 else if (node->nodeType == HLSLNodeType_BinaryExpression) {
947 VisitBinaryExpression((HLSLBinaryExpression *)node);
948 }
949 else if (node->nodeType == HLSLNodeType_ConditionalExpression) {
950 VisitConditionalExpression((HLSLConditionalExpression *)node);
951 }
952 else if (node->nodeType == HLSLNodeType_CastingExpression) {
953 VisitCastingExpression((HLSLCastingExpression *)node);
954 }
955 else if (node->nodeType == HLSLNodeType_LiteralExpression) {
956 VisitLiteralExpression((HLSLLiteralExpression *)node);
957 }
958 else if (node->nodeType == HLSLNodeType_IdentifierExpression) {
959 VisitIdentifierExpression((HLSLIdentifierExpression *)node);
960 }
961 else if (node->nodeType == HLSLNodeType_ConstructorExpression) {
962 VisitConstructorExpression((HLSLConstructorExpression *)node);
963 }
964 else if (node->nodeType == HLSLNodeType_MemberAccess) {
965 VisitMemberAccess((HLSLMemberAccess *)node);
966 }
967 else if (node->nodeType == HLSLNodeType_ArrayAccess) {
968 VisitArrayAccess((HLSLArrayAccess *)node);
969 }
970 else if (node->nodeType == HLSLNodeType_FunctionCall) {
971 VisitFunctionCall((HLSLFunctionCall *)node);
972 }
973 // Acoget-TODO: This was missing. Did adding it break anything?
974 else if (node->nodeType == HLSLNodeType_SamplerState) {
975 VisitSamplerState((HLSLSamplerState *)node);
976 }
977 else {
978 ASSERT(0);
979 }
980 }
981
VisitReturnStatement(HLSLReturnStatement * node)982 void HLSLTreeVisitor::VisitReturnStatement(HLSLReturnStatement * node)
983 {
984 VisitExpression(node->expression);
985 }
986
VisitDiscardStatement(HLSLDiscardStatement * node)987 void HLSLTreeVisitor::VisitDiscardStatement(HLSLDiscardStatement * node) {}
VisitBreakStatement(HLSLBreakStatement * node)988 void HLSLTreeVisitor::VisitBreakStatement(HLSLBreakStatement * node) {}
VisitContinueStatement(HLSLContinueStatement * node)989 void HLSLTreeVisitor::VisitContinueStatement(HLSLContinueStatement * node) {}
990
VisitIfStatement(HLSLIfStatement * node)991 void HLSLTreeVisitor::VisitIfStatement(HLSLIfStatement * node)
992 {
993 VisitExpression(node->condition);
994 VisitStatements(node->statement);
995 if (node->elseStatement) {
996 VisitStatements(node->elseStatement);
997 }
998 }
999
VisitForStatement(HLSLForStatement * node)1000 void HLSLTreeVisitor::VisitForStatement(HLSLForStatement * node)
1001 {
1002 if (node->initialization) {
1003 VisitDeclaration(node->initialization);
1004 }
1005 if (node->condition) {
1006 VisitExpression(node->condition);
1007 }
1008 if (node->increment) {
1009 VisitExpression(node->increment);
1010 }
1011 VisitStatements(node->statement);
1012 }
1013
VisitWhileStatement(HLSLWhileStatement * node)1014 void HLSLTreeVisitor::VisitWhileStatement(HLSLWhileStatement * node)
1015 {
1016 if (node->condition) {
1017 VisitExpression(node->condition);
1018 }
1019 VisitStatements(node->statement);
1020 }
1021
VisitBlockStatement(HLSLBlockStatement * node)1022 void HLSLTreeVisitor::VisitBlockStatement(HLSLBlockStatement * node)
1023 {
1024 VisitStatements(node->statement);
1025 }
1026
VisitUnaryExpression(HLSLUnaryExpression * node)1027 void HLSLTreeVisitor::VisitUnaryExpression(HLSLUnaryExpression * node)
1028 {
1029 VisitExpression(node->expression);
1030 }
1031
VisitBinaryExpression(HLSLBinaryExpression * node)1032 void HLSLTreeVisitor::VisitBinaryExpression(HLSLBinaryExpression * node)
1033 {
1034 VisitExpression(node->expression1);
1035 VisitExpression(node->expression2);
1036 }
1037
VisitConditionalExpression(HLSLConditionalExpression * node)1038 void HLSLTreeVisitor::VisitConditionalExpression(HLSLConditionalExpression * node)
1039 {
1040 VisitExpression(node->condition);
1041 VisitExpression(node->falseExpression);
1042 VisitExpression(node->trueExpression);
1043 }
1044
VisitCastingExpression(HLSLCastingExpression * node)1045 void HLSLTreeVisitor::VisitCastingExpression(HLSLCastingExpression * node)
1046 {
1047 VisitType(node->type);
1048 VisitExpression(node->expression);
1049 }
1050
VisitLiteralExpression(HLSLLiteralExpression * node)1051 void HLSLTreeVisitor::VisitLiteralExpression(HLSLLiteralExpression * node) {}
VisitIdentifierExpression(HLSLIdentifierExpression * node)1052 void HLSLTreeVisitor::VisitIdentifierExpression(HLSLIdentifierExpression * node) {}
1053
VisitConstructorExpression(HLSLConstructorExpression * node)1054 void HLSLTreeVisitor::VisitConstructorExpression(HLSLConstructorExpression * node)
1055 {
1056 HLSLExpression * argument = node->argument;
1057 while (argument != NULL) {
1058 VisitExpression(argument);
1059 argument = argument->nextExpression;
1060 }
1061 }
1062
VisitMemberAccess(HLSLMemberAccess * node)1063 void HLSLTreeVisitor::VisitMemberAccess(HLSLMemberAccess * node)
1064 {
1065 VisitExpression(node->object);
1066 }
1067
VisitArrayAccess(HLSLArrayAccess * node)1068 void HLSLTreeVisitor::VisitArrayAccess(HLSLArrayAccess * node)
1069 {
1070 VisitExpression(node->array);
1071 VisitExpression(node->index);
1072 }
1073
VisitFunctionCall(HLSLFunctionCall * node)1074 void HLSLTreeVisitor::VisitFunctionCall(HLSLFunctionCall * node)
1075 {
1076 HLSLExpression * argument = node->argument;
1077 while (argument != NULL) {
1078 VisitExpression(argument);
1079 argument = argument->nextExpression;
1080 }
1081 }
1082
VisitStateAssignment(HLSLStateAssignment * node)1083 void HLSLTreeVisitor::VisitStateAssignment(HLSLStateAssignment * node) {}
1084
VisitSamplerState(HLSLSamplerState * node)1085 void HLSLTreeVisitor::VisitSamplerState(HLSLSamplerState * node)
1086 {
1087 HLSLStateAssignment * stateAssignment = node->stateAssignments;
1088 while (stateAssignment != NULL) {
1089 VisitStateAssignment(stateAssignment);
1090 stateAssignment = stateAssignment->nextStateAssignment;
1091 }
1092 }
1093
VisitPass(HLSLPass * node)1094 void HLSLTreeVisitor::VisitPass(HLSLPass * node)
1095 {
1096 HLSLStateAssignment * stateAssignment = node->stateAssignments;
1097 while (stateAssignment != NULL) {
1098 VisitStateAssignment(stateAssignment);
1099 stateAssignment = stateAssignment->nextStateAssignment;
1100 }
1101 }
1102
VisitTechnique(HLSLTechnique * node)1103 void HLSLTreeVisitor::VisitTechnique(HLSLTechnique * node)
1104 {
1105 HLSLPass * pass = node->passes;
1106 while (pass != NULL) {
1107 VisitPass(pass);
1108 pass = pass->nextPass;
1109 }
1110 }
1111
VisitPipeline(HLSLPipeline * node)1112 void HLSLTreeVisitor::VisitPipeline(HLSLPipeline * node)
1113 {
1114 // @@ ?
1115 }
1116
VisitFunctions(HLSLRoot * root)1117 void HLSLTreeVisitor::VisitFunctions(HLSLRoot * root)
1118 {
1119 HLSLStatement * statement = root->statement;
1120 while (statement != NULL) {
1121 if (statement->nodeType == HLSLNodeType_Function) {
1122 VisitFunction((HLSLFunction *)statement);
1123 }
1124
1125 statement = statement->nextStatement;
1126 }
1127 }
1128
VisitParameters(HLSLRoot * root)1129 void HLSLTreeVisitor::VisitParameters(HLSLRoot * root)
1130 {
1131 HLSLStatement * statement = root->statement;
1132 while (statement != NULL) {
1133 if (statement->nodeType == HLSLNodeType_Declaration) {
1134 VisitDeclaration((HLSLDeclaration *)statement);
1135 }
1136
1137 statement = statement->nextStatement;
1138 }
1139 }
1140
1141
1142 class ResetHiddenFlagVisitor : public HLSLTreeVisitor
1143 {
1144 public:
VisitTopLevelStatement(HLSLStatement * statement)1145 virtual void VisitTopLevelStatement(HLSLStatement * statement)
1146 {
1147 statement->hidden = true;
1148
1149 if (statement->nodeType == HLSLNodeType_Buffer)
1150 {
1151 VisitBuffer((HLSLBuffer*)statement);
1152 }
1153 }
1154
1155 // Hide buffer fields.
VisitDeclaration(HLSLDeclaration * node)1156 virtual void VisitDeclaration(HLSLDeclaration * node)
1157 {
1158 node->hidden = true;
1159 }
1160
VisitArgument(HLSLArgument * node)1161 virtual void VisitArgument(HLSLArgument * node)
1162 {
1163 node->hidden = false; // Arguments are visible by default.
1164 }
1165 };
1166
1167 class MarkVisibleStatementsVisitor : public HLSLTreeVisitor
1168 {
1169 public:
1170 HLSLTree * tree;
MarkVisibleStatementsVisitor(HLSLTree * _tree)1171 MarkVisibleStatementsVisitor(HLSLTree * _tree) : tree(_tree) {}
1172
VisitFunction(HLSLFunction * node)1173 virtual void VisitFunction(HLSLFunction * node)
1174 {
1175 node->hidden = false;
1176 HLSLTreeVisitor::VisitFunction(node);
1177
1178 if (node->forward)
1179 VisitFunction(node->forward);
1180 }
1181
VisitFunctionCall(HLSLFunctionCall * node)1182 virtual void VisitFunctionCall(HLSLFunctionCall * node)
1183 {
1184 HLSLTreeVisitor::VisitFunctionCall(node);
1185
1186 if (node->function->hidden)
1187 {
1188 VisitFunction(const_cast<HLSLFunction*>(node->function));
1189 }
1190 }
1191
VisitIdentifierExpression(HLSLIdentifierExpression * node)1192 virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node)
1193 {
1194 HLSLTreeVisitor::VisitIdentifierExpression(node);
1195
1196 if (node->global)
1197 {
1198 HLSLDeclaration * declaration = tree->FindGlobalDeclaration(node->name);
1199 if (declaration != NULL && declaration->hidden)
1200 {
1201 declaration->hidden = false;
1202 VisitDeclaration(declaration);
1203 }
1204 }
1205 }
1206
VisitType(HLSLType & type)1207 virtual void VisitType(HLSLType & type)
1208 {
1209 if (type.baseType == HLSLBaseType_UserDefined)
1210 {
1211 HLSLStruct * globalStruct = tree->FindGlobalStruct(type.typeName);
1212 if (globalStruct != NULL)
1213 {
1214 globalStruct->hidden = false;
1215 VisitStruct(globalStruct);
1216 }
1217 }
1218 }
1219
1220 };
1221
1222
PruneTree(HLSLTree * tree,const char * entryName0,const char * entryName1)1223 void PruneTree(HLSLTree* tree, const char* entryName0, const char* entryName1/*=NULL*/)
1224 {
1225 HLSLRoot* root = tree->GetRoot();
1226
1227 // Reset all flags.
1228 ResetHiddenFlagVisitor reset;
1229 reset.VisitRoot(root);
1230
1231 // Mark all the statements necessary for these entrypoints.
1232 HLSLFunction* entry = tree->FindFunction(entryName0);
1233 if (entry != NULL)
1234 {
1235 MarkVisibleStatementsVisitor mark(tree);
1236 mark.VisitFunction(entry);
1237 }
1238
1239 if (entryName1 != NULL)
1240 {
1241 entry = tree->FindFunction(entryName1);
1242 if (entry != NULL)
1243 {
1244 MarkVisibleStatementsVisitor mark(tree);
1245 mark.VisitFunction(entry);
1246 }
1247 }
1248
1249 // Mark buffers visible, if any of their fields is visible.
1250 HLSLStatement * statement = root->statement;
1251 while (statement != NULL)
1252 {
1253 if (statement->nodeType == HLSLNodeType_Buffer)
1254 {
1255 HLSLBuffer* buffer = (HLSLBuffer*)statement;
1256
1257 HLSLDeclaration* field = buffer->field;
1258 while (field != NULL)
1259 {
1260 ASSERT(field->nodeType == HLSLNodeType_Declaration);
1261 if (!field->hidden)
1262 {
1263 buffer->hidden = false;
1264 break;
1265 }
1266 field = (HLSLDeclaration*)field->nextStatement;
1267 }
1268 }
1269
1270 statement = statement->nextStatement;
1271 }
1272 }
1273
1274
SortTree(HLSLTree * tree)1275 void SortTree(HLSLTree * tree)
1276 {
1277 // Stable sort so that statements are in this order:
1278 // structs, declarations, functions, techniques.
1279 // but their relative order is preserved.
1280
1281 HLSLRoot* root = tree->GetRoot();
1282
1283 HLSLStatement* structs = NULL;
1284 HLSLStatement* lastStruct = NULL;
1285 HLSLStatement* constDeclarations = NULL;
1286 HLSLStatement* lastConstDeclaration = NULL;
1287 HLSLStatement* declarations = NULL;
1288 HLSLStatement* lastDeclaration = NULL;
1289 HLSLStatement* functions = NULL;
1290 HLSLStatement* lastFunction = NULL;
1291 HLSLStatement* other = NULL;
1292 HLSLStatement* lastOther = NULL;
1293
1294 HLSLStatement* statement = root->statement;
1295 while (statement != NULL) {
1296 HLSLStatement* nextStatement = statement->nextStatement;
1297 statement->nextStatement = NULL;
1298
1299 if (statement->nodeType == HLSLNodeType_Struct) {
1300 if (structs == NULL) structs = statement;
1301 if (lastStruct != NULL) lastStruct->nextStatement = statement;
1302 lastStruct = statement;
1303 }
1304 else if (statement->nodeType == HLSLNodeType_Declaration || statement->nodeType == HLSLNodeType_Buffer) {
1305 if (statement->nodeType == HLSLNodeType_Declaration && (((HLSLDeclaration *)statement)->type.flags & HLSLTypeFlag_Const)) {
1306 if (constDeclarations == NULL) constDeclarations = statement;
1307 if (lastConstDeclaration != NULL) lastConstDeclaration->nextStatement = statement;
1308 lastConstDeclaration = statement;
1309 }
1310 else {
1311 if (declarations == NULL) declarations = statement;
1312 if (lastDeclaration != NULL) lastDeclaration->nextStatement = statement;
1313 lastDeclaration = statement;
1314 }
1315 }
1316 else if (statement->nodeType == HLSLNodeType_Function) {
1317 if (functions == NULL) functions = statement;
1318 if (lastFunction != NULL) lastFunction->nextStatement = statement;
1319 lastFunction = statement;
1320 }
1321 else {
1322 if (other == NULL) other = statement;
1323 if (lastOther != NULL) lastOther->nextStatement = statement;
1324 lastOther = statement;
1325 }
1326
1327 statement = nextStatement;
1328 }
1329
1330 // Chain all the statements in the order that we want.
1331 HLSLStatement * firstStatement = structs;
1332 HLSLStatement * lastStatement = lastStruct;
1333
1334 if (constDeclarations != NULL) {
1335 if (firstStatement == NULL) firstStatement = constDeclarations;
1336 else lastStatement->nextStatement = constDeclarations;
1337 lastStatement = lastConstDeclaration;
1338 }
1339
1340 if (declarations != NULL) {
1341 if (firstStatement == NULL) firstStatement = declarations;
1342 else lastStatement->nextStatement = declarations;
1343 lastStatement = lastDeclaration;
1344 }
1345
1346 if (functions != NULL) {
1347 if (firstStatement == NULL) firstStatement = functions;
1348 else lastStatement->nextStatement = functions;
1349 lastStatement = lastFunction;
1350 }
1351
1352 if (other != NULL) {
1353 if (firstStatement == NULL) firstStatement = other;
1354 else lastStatement->nextStatement = other;
1355 lastStatement = lastOther;
1356 }
1357
1358 root->statement = firstStatement;
1359 }
1360
1361
1362
1363
1364
1365 // First and last can be the same.
AddStatements(HLSLRoot * root,HLSLStatement * before,HLSLStatement * first,HLSLStatement * last)1366 void AddStatements(HLSLRoot * root, HLSLStatement * before, HLSLStatement * first, HLSLStatement * last)
1367 {
1368 if (before == NULL) {
1369 last->nextStatement = root->statement;
1370 root->statement = first;
1371 }
1372 else {
1373 last->nextStatement = before->nextStatement;
1374 before->nextStatement = first;
1375 }
1376 }
1377
AddSingleStatement(HLSLRoot * root,HLSLStatement * before,HLSLStatement * statement)1378 void AddSingleStatement(HLSLRoot * root, HLSLStatement * before, HLSLStatement * statement)
1379 {
1380 AddStatements(root, before, statement, statement);
1381 }
1382
1383
1384
1385 // @@ This is very game-specific. Should be moved to pipeline_parser or somewhere else.
GroupParameters(HLSLTree * tree)1386 void GroupParameters(HLSLTree * tree)
1387 {
1388 // Sort parameters based on semantic and group them in cbuffers.
1389
1390 HLSLRoot* root = tree->GetRoot();
1391
1392 HLSLDeclaration * firstPerItemDeclaration = NULL;
1393 HLSLDeclaration * lastPerItemDeclaration = NULL;
1394
1395 HLSLDeclaration * instanceDataDeclaration = NULL;
1396
1397 HLSLDeclaration * firstPerPassDeclaration = NULL;
1398 HLSLDeclaration * lastPerPassDeclaration = NULL;
1399
1400 HLSLDeclaration * firstPerItemSampler = NULL;
1401 HLSLDeclaration * lastPerItemSampler = NULL;
1402
1403 HLSLDeclaration * firstPerPassSampler = NULL;
1404 HLSLDeclaration * lastPerPassSampler = NULL;
1405
1406 HLSLStatement * statementBeforeBuffers = NULL;
1407
1408 HLSLStatement* previousStatement = NULL;
1409 HLSLStatement* statement = root->statement;
1410 while (statement != NULL)
1411 {
1412 HLSLStatement* nextStatement = statement->nextStatement;
1413
1414 if (statement->nodeType == HLSLNodeType_Struct) // Do not remove this, or it will mess the else clause below.
1415 {
1416 statementBeforeBuffers = statement;
1417 }
1418 else if (statement->nodeType == HLSLNodeType_Declaration)
1419 {
1420 HLSLDeclaration* declaration = (HLSLDeclaration*)statement;
1421
1422 // We insert buffers after the last const declaration.
1423 if ((declaration->type.flags & HLSLTypeFlag_Const) != 0)
1424 {
1425 statementBeforeBuffers = statement;
1426 }
1427
1428 // Do not move samplers or static/const parameters.
1429 if ((declaration->type.flags & (HLSLTypeFlag_Static|HLSLTypeFlag_Const)) == 0)
1430 {
1431 // Unlink statement.
1432 statement->nextStatement = NULL;
1433 if (previousStatement != NULL) previousStatement->nextStatement = nextStatement;
1434 else root->statement = nextStatement;
1435
1436 while(declaration != NULL)
1437 {
1438 HLSLDeclaration* nextDeclaration = declaration->nextDeclaration;
1439
1440 if (declaration->semantic != NULL && String_EqualNoCase(declaration->semantic, "PER_INSTANCED_ITEM"))
1441 {
1442 ASSERT(instanceDataDeclaration == NULL);
1443 instanceDataDeclaration = declaration;
1444 }
1445 else
1446 {
1447 // Select group based on type and semantic.
1448 HLSLDeclaration ** first, ** last;
1449 if (declaration->semantic == NULL || String_EqualNoCase(declaration->semantic, "PER_ITEM") || String_EqualNoCase(declaration->semantic, "PER_MATERIAL"))
1450 {
1451 if (IsSamplerType(declaration->type))
1452 {
1453 first = &firstPerItemSampler;
1454 last = &lastPerItemSampler;
1455 }
1456 else
1457 {
1458 first = &firstPerItemDeclaration;
1459 last = &lastPerItemDeclaration;
1460 }
1461 }
1462 else
1463 {
1464 if (IsSamplerType(declaration->type))
1465 {
1466 first = &firstPerPassSampler;
1467 last = &lastPerPassSampler;
1468 }
1469 else
1470 {
1471 first = &firstPerPassDeclaration;
1472 last = &lastPerPassDeclaration;
1473 }
1474 }
1475
1476 // Add declaration to new list.
1477 if (*first == NULL) *first = declaration;
1478 else (*last)->nextStatement = declaration;
1479 *last = declaration;
1480 }
1481
1482 // Unlink from declaration list.
1483 declaration->nextDeclaration = NULL;
1484
1485 // Reset attributes.
1486 declaration->registerName = NULL;
1487 //declaration->semantic = NULL; // @@ Don't do this!
1488
1489 declaration = nextDeclaration;
1490 }
1491 }
1492 }
1493 /*else
1494 {
1495 if (statementBeforeBuffers == NULL) {
1496 // This is the location where we will insert our buffers.
1497 statementBeforeBuffers = previousStatement;
1498 }
1499 }*/
1500
1501 if (statement->nextStatement == nextStatement) {
1502 previousStatement = statement;
1503 }
1504 statement = nextStatement;
1505 }
1506
1507
1508 // Add instance data declaration at the end of the per_item buffer.
1509 if (instanceDataDeclaration != NULL)
1510 {
1511 if (firstPerItemDeclaration == NULL) firstPerItemDeclaration = instanceDataDeclaration;
1512 else lastPerItemDeclaration->nextStatement = instanceDataDeclaration;
1513 }
1514
1515
1516 // Add samplers.
1517 if (firstPerItemSampler != NULL) {
1518 AddStatements(root, statementBeforeBuffers, firstPerItemSampler, lastPerItemSampler);
1519 statementBeforeBuffers = lastPerItemSampler;
1520 }
1521 if (firstPerPassSampler != NULL) {
1522 AddStatements(root, statementBeforeBuffers, firstPerPassSampler, lastPerPassSampler);
1523 statementBeforeBuffers = lastPerPassSampler;
1524 }
1525
1526
1527 // @@ We are assuming per_item and per_pass buffers don't already exist. @@ We should assert on that.
1528
1529 if (firstPerItemDeclaration != NULL)
1530 {
1531 // Create buffer statement.
1532 HLSLBuffer * perItemBuffer = tree->AddNode<HLSLBuffer>(firstPerItemDeclaration->fileName, firstPerItemDeclaration->line-1);
1533 perItemBuffer->name = tree->AddString("per_item");
1534 perItemBuffer->registerName = tree->AddString("b0");
1535 perItemBuffer->field = firstPerItemDeclaration;
1536
1537 // Set declaration buffer pointers.
1538 HLSLDeclaration * field = perItemBuffer->field;
1539 while (field != NULL)
1540 {
1541 field->buffer = perItemBuffer;
1542 field = (HLSLDeclaration *)field->nextStatement;
1543 }
1544
1545 // Add buffer to statements.
1546 AddSingleStatement(root, statementBeforeBuffers, perItemBuffer);
1547 statementBeforeBuffers = perItemBuffer;
1548 }
1549
1550 if (firstPerPassDeclaration != NULL)
1551 {
1552 // Create buffer statement.
1553 HLSLBuffer * perPassBuffer = tree->AddNode<HLSLBuffer>(firstPerPassDeclaration->fileName, firstPerPassDeclaration->line-1);
1554 perPassBuffer->name = tree->AddString("per_pass");
1555 perPassBuffer->registerName = tree->AddString("b1");
1556 perPassBuffer->field = firstPerPassDeclaration;
1557
1558 // Set declaration buffer pointers.
1559 HLSLDeclaration * field = perPassBuffer->field;
1560 while (field != NULL)
1561 {
1562 field->buffer = perPassBuffer;
1563 field = (HLSLDeclaration *)field->nextStatement;
1564 }
1565
1566 // Add buffer to statements.
1567 AddSingleStatement(root, statementBeforeBuffers, perPassBuffer);
1568 }
1569 }
1570
1571
1572 class FindArgumentVisitor : public HLSLTreeVisitor
1573 {
1574 public:
1575 bool found;
1576 const char * name;
1577
FindArgumentVisitor()1578 FindArgumentVisitor()
1579 {
1580 found = false;
1581 name = NULL;
1582 }
1583
FindArgument(const char * _name,HLSLFunction * function)1584 bool FindArgument(const char * _name, HLSLFunction * function)
1585 {
1586 this->found = false;
1587 this->name = _name;
1588 VisitStatements(function->statement);
1589 return found;
1590 }
1591
VisitStatements(HLSLStatement * statement)1592 virtual void VisitStatements(HLSLStatement * statement) override
1593 {
1594 while (statement != NULL && !found)
1595 {
1596 VisitStatement(statement);
1597 statement = statement->nextStatement;
1598 }
1599 }
1600
VisitIdentifierExpression(HLSLIdentifierExpression * node)1601 virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node) override
1602 {
1603 if (node->name == name)
1604 {
1605 found = true;
1606 }
1607 }
1608 };
1609
1610
HideUnusedArguments(HLSLFunction * function)1611 void HideUnusedArguments(HLSLFunction * function)
1612 {
1613 FindArgumentVisitor visitor;
1614
1615 // For each argument.
1616 HLSLArgument * arg = function->argument;
1617 while (arg != NULL)
1618 {
1619 if (!visitor.FindArgument(arg->name, function))
1620 {
1621 arg->hidden = true;
1622 }
1623
1624 arg = arg->nextArgument;
1625 }
1626 }
1627
EmulateAlphaTest(HLSLTree * tree,const char * entryName,float alphaRef)1628 bool EmulateAlphaTest(HLSLTree* tree, const char* entryName, float alphaRef/*=0.5*/)
1629 {
1630 // Find all return statements of this entry point.
1631 HLSLFunction* entry = tree->FindFunction(entryName);
1632 if (entry != NULL)
1633 {
1634 HLSLStatement ** ptr = &entry->statement;
1635 HLSLStatement * statement = entry->statement;
1636 while (statement != NULL)
1637 {
1638 if (statement->nodeType == HLSLNodeType_ReturnStatement)
1639 {
1640 HLSLReturnStatement * returnStatement = (HLSLReturnStatement *)statement;
1641 HLSLBaseType returnType = returnStatement->expression->expressionType.baseType;
1642
1643 // Build statement: "if (%s.a < 0.5) discard;"
1644
1645 HLSLDiscardStatement * discard = tree->AddNode<HLSLDiscardStatement>(statement->fileName, statement->line);
1646
1647 HLSLExpression * alpha = NULL;
1648 if (returnType == HLSLBaseType_Float4)
1649 {
1650 // @@ If return expression is a constructor, grab 4th argument.
1651 // That's not as easy, since we support 'float4(float3, float)' or 'float4(float, float3)', extracting
1652 // the latter is not that easy.
1653 /*if (returnStatement->expression->nodeType == HLSLNodeType_ConstructorExpression) {
1654 HLSLConstructorExpression * constructor = (HLSLConstructorExpression *)returnStatement->expression;
1655 //constructor->
1656 }
1657 */
1658
1659 if (alpha == NULL) {
1660 HLSLMemberAccess * access = tree->AddNode<HLSLMemberAccess>(statement->fileName, statement->line);
1661 access->expressionType = HLSLType(HLSLBaseType_Float);
1662 access->object = returnStatement->expression; // @@ Is reference OK? Or should we clone expression?
1663 access->field = tree->AddString("a");
1664 access->swizzle = true;
1665
1666 alpha = access;
1667 }
1668 }
1669 else if (returnType == HLSLBaseType_Float)
1670 {
1671 alpha = returnStatement->expression; // @@ Is reference OK? Or should we clone expression?
1672 }
1673 else
1674 {
1675 return false;
1676 }
1677
1678 HLSLLiteralExpression * threshold = tree->AddNode<HLSLLiteralExpression>(statement->fileName, statement->line);
1679 threshold->expressionType = HLSLType(HLSLBaseType_Float);
1680 threshold->fValue = alphaRef;
1681 threshold->type = HLSLBaseType_Float;
1682
1683 HLSLBinaryExpression * condition = tree->AddNode<HLSLBinaryExpression>(statement->fileName, statement->line);
1684 condition->expressionType = HLSLType(HLSLBaseType_Bool);
1685 condition->binaryOp = HLSLBinaryOp_Less;
1686 condition->expression1 = alpha;
1687 condition->expression2 = threshold;
1688
1689 // Insert statement.
1690 HLSLIfStatement * st = tree->AddNode<HLSLIfStatement>(statement->fileName, statement->line);
1691 st->condition = condition;
1692 st->statement = discard;
1693 st->nextStatement = statement;
1694 *ptr = st;
1695 }
1696
1697 ptr = &statement->nextStatement;
1698 statement = statement->nextStatement;
1699 }
1700 }
1701
1702 return true;
1703 }
1704
NeedsFlattening(HLSLExpression * expr,int level=0)1705 bool NeedsFlattening(HLSLExpression * expr, int level = 0) {
1706 if (expr == NULL) {
1707 return false;
1708 }
1709 if (expr->nodeType == HLSLNodeType_UnaryExpression) {
1710 HLSLUnaryExpression * unaryExpr = (HLSLUnaryExpression *)expr;
1711 return NeedsFlattening(unaryExpr->expression, level+1) || NeedsFlattening(expr->nextExpression, level);
1712 }
1713 else if (expr->nodeType == HLSLNodeType_BinaryExpression) {
1714 HLSLBinaryExpression * binaryExpr = (HLSLBinaryExpression *)expr;
1715 if (IsAssignOp(binaryExpr->binaryOp)) {
1716 return NeedsFlattening(binaryExpr->expression2, level+1) || NeedsFlattening(expr->nextExpression, level);
1717 }
1718 else {
1719 return NeedsFlattening(binaryExpr->expression1, level+1) || NeedsFlattening(binaryExpr->expression2, level+1) || NeedsFlattening(expr->nextExpression, level);
1720 }
1721 }
1722 else if (expr->nodeType == HLSLNodeType_ConditionalExpression) {
1723 HLSLConditionalExpression * conditionalExpr = (HLSLConditionalExpression *)expr;
1724 return NeedsFlattening(conditionalExpr->condition, level+1) || NeedsFlattening(conditionalExpr->trueExpression, level+1) || NeedsFlattening(conditionalExpr->falseExpression, level+1) || NeedsFlattening(expr->nextExpression, level);
1725 }
1726 else if (expr->nodeType == HLSLNodeType_CastingExpression) {
1727 HLSLCastingExpression * castingExpr = (HLSLCastingExpression *)expr;
1728 return NeedsFlattening(castingExpr->expression, level+1) || NeedsFlattening(expr->nextExpression, level);
1729 }
1730 else if (expr->nodeType == HLSLNodeType_LiteralExpression) {
1731 return NeedsFlattening(expr->nextExpression, level);
1732 }
1733 else if (expr->nodeType == HLSLNodeType_IdentifierExpression) {
1734 return NeedsFlattening(expr->nextExpression, level);
1735 }
1736 else if (expr->nodeType == HLSLNodeType_ConstructorExpression) {
1737 HLSLConstructorExpression * constructorExpr = (HLSLConstructorExpression *)expr;
1738 return NeedsFlattening(constructorExpr->argument, level+1) || NeedsFlattening(expr->nextExpression, level);
1739 }
1740 else if (expr->nodeType == HLSLNodeType_MemberAccess) {
1741 return NeedsFlattening(expr->nextExpression, level+1);
1742 }
1743 else if (expr->nodeType == HLSLNodeType_ArrayAccess) {
1744 HLSLArrayAccess * arrayAccess = (HLSLArrayAccess *)expr;
1745 return NeedsFlattening(arrayAccess->array, level+1) || NeedsFlattening(arrayAccess->index, level+1) || NeedsFlattening(expr->nextExpression, level);
1746 }
1747 else if (expr->nodeType == HLSLNodeType_FunctionCall) {
1748 HLSLFunctionCall * functionCall = (HLSLFunctionCall *)expr;
1749 if (functionCall->function->numOutputArguments > 0) {
1750 if (level > 0) {
1751 return true;
1752 }
1753 }
1754 return NeedsFlattening(functionCall->argument, level+1) || NeedsFlattening(expr->nextExpression, level);
1755 }
1756 else {
1757 //assert(false);
1758 return false;
1759 }
1760 }
1761
1762
1763 struct StatementList {
1764 HLSLStatement * head = NULL;
1765 HLSLStatement * tail = NULL;
appendM4::StatementList1766 void append(HLSLStatement * st) {
1767 if (head == NULL) {
1768 tail = head = st;
1769 }
1770 tail->nextStatement = st;
1771 tail = st;
1772 }
1773 };
1774
1775
1776 class ExpressionFlattener : public HLSLTreeVisitor
1777 {
1778 public:
1779 HLSLTree * m_tree;
1780 int tmp_index;
1781 HLSLStatement ** statement_pointer;
1782 HLSLFunction * current_function;
1783
ExpressionFlattener()1784 ExpressionFlattener()
1785 {
1786 m_tree = NULL;
1787 tmp_index = 0;
1788 statement_pointer = NULL;
1789 current_function = NULL;
1790 }
1791
FlattenExpressions(HLSLTree * tree)1792 void FlattenExpressions(HLSLTree * tree)
1793 {
1794 m_tree = tree;
1795 VisitRoot(tree->GetRoot());
1796 }
1797
1798 // Visit all statements updating the statement_pointer so that we can insert and replace statements. @@ Add this to the default visitor?
VisitFunction(HLSLFunction * node)1799 virtual void VisitFunction(HLSLFunction * node) override
1800 {
1801 current_function = node;
1802 statement_pointer = &node->statement;
1803 VisitStatements(node->statement);
1804 statement_pointer = NULL;
1805 current_function = NULL;
1806 }
1807
VisitIfStatement(HLSLIfStatement * node)1808 virtual void VisitIfStatement(HLSLIfStatement * node) override
1809 {
1810 if (NeedsFlattening(node->condition, 1)) {
1811 assert(false); // @@ Add statements before if statement.
1812 }
1813
1814 statement_pointer = &node->statement;
1815 VisitStatements(node->statement);
1816 if (node->elseStatement) {
1817 statement_pointer = &node->elseStatement;
1818 VisitStatements(node->elseStatement);
1819 }
1820 }
1821
VisitForStatement(HLSLForStatement * node)1822 virtual void VisitForStatement(HLSLForStatement * node) override
1823 {
1824 if (NeedsFlattening(node->initialization->assignment, 1)) {
1825 assert(false); // @@ Add statements before for statement.
1826 }
1827 if (NeedsFlattening(node->condition, 1) || NeedsFlattening(node->increment, 1)) {
1828 assert(false); // @@ These are tricky to implement. Need to handle all loop exits.
1829 }
1830
1831 statement_pointer = &node->statement;
1832 VisitStatements(node->statement);
1833 }
1834
VisitBlockStatement(HLSLBlockStatement * node)1835 virtual void VisitBlockStatement(HLSLBlockStatement * node) override
1836 {
1837 statement_pointer = &node->statement;
1838 VisitStatements(node->statement);
1839 }
1840
VisitStatements(HLSLStatement * statement)1841 virtual void VisitStatements(HLSLStatement * statement) override
1842 {
1843 while (statement != NULL) {
1844 VisitStatement(statement);
1845 statement_pointer = &statement->nextStatement;
1846 statement = statement->nextStatement;
1847 }
1848 }
1849
1850 // This is usually a function call or assignment.
VisitExpressionStatement(HLSLExpressionStatement * node)1851 virtual void VisitExpressionStatement(HLSLExpressionStatement * node) override
1852 {
1853 if (NeedsFlattening(node->expression, 0))
1854 {
1855 StatementList statements;
1856 Flatten(node->expression, statements, false);
1857
1858 // Link beginning of statement list.
1859 *statement_pointer = statements.head;
1860
1861 // Link end of statement list.
1862 HLSLStatement * tail = statements.tail;
1863 tail->nextStatement = node->nextStatement;
1864
1865 // Update statement pointer.
1866 statement_pointer = &tail->nextStatement;
1867
1868 // @@ Delete node?
1869 }
1870 }
1871
VisitDeclaration(HLSLDeclaration * node)1872 virtual void VisitDeclaration(HLSLDeclaration * node) override
1873 {
1874 // Skip global declarations.
1875 if (statement_pointer == NULL) return;
1876
1877 if (NeedsFlattening(node->assignment, 1))
1878 {
1879 StatementList statements;
1880 HLSLIdentifierExpression * ident = Flatten(node->assignment, statements, true);
1881
1882 // @@ Delete node->assignment?
1883
1884 node->assignment = ident;
1885 statements.append(node);
1886
1887 // Link beginning of statement list.
1888 *statement_pointer = statements.head;
1889
1890 // Link end of statement list.
1891 HLSLStatement * tail = statements.tail;
1892 tail->nextStatement = node->nextStatement;
1893
1894 // Update statement pointer.
1895 statement_pointer = &tail->nextStatement;
1896 }
1897 }
1898
VisitReturnStatement(HLSLReturnStatement * node)1899 virtual void VisitReturnStatement(HLSLReturnStatement * node) override
1900 {
1901 if (NeedsFlattening(node->expression, 1))
1902 {
1903 StatementList statements;
1904 HLSLIdentifierExpression * ident = Flatten(node->expression, statements, true);
1905
1906 // @@ Delete node->expression?
1907
1908 node->expression = ident;
1909 statements.append(node);
1910
1911 // Link beginning of statement list.
1912 *statement_pointer = statements.head;
1913
1914 // Link end of statement list.
1915 HLSLStatement * tail = statements.tail;
1916 tail->nextStatement = node->nextStatement;
1917
1918 // Update statement pointer.
1919 statement_pointer = &tail->nextStatement;
1920 }
1921 }
1922
1923
BuildTemporaryDeclaration(HLSLExpression * expr)1924 HLSLDeclaration * BuildTemporaryDeclaration(HLSLExpression * expr)
1925 {
1926 assert(expr->expressionType.baseType != HLSLBaseType_Void);
1927
1928 HLSLDeclaration * declaration = m_tree->AddNode<HLSLDeclaration>(expr->fileName, expr->line);
1929 declaration->name = m_tree->AddStringFormat("tmp%d", tmp_index++);
1930 declaration->type = expr->expressionType;
1931 declaration->assignment = expr;
1932
1933 return declaration;
1934 }
1935
BuildExpressionStatement(HLSLExpression * expr)1936 HLSLExpressionStatement * BuildExpressionStatement(HLSLExpression * expr)
1937 {
1938 HLSLExpressionStatement * statement = m_tree->AddNode<HLSLExpressionStatement>(expr->fileName, expr->line);
1939 statement->expression = expr;
1940 return statement;
1941 }
1942
AddExpressionStatement(HLSLExpression * expr,StatementList & statements,bool wantIdent)1943 HLSLIdentifierExpression * AddExpressionStatement(HLSLExpression * expr, StatementList & statements, bool wantIdent)
1944 {
1945 if (wantIdent) {
1946 HLSLDeclaration * declaration = BuildTemporaryDeclaration(expr);
1947 statements.append(declaration);
1948
1949 HLSLIdentifierExpression * ident = m_tree->AddNode<HLSLIdentifierExpression>(expr->fileName, expr->line);
1950 ident->name = declaration->name;
1951 ident->expressionType = declaration->type;
1952 return ident;
1953 }
1954 else {
1955 HLSLExpressionStatement * statement = BuildExpressionStatement(expr);
1956 statements.append(statement);
1957 return NULL;
1958 }
1959 }
1960
Flatten(HLSLExpression * expr,StatementList & statements,bool wantIdent=true)1961 HLSLIdentifierExpression * Flatten(HLSLExpression * expr, StatementList & statements, bool wantIdent = true)
1962 {
1963 if (!NeedsFlattening(expr, wantIdent)) {
1964 return AddExpressionStatement(expr, statements, wantIdent);
1965 }
1966
1967 if (expr->nodeType == HLSLNodeType_UnaryExpression) {
1968 assert(expr->nextExpression == NULL);
1969
1970 HLSLUnaryExpression * unaryExpr = (HLSLUnaryExpression *)expr;
1971
1972 HLSLIdentifierExpression * tmp = Flatten(unaryExpr->expression, statements, true);
1973
1974 HLSLUnaryExpression * newUnaryExpr = m_tree->AddNode<HLSLUnaryExpression>(unaryExpr->fileName, unaryExpr->line);
1975 newUnaryExpr->unaryOp = unaryExpr->unaryOp;
1976 newUnaryExpr->expression = tmp;
1977 newUnaryExpr->expressionType = unaryExpr->expressionType;
1978
1979 return AddExpressionStatement(newUnaryExpr, statements, wantIdent);
1980 }
1981 else if (expr->nodeType == HLSLNodeType_BinaryExpression) {
1982 assert(expr->nextExpression == NULL);
1983
1984 HLSLBinaryExpression * binaryExpr = (HLSLBinaryExpression *)expr;
1985
1986 if (IsAssignOp(binaryExpr->binaryOp)) {
1987 // Flatten right hand side only.
1988 HLSLIdentifierExpression * tmp2 = Flatten(binaryExpr->expression2, statements, true);
1989
1990 HLSLBinaryExpression * newBinaryExpr = m_tree->AddNode<HLSLBinaryExpression>(binaryExpr->fileName, binaryExpr->line);
1991 newBinaryExpr->binaryOp = binaryExpr->binaryOp;
1992 newBinaryExpr->expression1 = binaryExpr->expression1;
1993 newBinaryExpr->expression2 = tmp2;
1994 newBinaryExpr->expressionType = binaryExpr->expressionType;
1995
1996 return AddExpressionStatement(newBinaryExpr, statements, wantIdent);
1997 }
1998 else {
1999 HLSLIdentifierExpression * tmp1 = Flatten(binaryExpr->expression1, statements, true);
2000 HLSLIdentifierExpression * tmp2 = Flatten(binaryExpr->expression2, statements, true);
2001
2002 HLSLBinaryExpression * newBinaryExpr = m_tree->AddNode<HLSLBinaryExpression>(binaryExpr->fileName, binaryExpr->line);
2003 newBinaryExpr->binaryOp = binaryExpr->binaryOp;
2004 newBinaryExpr->expression1 = tmp1;
2005 newBinaryExpr->expression2 = tmp2;
2006 newBinaryExpr->expressionType = binaryExpr->expressionType;
2007
2008 return AddExpressionStatement(newBinaryExpr, statements, wantIdent);
2009 }
2010 }
2011 else if (expr->nodeType == HLSLNodeType_ConditionalExpression) {
2012 assert(false);
2013 }
2014 else if (expr->nodeType == HLSLNodeType_CastingExpression) {
2015 assert(false);
2016 }
2017 else if (expr->nodeType == HLSLNodeType_LiteralExpression) {
2018 assert(false);
2019 }
2020 else if (expr->nodeType == HLSLNodeType_IdentifierExpression) {
2021 assert(false);
2022 }
2023 else if (expr->nodeType == HLSLNodeType_ConstructorExpression) {
2024 assert(false);
2025 }
2026 else if (expr->nodeType == HLSLNodeType_MemberAccess) {
2027 assert(false);
2028 }
2029 else if (expr->nodeType == HLSLNodeType_ArrayAccess) {
2030 assert(false);
2031 }
2032 else if (expr->nodeType == HLSLNodeType_FunctionCall) {
2033 HLSLFunctionCall * functionCall = (HLSLFunctionCall *)expr;
2034
2035 // @@ Output function as is?
2036 // @@ We have to flatten function arguments! This is tricky, need to handle input/output arguments.
2037 assert(!NeedsFlattening(functionCall->argument));
2038
2039 return AddExpressionStatement(expr, statements, wantIdent);
2040 }
2041 else {
2042 assert(false);
2043 }
2044 return NULL;
2045 }
2046 };
2047
2048
FlattenExpressions(HLSLTree * tree)2049 void FlattenExpressions(HLSLTree* tree) {
2050 ExpressionFlattener flattener;
2051 flattener.FlattenExpressions(tree);
2052 }
2053
2054 } // M4
2055
2056