1 #ifndef HLSL_TREE_H
2 #define HLSL_TREE_H
3 
4 //#include "Engine/StringPool.h"
5 #include "Engine.h"
6 
7 #include <new>
8 #include <map>
9 #include <vector>
10 #include <string>
11 
12 namespace M4
13 {
14 
15 enum HLSLNodeType
16 {
17     HLSLNodeType_Root,
18     HLSLNodeType_Declaration,
19     HLSLNodeType_Struct,
20     HLSLNodeType_StructField,
21     HLSLNodeType_Buffer,
22     HLSLNodeType_BufferField,
23     HLSLNodeType_Function,
24     HLSLNodeType_Argument,
25     HLSLNodeType_Macro,
26     HLSLNodeType_ExpressionStatement,
27     HLSLNodeType_Expression,
28     HLSLNodeType_ReturnStatement,
29     HLSLNodeType_DiscardStatement,
30     HLSLNodeType_BreakStatement,
31     HLSLNodeType_ContinueStatement,
32     HLSLNodeType_IfStatement,
33     HLSLNodeType_ForStatement,
34     HLSLNodeType_WhileStatement,
35     HLSLNodeType_BlockStatement,
36     HLSLNodeType_UnaryExpression,
37     HLSLNodeType_BinaryExpression,
38     HLSLNodeType_ConditionalExpression,
39     HLSLNodeType_CastingExpression,
40     HLSLNodeType_LiteralExpression,
41     HLSLNodeType_IdentifierExpression,
42     HLSLNodeType_ConstructorExpression,
43     HLSLNodeType_MemberAccess,
44     HLSLNodeType_ArrayAccess,
45     HLSLNodeType_FunctionCall,
46     HLSLNodeType_StateAssignment,
47     HLSLNodeType_SamplerState,
48     HLSLNodeType_Pass,
49     HLSLNodeType_Technique,
50     HLSLNodeType_Attribute,
51     HLSLNodeType_Pipeline,
52     HLSLNodeType_Stage,
53 };
54 
55 
56 enum HLSLBaseType
57 {
58     HLSLBaseType_Unknown,
59     HLSLBaseType_Void,
60     HLSLBaseType_Float,
61     HLSLBaseType_FirstNumeric = HLSLBaseType_Float,
62     HLSLBaseType_Float2,
63     HLSLBaseType_Float3,
64     HLSLBaseType_Float4,
65 
66     HLSLBaseType_Float2x4,
67     HLSLBaseType_Float2x3,
68 	HLSLBaseType_Float2x2,
69 
70     HLSLBaseType_Float3x4,
71     HLSLBaseType_Float3x3,
72     HLSLBaseType_Float3x2,
73 
74     HLSLBaseType_Float4x4,
75     HLSLBaseType_Float4x3,
76     HLSLBaseType_Float4x2,
77 
78     HLSLBaseType_Bool,
79     HLSLBaseType_FirstInteger = HLSLBaseType_Bool,
80 	HLSLBaseType_Bool2,
81 	HLSLBaseType_Bool3,
82 	HLSLBaseType_Bool4,
83     HLSLBaseType_Int,
84     HLSLBaseType_Int2,
85     HLSLBaseType_Int3,
86     HLSLBaseType_Int4,
87     HLSLBaseType_Uint,
88     HLSLBaseType_Uint2,
89     HLSLBaseType_Uint3,
90     HLSLBaseType_Uint4,
91     /*HLSLBaseType_Short,   // @@ Separate dimension from Base type, this is getting out of control.
92     HLSLBaseType_Short2,
93     HLSLBaseType_Short3,
94     HLSLBaseType_Short4,
95     HLSLBaseType_Ushort,
96     HLSLBaseType_Ushort2,
97     HLSLBaseType_Ushort3,
98     HLSLBaseType_Ushort4,*/
99     HLSLBaseType_LastInteger = HLSLBaseType_Uint4,
100     HLSLBaseType_LastNumeric = HLSLBaseType_Uint4,
101     HLSLBaseType_Texture,
102     HLSLBaseType_Sampler,           // @@ use type inference to determine sampler type.
103     HLSLBaseType_Sampler2D,
104     HLSLBaseType_Sampler3D,
105     HLSLBaseType_SamplerCube,
106     HLSLBaseType_Sampler2DShadow,
107     HLSLBaseType_Sampler2DMS,
108     HLSLBaseType_Sampler2DArray,
109     HLSLBaseType_UserDefined,       // struct
110     HLSLBaseType_Expression,        // type argument for defined() sizeof() and typeof().
111     HLSLBaseType_Auto,
112 
113     HLSLBaseType_Count,
114     HLSLBaseType_NumericCount = HLSLBaseType_LastNumeric - HLSLBaseType_FirstNumeric + 1
115 };
116 
117 
118 
119 enum NumericType
120 {
121     NumericType_Float,
122     NumericType_Bool,
123     NumericType_Int,
124     NumericType_Uint,
125     NumericType_Count,
126     NumericType_NaN,
127 };
128 
129 
130 struct BaseTypeDescription
131 {
132     const char*     typeName;
133     NumericType     numericType;
134     int             numComponents;
135     int             numDimensions;
136     int             height;
137     int             binaryOpRank;
138 };
139 
140 const BaseTypeDescription baseTypeDescriptions[HLSLBaseType_Count] =
141     {
142         { "unknown type",       NumericType_NaN,        0, 0, 0, -1 },      // HLSLBaseType_Unknown
143         { "void",               NumericType_NaN,        0, 0, 0, -1 },      // HLSLBaseType_Void
144         { "float",              NumericType_Float,      1, 0, 1,  0 },      // HLSLBaseType_Float
145         { "float2",             NumericType_Float,      2, 1, 1,  0 },      // HLSLBaseType_Float2
146         { "float3",             NumericType_Float,      3, 1, 1,  0 },      // HLSLBaseType_Float3
147         { "float4",             NumericType_Float,      4, 1, 1,  0 },      // HLSLBaseType_Float4
148 
149         { "float2x4",			NumericType_Float,		2, 2, 4,  0 },		// HLSLBaseType_Float2x4
150         { "float2x3",			NumericType_Float,		2, 2, 3,  0 },		// HLSLBaseType_Float2x3
151         { "float2x2",			NumericType_Float,		2, 2, 2,  0 },		// HLSLBaseType_Float2x2
152 
153         { "float3x4",           NumericType_Float,      3, 2, 4,  0 },      // HLSLBaseType_Float3x4
154         { "float3x3",           NumericType_Float,      3, 2, 3,  0 },      // HLSLBaseType_Float3x3
155         { "float3x2",           NumericType_Float,      3, 2, 2,  0 },      // HLSLBaseType_Float3x2
156 
157         { "float4x4",           NumericType_Float,      4, 2, 4,  0 },      // HLSLBaseType_Float4x4
158         { "float4x3",           NumericType_Float,      4, 2, 3,  0 },      // HLSLBaseType_Float4x3
159         { "float4x2",           NumericType_Float,      4, 2, 2,  0 },      // HLSLBaseType_Float4x2
160 
161         { "bool",               NumericType_Bool,       1, 0, 1,  4 },      // HLSLBaseType_Bool
162         { "bool2",				NumericType_Bool,		2, 1, 1,  4 },      // HLSLBaseType_Bool2
163         { "bool3",				NumericType_Bool,		3, 1, 1,  4 },      // HLSLBaseType_Bool3
164         { "bool4",				NumericType_Bool,		4, 1, 1,  4 },      // HLSLBaseType_Bool4
165 
166         { "int",                NumericType_Int,        1, 0, 1,  3 },      // HLSLBaseType_Int
167         { "int2",               NumericType_Int,        2, 1, 1,  3 },      // HLSLBaseType_Int2
168         { "int3",               NumericType_Int,        3, 1, 1,  3 },      // HLSLBaseType_Int3
169         { "int4",               NumericType_Int,        4, 1, 1,  3 },      // HLSLBaseType_Int4
170 
171         { "uint",               NumericType_Uint,       1, 0, 1,  2 },      // HLSLBaseType_Uint
172         { "uint2",              NumericType_Uint,       2, 1, 1,  2 },      // HLSLBaseType_Uint2
173         { "uint3",              NumericType_Uint,       3, 1, 1,  2 },      // HLSLBaseType_Uint3
174         { "uint4",              NumericType_Uint,       4, 1, 1,  2 },      // HLSLBaseType_Uint4
175 
176         { "texture",            NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Texture
177         { "sampler",            NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler
178         { "sampler2D",          NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler2D
179         { "sampler3D",          NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler3D
180         { "samplerCUBE",        NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_SamplerCube
181         { "sampler2DShadow",    NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler2DShadow
182         { "sampler2DMS",        NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler2DMS
183         { "sampler2DArray",     NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_Sampler2DArray
184         { "user defined",       NumericType_NaN,        1, 0, 0, -1 },      // HLSLBaseType_UserDefined
185         { "expression",         NumericType_NaN,        1, 0, 0, -1 }       // HLSLBaseType_Expression
186     };
187 
188 extern const HLSLBaseType ScalarBaseType[HLSLBaseType_Count];
189 
IsSamplerType(HLSLBaseType baseType)190 inline bool IsSamplerType(HLSLBaseType baseType)
191 {
192     return baseType == HLSLBaseType_Sampler ||
193            baseType == HLSLBaseType_Sampler2D ||
194            baseType == HLSLBaseType_Sampler3D ||
195            baseType == HLSLBaseType_SamplerCube ||
196            baseType == HLSLBaseType_Sampler2DShadow ||
197            baseType == HLSLBaseType_Sampler2DMS ||
198            baseType == HLSLBaseType_Sampler2DArray;
199 }
200 
IsMatrixType(HLSLBaseType baseType)201 inline bool IsMatrixType(HLSLBaseType baseType)
202 {
203     return baseType == HLSLBaseType_Float2x4 || baseType == HLSLBaseType_Float2x3 || baseType == HLSLBaseType_Float2x2 ||
204            baseType == HLSLBaseType_Float3x4 || baseType == HLSLBaseType_Float3x3 || baseType == HLSLBaseType_Float3x2 ||
205            baseType == HLSLBaseType_Float4x4 || baseType == HLSLBaseType_Float4x3 || baseType == HLSLBaseType_Float4x2;
206 }
207 
IsScalarType(HLSLBaseType baseType)208 inline bool IsScalarType( HLSLBaseType baseType )
209 {
210 	return  baseType == HLSLBaseType_Float ||
211 			baseType == HLSLBaseType_Bool ||
212 			baseType == HLSLBaseType_Int ||
213 			baseType == HLSLBaseType_Uint;
214 }
215 
IsVectorType(HLSLBaseType baseType)216 inline bool IsVectorType( HLSLBaseType baseType )
217 {
218 	return  baseType == HLSLBaseType_Float2 ||
219 		baseType == HLSLBaseType_Float3 ||
220 		baseType == HLSLBaseType_Float4 ||
221 		baseType == HLSLBaseType_Bool2 ||
222 		baseType == HLSLBaseType_Bool3 ||
223 		baseType == HLSLBaseType_Bool4 ||
224 		baseType == HLSLBaseType_Int2  ||
225 		baseType == HLSLBaseType_Int3  ||
226 		baseType == HLSLBaseType_Int4  ||
227 		baseType == HLSLBaseType_Uint2 ||
228 		baseType == HLSLBaseType_Uint3 ||
229 		baseType == HLSLBaseType_Uint4;
230 }
231 
232 
233 enum HLSLBinaryOp
234 {
235     HLSLBinaryOp_And,
236     HLSLBinaryOp_Or,
237     HLSLBinaryOp_Add,
238     HLSLBinaryOp_Sub,
239     HLSLBinaryOp_Mul,
240     HLSLBinaryOp_Div,
241     HLSLBinaryOp_Mod,
242     HLSLBinaryOp_Less,
243     HLSLBinaryOp_Greater,
244     HLSLBinaryOp_LessEqual,
245     HLSLBinaryOp_GreaterEqual,
246     HLSLBinaryOp_Equal,
247     HLSLBinaryOp_NotEqual,
248     HLSLBinaryOp_BitAnd,
249     HLSLBinaryOp_BitOr,
250     HLSLBinaryOp_BitXor,
251     HLSLBinaryOp_Assign,
252     HLSLBinaryOp_AddAssign,
253     HLSLBinaryOp_SubAssign,
254     HLSLBinaryOp_MulAssign,
255     HLSLBinaryOp_DivAssign,
256 };
257 
IsCompareOp(HLSLBinaryOp op)258 inline bool IsCompareOp( HLSLBinaryOp op )
259 {
260 	return op == HLSLBinaryOp_Less ||
261 		op == HLSLBinaryOp_Greater ||
262 		op == HLSLBinaryOp_LessEqual ||
263 		op == HLSLBinaryOp_GreaterEqual ||
264 		op == HLSLBinaryOp_Equal ||
265 		op == HLSLBinaryOp_NotEqual;
266 }
267 
IsArithmeticOp(HLSLBinaryOp op)268 inline bool IsArithmeticOp( HLSLBinaryOp op )
269 {
270     return op == HLSLBinaryOp_Add ||
271         op == HLSLBinaryOp_Sub ||
272         op == HLSLBinaryOp_Mul ||
273         op == HLSLBinaryOp_Div ||
274         op == HLSLBinaryOp_Mod;
275 }
276 
IsLogicOp(HLSLBinaryOp op)277 inline bool IsLogicOp( HLSLBinaryOp op )
278 {
279     return op == HLSLBinaryOp_And ||
280         op == HLSLBinaryOp_Or;
281 }
282 
IsAssignOp(HLSLBinaryOp op)283 inline bool IsAssignOp( HLSLBinaryOp op )
284 {
285     return op == HLSLBinaryOp_Assign ||
286         op == HLSLBinaryOp_AddAssign ||
287         op == HLSLBinaryOp_SubAssign ||
288         op == HLSLBinaryOp_MulAssign ||
289         op == HLSLBinaryOp_DivAssign;
290 }
291 
292 
293 enum HLSLUnaryOp
294 {
295     HLSLUnaryOp_Negative,       // -x
296     HLSLUnaryOp_Positive,       // +x
297     HLSLUnaryOp_Not,            // !x
298     HLSLUnaryOp_PreIncrement,   // ++x
299     HLSLUnaryOp_PreDecrement,   // --x
300     HLSLUnaryOp_PostIncrement,  // x++
301     HLSLUnaryOp_PostDecrement,  // x++
302     HLSLUnaryOp_BitNot,         // ~x
303 };
304 
305 enum HLSLArgumentModifier
306 {
307     HLSLArgumentModifier_None,
308     HLSLArgumentModifier_In,
309     HLSLArgumentModifier_Out,
310     HLSLArgumentModifier_Inout,
311     HLSLArgumentModifier_Uniform,
312     HLSLArgumentModifier_Const,
313 };
314 
315 enum HLSLTypeFlags
316 {
317     HLSLTypeFlag_None = 0,
318     HLSLTypeFlag_Const = 0x01,
319     HLSLTypeFlag_Static = 0x02,
320     HLSLTypeFlag_Uniform = 0x04,
321     //HLSLTypeFlag_Extern = 0x10,
322     //HLSLTypeFlag_Volatile = 0x20,
323     //HLSLTypeFlag_Shared = 0x40,
324     //HLSLTypeFlag_Precise = 0x80,
325 
326     HLSLTypeFlag_Input = 0x100,
327     HLSLTypeFlag_Output = 0x200,
328 
329     // Interpolation modifiers.
330     HLSLTypeFlag_Linear = 0x10000,
331     HLSLTypeFlag_Centroid = 0x20000,
332     HLSLTypeFlag_NoInterpolation = 0x40000,
333     HLSLTypeFlag_NoPerspective = 0x80000,
334     HLSLTypeFlag_Sample = 0x100000,
335 
336     // Misc.
337     HLSLTypeFlag_NoPromote = 0x200000,
338 };
339 
340 enum HLSLAttributeType
341 {
342     HLSLAttributeType_Unknown,
343     HLSLAttributeType_Unroll,
344     HLSLAttributeType_Branch,
345     HLSLAttributeType_Flatten,
346     HLSLAttributeType_NoFastMath,
347 };
348 
349 enum HLSLAddressSpace
350 {
351     HLSLAddressSpace_Undefined,
352     HLSLAddressSpace_Constant,
353     HLSLAddressSpace_Device,
354     HLSLAddressSpace_Thread,
355     HLSLAddressSpace_Shared,
356 };
357 
358 
359 struct HLSLNode;
360 struct HLSLRoot;
361 struct HLSLStatement;
362 struct HLSLAttribute;
363 struct HLSLDeclaration;
364 struct HLSLStruct;
365 struct HLSLStructField;
366 struct HLSLBuffer;
367 struct HLSLFunction;
368 struct HLSLArgument;
369 struct HLSLExpressionStatement;
370 struct HLSLExpression;
371 struct HLSLBinaryExpression;
372 struct HLSLLiteralExpression;
373 struct HLSLIdentifierExpression;
374 struct HLSLConstructorExpression;
375 struct HLSLFunctionCall;
376 struct HLSLArrayAccess;
377 struct HLSLAttribute;
378 
379 struct HLSLType
380 {
381     explicit HLSLType(HLSLBaseType _baseType = HLSLBaseType_Unknown)
382     {
383         baseType    = _baseType;
384         samplerType = HLSLBaseType_Float;
385         typeName    = NULL;
386         array       = false;
387         arraySize   = NULL;
388         flags       = 0;
389         addressSpace = HLSLAddressSpace_Undefined;
390     }
391     HLSLBaseType        baseType;
392     HLSLBaseType        samplerType;    // Half or Float
393     const char*         typeName;       // For user defined types.
394     bool                array;
395     HLSLExpression*     arraySize;
396     int                 flags;
397     HLSLAddressSpace    addressSpace;
398 };
399 
IsSamplerType(const HLSLType & type)400 inline bool IsSamplerType(const HLSLType & type)
401 {
402     return IsSamplerType(type.baseType);
403 }
404 
IsScalarType(const HLSLType & type)405 inline bool IsScalarType(const HLSLType & type)
406 {
407 	return IsScalarType(type.baseType);
408 }
409 
IsVectorType(const HLSLType & type)410 inline bool IsVectorType(const HLSLType & type)
411 {
412 	return IsVectorType(type.baseType);
413 }
414 
415 
416 /** Base class for all nodes in the HLSL AST */
417 struct HLSLNode
418 {
419     HLSLNodeType        nodeType;
420     const char*         fileName;
421     int                 line;
422 };
423 
424 struct HLSLRoot : public HLSLNode
425 {
426     static const HLSLNodeType s_type = HLSLNodeType_Root;
HLSLRootHLSLRoot427     HLSLRoot()          { statement = NULL; }
428     HLSLStatement*      statement;          // First statement.
429 };
430 
431 struct HLSLStatement : public HLSLNode
432 {
HLSLStatementHLSLStatement433     HLSLStatement()
434     {
435         nextStatement   = NULL;
436         attributes      = NULL;
437         hidden          = false;
438     }
439     HLSLStatement*      nextStatement;      // Next statement in the block.
440     HLSLAttribute*      attributes;
441     mutable bool        hidden;
442 };
443 
444 struct HLSLAttribute : public HLSLNode
445 {
446     static const HLSLNodeType s_type = HLSLNodeType_Attribute;
HLSLAttributeHLSLAttribute447 	HLSLAttribute()
448 	{
449 		attributeType = HLSLAttributeType_Unknown;
450 		argument      = NULL;
451 		nextAttribute = NULL;
452 	}
453     HLSLAttributeType   attributeType;
454     HLSLExpression*     argument;
455     HLSLAttribute*      nextAttribute;
456 };
457 
458 struct HLSLDeclaration : public HLSLStatement
459 {
460     static const HLSLNodeType s_type = HLSLNodeType_Declaration;
HLSLDeclarationHLSLDeclaration461     HLSLDeclaration()
462     {
463         name            = NULL;
464         registerName    = NULL;
465         semantic        = NULL;
466         nextDeclaration = NULL;
467         assignment      = NULL;
468         buffer          = NULL;
469     }
470     const char*         name;
471     HLSLType            type;
472     const char*         registerName;       // @@ Store register index?
473     const char*         semantic;
474     HLSLDeclaration*    nextDeclaration;    // If multiple variables declared on a line.
475     HLSLExpression*     assignment;
476     HLSLBuffer*         buffer;
477 };
478 
479 struct HLSLStruct : public HLSLStatement
480 {
481     static const HLSLNodeType s_type = HLSLNodeType_Struct;
HLSLStructHLSLStruct482     HLSLStruct()
483     {
484         name            = NULL;
485         field           = NULL;
486     }
487     const char*         name;
488     HLSLStructField*    field;              // First field in the structure.
489 };
490 
491 struct HLSLStructField : public HLSLNode
492 {
493     static const HLSLNodeType s_type = HLSLNodeType_StructField;
HLSLStructFieldHLSLStructField494     HLSLStructField()
495     {
496         name            = NULL;
497         semantic        = NULL;
498         sv_semantic     = NULL;
499         nextField       = NULL;
500         hidden          = false;
501     }
502     const char*         name;
503     HLSLType            type;
504     const char*         semantic;
505     const char*         sv_semantic;
506     HLSLStructField*    nextField;      // Next field in the structure.
507     bool                hidden;
508 };
509 
510 /** A cbuffer or tbuffer declaration. */
511 struct HLSLBuffer : public HLSLStatement
512 {
513     static const HLSLNodeType s_type = HLSLNodeType_Buffer;
HLSLBufferHLSLBuffer514     HLSLBuffer()
515     {
516         name            = NULL;
517         registerName    = NULL;
518         field           = NULL;
519     }
520     const char*         name;
521     const char*         registerName;
522     HLSLDeclaration*    field;
523 };
524 
525 
526 /** Function declaration */
527 struct HLSLFunction : public HLSLStatement
528 {
529     static const HLSLNodeType s_type = HLSLNodeType_Function;
HLSLFunctionHLSLFunction530     HLSLFunction()
531     {
532         name            = NULL;
533         semantic        = NULL;
534         sv_semantic     = NULL;
535         statement       = NULL;
536         argument        = NULL;
537         numArguments    = 0;
538         numOutputArguments = 0;
539         forward         = NULL;
540     }
541     const char*         name;
542     HLSLType            returnType;
543     const char*         semantic;
544     const char*         sv_semantic;
545     int                 numArguments;
546     int                 numOutputArguments;     // Includes out and inout arguments.
547     HLSLArgument*       argument;
548     HLSLStatement*      statement;
549     HLSLFunction*       forward; // Which HLSLFunction this one forward-declares
550 };
551 
552 /** Declaration of an argument to a function. */
553 struct HLSLArgument : public HLSLNode
554 {
555     static const HLSLNodeType s_type = HLSLNodeType_Argument;
HLSLArgumentHLSLArgument556     HLSLArgument()
557     {
558         name            = NULL;
559         modifier        = HLSLArgumentModifier_None;
560         semantic        = NULL;
561         sv_semantic     = NULL;
562         defaultValue    = NULL;
563         nextArgument    = NULL;
564         hidden          = false;
565     }
566     const char*             name;
567     HLSLArgumentModifier    modifier;
568     HLSLType                type;
569     const char*             semantic;
570     const char*             sv_semantic;
571     HLSLExpression*         defaultValue;
572     HLSLArgument*           nextArgument;
573     bool                    hidden;
574 };
575 
576 /** Macro declaration */
577 struct HLSLMacro : public HLSLStatement
578 {
579     static const HLSLNodeType s_type = HLSLNodeType_Macro;
HLSLMacroHLSLMacro580     HLSLMacro()
581     {
582         name            = NULL;
583         argument        = NULL;
584         numArguments    = 0;
585         macroAliased    = NULL;
586     }
587     const char*         name;
588     HLSLArgument*       argument;
589     unsigned int        numArguments;
590     std::string         value;
591     HLSLMacro*          macroAliased;
592 };
593 
594 /** A expression which forms a complete statement. */
595 struct HLSLExpressionStatement : public HLSLStatement
596 {
597     static const HLSLNodeType s_type = HLSLNodeType_ExpressionStatement;
HLSLExpressionStatementHLSLExpressionStatement598     HLSLExpressionStatement()
599     {
600         expression = NULL;
601     }
602     HLSLExpression*     expression;
603 };
604 
605 struct HLSLReturnStatement : public HLSLStatement
606 {
607     static const HLSLNodeType s_type = HLSLNodeType_ReturnStatement;
HLSLReturnStatementHLSLReturnStatement608     HLSLReturnStatement()
609     {
610         expression = NULL;
611     }
612     HLSLExpression*     expression;
613 };
614 
615 struct HLSLDiscardStatement : public HLSLStatement
616 {
617     static const HLSLNodeType s_type = HLSLNodeType_DiscardStatement;
618 };
619 
620 struct HLSLBreakStatement : public HLSLStatement
621 {
622     static const HLSLNodeType s_type = HLSLNodeType_BreakStatement;
623 };
624 
625 struct HLSLContinueStatement : public HLSLStatement
626 {
627     static const HLSLNodeType s_type = HLSLNodeType_ContinueStatement;
628 };
629 
630 struct HLSLIfStatement : public HLSLStatement
631 {
632     static const HLSLNodeType s_type = HLSLNodeType_IfStatement;
HLSLIfStatementHLSLIfStatement633     HLSLIfStatement()
634     {
635         condition     = NULL;
636         statement     = NULL;
637         elseStatement = NULL;
638         isStatic      = false;
639     }
640     HLSLExpression*     condition;
641     HLSLStatement*      statement;
642     HLSLStatement*      elseStatement;
643     bool                isStatic;
644 };
645 
646 struct HLSLForStatement : public HLSLStatement
647 {
648     static const HLSLNodeType s_type = HLSLNodeType_ForStatement;
HLSLForStatementHLSLForStatement649     HLSLForStatement()
650     {
651         initialization = NULL;
652         condition = NULL;
653         increment = NULL;
654         statement = NULL;
655     }
656     HLSLDeclaration*    initialization;
657     HLSLExpression*     initializationWithoutType;
658     HLSLExpression*     condition;
659     HLSLExpression*     increment;
660     HLSLStatement*      statement;
661 };
662 
663 struct HLSLWhileStatement : public HLSLStatement
664 {
665     static const HLSLNodeType s_type = HLSLNodeType_WhileStatement;
HLSLWhileStatementHLSLWhileStatement666     HLSLWhileStatement()
667     {
668         condition = NULL;
669         statement = NULL;
670     }
671     HLSLExpression*     condition;
672     HLSLStatement*      statement;
673 };
674 
675 struct HLSLBlockStatement : public HLSLStatement
676 {
677     static const HLSLNodeType s_type = HLSLNodeType_BlockStatement;
HLSLBlockStatementHLSLBlockStatement678     HLSLBlockStatement()
679     {
680         statement = NULL;
681     }
682     HLSLStatement*      statement;
683 };
684 
685 
686 /** Base type for all types of expressions. */
687 struct HLSLExpression : public HLSLNode
688 {
689     static const HLSLNodeType s_type = HLSLNodeType_Expression;
HLSLExpressionHLSLExpression690     HLSLExpression()
691     {
692         nextExpression = NULL;
693     }
694     HLSLType            expressionType;
695     HLSLExpression*     nextExpression; // Used when the expression is part of a list, like in a function call.
696 };
697 
698 struct HLSLUnaryExpression : public HLSLExpression
699 {
700     static const HLSLNodeType s_type = HLSLNodeType_UnaryExpression;
HLSLUnaryExpressionHLSLUnaryExpression701     HLSLUnaryExpression()
702     {
703         expression = NULL;
704     }
705     HLSLUnaryOp         unaryOp;
706     HLSLExpression*     expression;
707 };
708 
709 struct HLSLBinaryExpression : public HLSLExpression
710 {
711     static const HLSLNodeType s_type = HLSLNodeType_BinaryExpression;
HLSLBinaryExpressionHLSLBinaryExpression712     HLSLBinaryExpression()
713     {
714         expression1 = NULL;
715         expression2 = NULL;
716     }
717     HLSLBinaryOp        binaryOp;
718     HLSLExpression*     expression1;
719     HLSLExpression*     expression2;
720 };
721 
722 /** ? : construct */
723 struct HLSLConditionalExpression : public HLSLExpression
724 {
725     static const HLSLNodeType s_type = HLSLNodeType_ConditionalExpression;
HLSLConditionalExpressionHLSLConditionalExpression726     HLSLConditionalExpression()
727     {
728         condition       = NULL;
729         trueExpression  = NULL;
730         falseExpression = NULL;
731     }
732     HLSLExpression*     condition;
733     HLSLExpression*     trueExpression;
734     HLSLExpression*     falseExpression;
735 };
736 
737 struct HLSLCastingExpression : public HLSLExpression
738 {
739     static const HLSLNodeType s_type = HLSLNodeType_CastingExpression;
HLSLCastingExpressionHLSLCastingExpression740     HLSLCastingExpression()
741     {
742         expression = NULL;
743     }
744     HLSLType            type;
745     HLSLExpression*     expression;
746 };
747 
748 /** Float, integer, boolean, etc. literal constant. */
749 struct HLSLLiteralExpression : public HLSLExpression
750 {
751     static const HLSLNodeType s_type = HLSLNodeType_LiteralExpression;
752     HLSLBaseType        type;   // Note, not all types can be literals.
753     union
754     {
755         bool            bValue;
756         float           fValue;
757         int             iValue;
758     };
759 };
760 
761 /** An identifier, typically a variable name or structure field name. */
762 struct HLSLIdentifierExpression : public HLSLExpression
763 {
764     static const HLSLNodeType s_type = HLSLNodeType_IdentifierExpression;
HLSLIdentifierExpressionHLSLIdentifierExpression765     HLSLIdentifierExpression()
766     {
767         name     = NULL;
768         global  = false;
769     }
770     const char*         name;
771     bool                global; // This is a global variable.
772 };
773 
774 /** float2(1, 2) */
775 struct HLSLConstructorExpression : public HLSLExpression
776 {
777     static const HLSLNodeType s_type = HLSLNodeType_ConstructorExpression;
HLSLConstructorExpressionHLSLConstructorExpression778 	HLSLConstructorExpression()
779 	{
780 		argument = NULL;
781 	}
782     HLSLType            type;
783     HLSLExpression*     argument;
784 };
785 
786 /** object.member **/
787 struct HLSLMemberAccess : public HLSLExpression
788 {
789     static const HLSLNodeType s_type = HLSLNodeType_MemberAccess;
HLSLMemberAccessHLSLMemberAccess790 	HLSLMemberAccess()
791 	{
792 		object  = NULL;
793 		field   = NULL;
794 		swizzle = false;
795 	}
796     HLSLExpression*     object;
797     const char*         field;
798     bool                swizzle;
799 };
800 
801 /** array[index] **/
802 struct HLSLArrayAccess : public HLSLExpression
803 {
804     static const HLSLNodeType s_type = HLSLNodeType_ArrayAccess;
HLSLArrayAccessHLSLArrayAccess805 	HLSLArrayAccess()
806 	{
807 		array = NULL;
808 		index = NULL;
809 	}
810     HLSLExpression*     array;
811     HLSLExpression*     index;
812 };
813 
814 struct HLSLFunctionCall : public HLSLExpression
815 {
816     static const HLSLNodeType s_type = HLSLNodeType_FunctionCall;
HLSLFunctionCallHLSLFunctionCall817 	HLSLFunctionCall()
818 	{
819 		function     = NULL;
820 		argument     = NULL;
821 		numArguments = 0;
822 	}
823     const HLSLFunction* function;
824     HLSLExpression*     argument;
825     int                 numArguments;
826 };
827 
828 struct HLSLStateAssignment : public HLSLNode
829 {
830     static const HLSLNodeType s_type = HLSLNodeType_StateAssignment;
HLSLStateAssignmentHLSLStateAssignment831     HLSLStateAssignment()
832     {
833         stateName = NULL;
834         sValue = NULL;
835         nextStateAssignment = NULL;
836     }
837 
838     const char*             stateName;
839     int                     d3dRenderState;
840     union {
841         int                 iValue;
842         float               fValue;
843         const char *        sValue;
844     };
845     HLSLStateAssignment*    nextStateAssignment;
846 };
847 
848 struct HLSLSamplerState : public HLSLExpression // @@ Does this need to be an expression? Does it have a type? I guess type is useful.
849 {
850     static const HLSLNodeType s_type = HLSLNodeType_SamplerState;
HLSLSamplerStateHLSLSamplerState851     HLSLSamplerState()
852     {
853         numStateAssignments = 0;
854         stateAssignments = NULL;
855     }
856 
857     int                     numStateAssignments;
858     HLSLStateAssignment*    stateAssignments;
859 };
860 
861 struct HLSLPass : public HLSLNode
862 {
863     static const HLSLNodeType s_type = HLSLNodeType_Pass;
HLSLPassHLSLPass864     HLSLPass()
865     {
866         name = NULL;
867         numStateAssignments = 0;
868         stateAssignments = NULL;
869         nextPass = NULL;
870     }
871 
872     const char*             name;
873     int                     numStateAssignments;
874     HLSLStateAssignment*    stateAssignments;
875     HLSLPass*               nextPass;
876 };
877 
878 struct HLSLTechnique : public HLSLStatement
879 {
880     static const HLSLNodeType s_type = HLSLNodeType_Technique;
HLSLTechniqueHLSLTechnique881     HLSLTechnique()
882     {
883         name = NULL;
884         numPasses = 0;
885         passes = NULL;
886     }
887 
888     const char*         name;
889     int                 numPasses;
890     HLSLPass*           passes;
891 };
892 
893 struct HLSLPipeline : public HLSLStatement
894 {
895     static const HLSLNodeType s_type = HLSLNodeType_Pipeline;
HLSLPipelineHLSLPipeline896     HLSLPipeline()
897     {
898         name = NULL;
899         numStateAssignments = 0;
900         stateAssignments = NULL;
901     }
902 
903     const char*             name;
904     int                     numStateAssignments;
905     HLSLStateAssignment*    stateAssignments;
906 };
907 
908 struct HLSLStage : public HLSLStatement
909 {
910     static const HLSLNodeType s_type = HLSLNodeType_Stage;
HLSLStageHLSLStage911     HLSLStage()
912     {
913         name = NULL;
914         statement = NULL;
915         inputs = NULL;
916         outputs = NULL;
917     }
918 
919     const char*             name;
920     HLSLStatement*          statement;
921     HLSLDeclaration*        inputs;
922     HLSLDeclaration*        outputs;
923 };
924 
925 struct matrixCtor {
926     HLSLBaseType matrixType;
927     std::vector<HLSLBaseType> argumentTypes;
928 
929     bool operator==(const matrixCtor & other) const
930     {
931         return  matrixType == other.matrixType &&
932                 argumentTypes == other.argumentTypes;
933     }
934 
935     bool operator<(const matrixCtor & other) const
936     {
937         if (matrixType < other.matrixType)
938         {
939             return true;
940         }
941         else if (matrixType > other.matrixType)
942         {
943             return false;
944         }
945 
946         return argumentTypes < other.argumentTypes;
947     }
948 };
949 
950 
951 /**
952  * Abstract syntax tree for parsed HLSL code.
953  */
954 class HLSLTree
955 {
956 
957 public:
958 
959     explicit HLSLTree(Allocator* allocator);
960     ~HLSLTree();
961 
962     /** Adds a string to the string pool used by the tree. */
963     const char* AddString(const char* string);
964     const char* AddStringFormat(const char* string, ...);
965 
966     /** Returns true if the string is contained within the tree. */
967     bool GetContainsString(const char* string) const;
968 
969     /** Returns the root block in the tree */
970     HLSLRoot* GetRoot() const;
971 
972     /** Adds a new node to the tree with the specified type. */
973     template <class T>
AddNode(const char * fileName,int line)974     T* AddNode(const char* fileName, int line)
975     {
976         HLSLNode* node = new (AllocateMemory(sizeof(T))) T();
977         node->nodeType  = T::s_type;
978         node->fileName  = fileName;
979         node->line      = line;
980         return static_cast<T*>(node);
981     }
982 
983     HLSLFunction * FindFunction(const char * name);
984     HLSLDeclaration * FindGlobalDeclaration(const char * name, HLSLBuffer ** buffer_out = NULL);
985     HLSLStruct * FindGlobalStruct(const char * name);
986     HLSLTechnique * FindTechnique(const char * name);
987     HLSLPipeline * FindFirstPipeline();
988     HLSLPipeline * FindNextPipeline(HLSLPipeline * current);
989     HLSLPipeline * FindPipeline(const char * name);
990     HLSLBuffer * FindBuffer(const char * name);
991 
992     bool GetExpressionValue(HLSLExpression * expression, int & value);
993     int GetExpressionValue(HLSLExpression * expression, float values[4]);
994 
995     bool NeedsFunction(const char * name);
996     bool ReplaceUniformsAssignments();
997     void EnumerateMatrixCtorsNeeded(std::vector<matrixCtor> & matrixCtors);
998 
999 private:
1000 
1001     void* AllocateMemory(size_t size);
1002     void  AllocatePage();
1003 
1004 private:
1005 
1006     static const size_t s_nodePageSize = 1024 * 4;
1007 
1008     struct NodePage
1009     {
1010         NodePage*   next;
1011         char        buffer[s_nodePageSize];
1012     };
1013 
1014     Allocator*      m_allocator;
1015     StringPool      m_stringPool;
1016     HLSLRoot*       m_root;
1017 
1018     NodePage*       m_firstPage;
1019     NodePage*       m_currentPage;
1020     size_t          m_currentPageOffset;
1021 
1022 };
1023 
1024 
1025 
1026 class HLSLTreeVisitor
1027 {
1028 public:
1029     virtual void VisitType(HLSLType & type);
1030 
1031     virtual void VisitRoot(HLSLRoot * node);
1032     virtual void VisitTopLevelStatement(HLSLStatement * node);
1033     virtual void VisitStatements(HLSLStatement * statement);
1034     virtual void VisitStatement(HLSLStatement * node);
1035     virtual void VisitDeclaration(HLSLDeclaration * node);
1036     virtual void VisitStruct(HLSLStruct * node);
1037     virtual void VisitStructField(HLSLStructField * node);
1038     virtual void VisitBuffer(HLSLBuffer * node);
1039     //virtual void VisitBufferField(HLSLBufferField * node);
1040     virtual void VisitFunction(HLSLFunction * node);
1041     virtual void VisitArgument(HLSLArgument * node);
1042     virtual void VisitExpressionStatement(HLSLExpressionStatement * node);
1043     virtual void VisitExpression(HLSLExpression * node);
1044     virtual void VisitReturnStatement(HLSLReturnStatement * node);
1045     virtual void VisitDiscardStatement(HLSLDiscardStatement * node);
1046     virtual void VisitBreakStatement(HLSLBreakStatement * node);
1047     virtual void VisitContinueStatement(HLSLContinueStatement * node);
1048     virtual void VisitIfStatement(HLSLIfStatement * node);
1049     virtual void VisitForStatement(HLSLForStatement * node);
1050     virtual void VisitWhileStatement(HLSLWhileStatement * node);
1051     virtual void VisitBlockStatement(HLSLBlockStatement * node);
1052     virtual void VisitUnaryExpression(HLSLUnaryExpression * node);
1053     virtual void VisitBinaryExpression(HLSLBinaryExpression * node);
1054     virtual void VisitConditionalExpression(HLSLConditionalExpression * node);
1055     virtual void VisitCastingExpression(HLSLCastingExpression * node);
1056     virtual void VisitLiteralExpression(HLSLLiteralExpression * node);
1057     virtual void VisitIdentifierExpression(HLSLIdentifierExpression * node);
1058     virtual void VisitConstructorExpression(HLSLConstructorExpression * node);
1059     virtual void VisitMemberAccess(HLSLMemberAccess * node);
1060     virtual void VisitArrayAccess(HLSLArrayAccess * node);
1061     virtual void VisitFunctionCall(HLSLFunctionCall * node);
1062     virtual void VisitStateAssignment(HLSLStateAssignment * node);
1063     virtual void VisitSamplerState(HLSLSamplerState * node);
1064     virtual void VisitPass(HLSLPass * node);
1065     virtual void VisitTechnique(HLSLTechnique * node);
1066     virtual void VisitPipeline(HLSLPipeline * node);
1067 
1068 
1069     virtual void VisitFunctions(HLSLRoot * root);
1070     virtual void VisitParameters(HLSLRoot * root);
1071 
1072     HLSLFunction * FindFunction(HLSLRoot * root, const char * name);
1073     HLSLDeclaration * FindGlobalDeclaration(HLSLRoot * root, const char * name);
1074     HLSLStruct * FindGlobalStruct(HLSLRoot * root, const char * name);
1075 };
1076 
1077 
1078 // Tree transformations:
1079 extern void PruneTree(HLSLTree* tree, const char* entryName0, const char* entryName1 = NULL);
1080 extern void SortTree(HLSLTree* tree);
1081 extern void GroupParameters(HLSLTree* tree);
1082 extern void HideUnusedArguments(HLSLFunction * function);
1083 extern bool EmulateAlphaTest(HLSLTree* tree, const char* entryName, float alphaRef = 0.5f);
1084 extern void FlattenExpressions(HLSLTree* tree);
1085 
1086 extern matrixCtor matrixCtorBuilder(HLSLType type, HLSLExpression *arguments);
1087 
1088 
1089 } // M4
1090 
1091 #endif
1092