1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_ASTNODE
9 #define SKSL_ASTNODE
10 
11 #include "src/sksl/SkSLLexer.h"
12 #include "src/sksl/SkSLString.h"
13 #include "src/sksl/ir/SkSLModifiers.h"
14 
15 #include <vector>
16 
17 namespace SkSL {
18 
19 // std::max isn't constexpr in some compilers
Max(size_t a,size_t b)20 static constexpr size_t Max(size_t a, size_t b) {
21     return a > b ? a : b;
22 }
23 
24 /**
25  * Represents a node in the abstract syntax tree (AST). The AST is based directly on the parse tree;
26  * it is a parsed-but-not-yet-analyzed version of the program.
27  */
28 struct ASTNode {
29     class ID {
30     public:
InvalidASTNode31         static ID Invalid() {
32             return ID();
33         }
34 
35         bool operator==(const ID& other) {
36             return fValue == other.fValue;
37         }
38 
39         bool operator!=(const ID& other) {
40             return fValue != other.fValue;
41         }
42 
43         MOZ_IMPLICIT operator bool() const { return fValue >= 0; }
44 
45     private:
IDASTNode46         ID()
47             : fValue(-1) {}
48 
IDASTNode49         ID(int value)
50             : fValue(value) {}
51 
52         int fValue;
53 
54         friend struct ASTFile;
55         friend struct ASTNode;
56         friend class Parser;
57     };
58 
59     enum class Kind {
60         // data: operator(Token), children: left, right
61         kBinary,
62         // children: statements
63         kBlock,
64         // data: value(bool)
65         kBool,
66         kBreak,
67         // children: target, arg1, arg2...
68         kCall,
69         kContinue,
70         kDiscard,
71         // children: statement, test
72         kDo,
73         // data: name(StringFragment), children: enumCases
74         kEnum,
75         // data: name(StringFragment), children: value?
76         kEnumCase,
77         // data: name(StringFragment)
78         kExtension,
79         // data: field(StringFragment), children: base
80         kField,
81         // children: declarations
82         kFile,
83         // data: value(float)
84         kFloat,
85         // children: init, test, next, statement
86         kFor,
87         // data: FunctionData, children: returnType, parameters, statement?
88         kFunction,
89         // data: name(StringFragment)
90         kIdentifier,
91         // children: base, index?
92         kIndex,
93         // data: isStatic(bool), children: test, ifTrue, ifFalse?
94         kIf,
95         // value(data): int
96         kInt,
97         // data: InterfaceBlockData, children: declaration1, declaration2, ..., size1, size2, ...
98         kInterfaceBlock,
99         // data: Modifiers
100         kModifiers,
101         kNull,
102         // data: ParameterData, children: type, arraySize1, arraySize2, ..., value?
103         kParameter,
104         // data: operator(Token), children: operand
105         kPostfix,
106         // data: operator(Token), children: operand
107         kPrefix,
108         // children: value
109         kReturn,
110         // ...
111         kSection,
112         // children: value, statement 1, statement 2...
113         kSwitchCase,
114         // children: value, case 1, case 2...
115         kSwitch,
116         // children: test, ifTrue, ifFalse
117         kTernary,
118         // data: TypeData, children: sizes
119         kType,
120         // data: VarData, children: arraySize1, arraySize2, ..., value?
121         kVarDeclaration,
122         // children: modifiers, type, varDeclaration1, varDeclaration2, ...
123         kVarDeclarations,
124         // children: test, statement
125         kWhile,
126     };
127 
128     class iterator {
129     public:
130         iterator operator++() {
131             SkASSERT(fID);
132             fID = (**this).fNext;
133             return *this;
134         }
135 
136         iterator operator++(int) {
137             SkASSERT(fID);
138             iterator old = *this;
139             fID = (**this).fNext;
140             return old;
141         }
142 
143         iterator operator+=(int count) {
144             SkASSERT(count >= 0);
145             for (; count > 0; --count) {
146                 ++(*this);
147             }
148             return *this;
149         }
150 
151         iterator operator+(int count) {
152             iterator result(*this);
153             return result += count;
154         }
155 
156         bool operator==(const iterator& other) const {
157             return fID == other.fID;
158         }
159 
160         bool operator!=(const iterator& other) const {
161             return fID != other.fID;
162         }
163 
164         ASTNode& operator*() {
165             SkASSERT(fID);
166             return (*fNodes)[fID.fValue];
167         }
168 
169         ASTNode* operator->() {
170             SkASSERT(fID);
171             return &(*fNodes)[fID.fValue];
172         }
173 
174     private:
iteratorASTNode175         iterator(std::vector<ASTNode>* nodes, ID id)
176             : fNodes(nodes)
177             , fID(id) {}
178 
179         std::vector<ASTNode>* fNodes;
180 
181         ID fID;
182 
183         friend struct ASTNode;
184     };
185 
186     struct TypeData {
TypeDataASTNode::TypeData187         TypeData() {}
188 
TypeDataASTNode::TypeData189         TypeData(StringFragment name, bool isStructDeclaration, bool isNullable)
190             : fName(name)
191             , fIsStructDeclaration(isStructDeclaration)
192             , fIsNullable(isNullable) {}
193 
194         StringFragment fName;
195         bool fIsStructDeclaration;
196         bool fIsNullable;
197     };
198 
199     struct ParameterData {
ParameterDataASTNode::ParameterData200         ParameterData() {}
201 
ParameterDataASTNode::ParameterData202         ParameterData(Modifiers modifiers, StringFragment name, size_t sizeCount)
203             : fModifiers(modifiers)
204             , fName(name)
205             , fSizeCount(sizeCount) {}
206 
207         Modifiers fModifiers;
208         StringFragment fName;
209         size_t fSizeCount;
210     };
211 
212     struct VarData {
VarDataASTNode::VarData213         VarData() {}
214 
VarDataASTNode::VarData215         VarData(StringFragment name, size_t sizeCount)
216             : fName(name)
217             , fSizeCount(sizeCount) {}
218 
219         StringFragment fName;
220         size_t fSizeCount;
221     };
222 
223     struct FunctionData {
FunctionDataASTNode::FunctionData224         FunctionData() {}
225 
FunctionDataASTNode::FunctionData226         FunctionData(Modifiers modifiers, StringFragment name, size_t parameterCount)
227             : fModifiers(modifiers)
228             , fName(name)
229             , fParameterCount(parameterCount) {}
230 
231         Modifiers fModifiers;
232         StringFragment fName;
233         size_t fParameterCount;
234     };
235 
236     struct InterfaceBlockData {
InterfaceBlockDataASTNode::InterfaceBlockData237         InterfaceBlockData() {}
238 
InterfaceBlockDataASTNode::InterfaceBlockData239         InterfaceBlockData(Modifiers modifiers, StringFragment typeName, size_t declarationCount,
240                            StringFragment instanceName, size_t sizeCount)
241             : fModifiers(modifiers)
242             , fTypeName(typeName)
243             , fDeclarationCount(declarationCount)
244             , fInstanceName(instanceName)
245             , fSizeCount(sizeCount) {}
246 
247         Modifiers fModifiers;
248         StringFragment fTypeName;
249         size_t fDeclarationCount;
250         StringFragment fInstanceName;
251         size_t fSizeCount;
252     };
253 
254     struct SectionData {
SectionDataASTNode::SectionData255         SectionData() {}
256 
SectionDataASTNode::SectionData257         SectionData(StringFragment name, StringFragment argument, StringFragment text)
258             : fName(name)
259             , fArgument(argument)
260             , fText(text) {}
261 
262         StringFragment fName;
263         StringFragment fArgument;
264         StringFragment fText;
265     };
266 
267     struct NodeData {
268         char fBytes[Max(sizeof(Token),
269                     Max(sizeof(StringFragment),
270                     Max(sizeof(bool),
271                     Max(sizeof(SKSL_INT),
272                     Max(sizeof(SKSL_FLOAT),
273                     Max(sizeof(Modifiers),
274                     Max(sizeof(TypeData),
275                     Max(sizeof(FunctionData),
276                     Max(sizeof(ParameterData),
277                     Max(sizeof(VarData),
278                     Max(sizeof(InterfaceBlockData),
279                         sizeof(SectionData))))))))))))];
280 
281         enum class Kind {
282             kToken,
283             kStringFragment,
284             kBool,
285             kInt,
286             kFloat,
287             kModifiers,
288             kTypeData,
289             kFunctionData,
290             kParameterData,
291             kVarData,
292             kInterfaceBlockData,
293             kSectionData
294         } fKind;
295 
296         NodeData() = default;
297 
NodeDataASTNode::NodeData298         NodeData(Token data)
299             : fKind(Kind::kToken) {
300             memcpy(fBytes, &data, sizeof(data));
301         }
302 
NodeDataASTNode::NodeData303         NodeData(StringFragment data)
304             : fKind(Kind::kStringFragment) {
305             memcpy(fBytes, &data, sizeof(data));
306         }
307 
NodeDataASTNode::NodeData308         NodeData(bool data)
309             : fKind(Kind::kBool) {
310             memcpy(fBytes, &data, sizeof(data));
311         }
312 
NodeDataASTNode::NodeData313         NodeData(SKSL_INT data)
314             : fKind(Kind::kInt) {
315             memcpy(fBytes, &data, sizeof(data));
316         }
317 
NodeDataASTNode::NodeData318         NodeData(SKSL_FLOAT data)
319             : fKind(Kind::kFloat) {
320             memcpy(fBytes, &data, sizeof(data));
321         }
322 
NodeDataASTNode::NodeData323         NodeData(Modifiers data)
324             : fKind(Kind::kModifiers) {
325             memcpy(fBytes, &data, sizeof(data));
326         }
327 
NodeDataASTNode::NodeData328         NodeData(TypeData data)
329             : fKind(Kind::kTypeData) {
330             memcpy(fBytes, &data, sizeof(data));
331         }
332 
NodeDataASTNode::NodeData333         NodeData(FunctionData data)
334             : fKind(Kind::kFunctionData) {
335             memcpy(fBytes, &data, sizeof(data));
336         }
337 
NodeDataASTNode::NodeData338         NodeData(VarData data)
339             : fKind(Kind::kVarData) {
340             memcpy(fBytes, &data, sizeof(data));
341         }
342 
NodeDataASTNode::NodeData343         NodeData(ParameterData data)
344             : fKind(Kind::kParameterData) {
345             memcpy(fBytes, &data, sizeof(data));
346         }
347 
NodeDataASTNode::NodeData348         NodeData(InterfaceBlockData data)
349             : fKind(Kind::kInterfaceBlockData) {
350             memcpy(fBytes, &data, sizeof(data));
351         }
352 
NodeDataASTNode::NodeData353         NodeData(SectionData data)
354             : fKind(Kind::kSectionData) {
355             memcpy(fBytes, &data, sizeof(data));
356         }
357     };
358 
ASTNodeASTNode359     ASTNode()
360         : fOffset(-1)
361         , fKind(Kind::kNull) {}
362 
ASTNodeASTNode363     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind)
364         : fNodes(nodes)
365         , fOffset(offset)
366             , fKind(kind) {
367         switch (kind) {
368             case Kind::kBinary:
369             case Kind::kPostfix:
370             case Kind::kPrefix:
371                 fData.fKind = NodeData::Kind::kToken;
372                 break;
373 
374             case Kind::kBool:
375             case Kind::kIf:
376             case Kind::kSwitch:
377                 fData.fKind = NodeData::Kind::kBool;
378                 break;
379 
380             case Kind::kEnum:
381             case Kind::kEnumCase:
382             case Kind::kExtension:
383             case Kind::kField:
384             case Kind::kIdentifier:
385                 fData.fKind = NodeData::Kind::kStringFragment;
386                 break;
387 
388             case Kind::kFloat:
389                 fData.fKind = NodeData::Kind::kFloat;
390                 break;
391 
392             case Kind::kFunction:
393                 fData.fKind = NodeData::Kind::kFunctionData;
394                 break;
395 
396             case Kind::kInt:
397                 fData.fKind = NodeData::Kind::kInt;
398                 break;
399 
400             case Kind::kInterfaceBlock:
401                 fData.fKind = NodeData::Kind::kInterfaceBlockData;
402                 break;
403 
404             case Kind::kModifiers:
405                 fData.fKind = NodeData::Kind::kModifiers;
406                 break;
407 
408             case Kind::kParameter:
409                 fData.fKind = NodeData::Kind::kParameterData;
410                 break;
411 
412             case Kind::kVarDeclaration:
413                 fData.fKind = NodeData::Kind::kVarData;
414                 break;
415 
416             case Kind::kType:
417                 fData.fKind = NodeData::Kind::kTypeData;
418                 break;
419 
420             default:
421                 break;
422         }
423     }
424 
ASTNodeASTNode425     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, Token t)
426         : fNodes(nodes)
427         , fData(t)
428         , fOffset(offset)
429         , fKind(kind) {}
430 
ASTNodeASTNode431     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, StringFragment s)
432         : fNodes(nodes)
433         , fData(s)
434         , fOffset(offset)
435         , fKind(kind) {}
436 
ASTNodeASTNode437     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, const char* s)
438         : fNodes(nodes)
439         , fData(StringFragment(s))
440         , fOffset(offset)
441         , fKind(kind) {}
442 
ASTNodeASTNode443     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, bool b)
444         : fNodes(nodes)
445         , fData(b)
446         , fOffset(offset)
447         , fKind(kind) {}
448 
ASTNodeASTNode449     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SKSL_INT i)
450         : fNodes(nodes)
451         , fData(i)
452         , fOffset(offset)
453         , fKind(kind) {}
454 
ASTNodeASTNode455     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SKSL_FLOAT f)
456         : fNodes(nodes)
457         , fData(f)
458         , fOffset(offset)
459         , fKind(kind) {}
460 
ASTNodeASTNode461     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, Modifiers m)
462         : fNodes(nodes)
463         , fData(m)
464         , fOffset(offset)
465         , fKind(kind) {}
466 
ASTNodeASTNode467     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, TypeData td)
468         : fNodes(nodes)
469         , fData(td)
470         , fOffset(offset)
471         , fKind(kind) {}
472 
ASTNodeASTNode473     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SectionData s)
474         : fNodes(nodes)
475         , fData(s)
476         , fOffset(offset)
477         , fKind(kind) {}
478 
479     MOZ_IMPLICIT operator bool() const {
480         return fKind != Kind::kNull;
481     }
482 
getTokenASTNode483     Token getToken() const {
484         SkASSERT(fData.fKind == NodeData::Kind::kToken);
485         Token result;
486         memcpy(&result, fData.fBytes, sizeof(result));
487         return result;
488     }
489 
getBoolASTNode490     bool getBool() const {
491         SkASSERT(fData.fKind == NodeData::Kind::kBool);
492         bool result;
493         memcpy(&result, fData.fBytes, sizeof(result));
494         return result;
495     }
496 
getIntASTNode497     SKSL_INT getInt() const {
498         SkASSERT(fData.fKind == NodeData::Kind::kInt);
499         SKSL_INT result;
500         memcpy(&result, fData.fBytes, sizeof(result));
501         return result;
502     }
503 
getFloatASTNode504     SKSL_FLOAT getFloat() const {
505         SkASSERT(fData.fKind == NodeData::Kind::kFloat);
506         SKSL_FLOAT result;
507         memcpy(&result, fData.fBytes, sizeof(result));
508         return result;
509     }
510 
getStringASTNode511     StringFragment getString() const {
512         SkASSERT(fData.fKind == NodeData::Kind::kStringFragment);
513         StringFragment result;
514         memcpy(&result, fData.fBytes, sizeof(result));
515         return result;
516     }
517 
getModifiersASTNode518     Modifiers getModifiers() const {
519         SkASSERT(fData.fKind == NodeData::Kind::kModifiers);
520         Modifiers result;
521         memcpy(&result, fData.fBytes, sizeof(result));
522         return result;
523     }
524 
setModifiersASTNode525     void setModifiers(const Modifiers& m) {
526         memcpy(fData.fBytes, &m, sizeof(m));
527     }
528 
getTypeDataASTNode529     TypeData getTypeData() const {
530         SkASSERT(fData.fKind == NodeData::Kind::kTypeData);
531         TypeData result;
532         memcpy(&result, fData.fBytes, sizeof(result));
533         return result;
534     }
535 
setTypeDataASTNode536     void setTypeData(const ASTNode::TypeData& td) {
537         SkASSERT(fData.fKind == NodeData::Kind::kTypeData);
538         memcpy(fData.fBytes, &td, sizeof(td));
539     }
540 
getParameterDataASTNode541     ParameterData getParameterData() const {
542         SkASSERT(fData.fKind == NodeData::Kind::kParameterData);
543         ParameterData result;
544         memcpy(&result, fData.fBytes, sizeof(result));
545         return result;
546     }
547 
setParameterDataASTNode548     void setParameterData(const ASTNode::ParameterData& pd) {
549         SkASSERT(fData.fKind == NodeData::Kind::kParameterData);
550         memcpy(fData.fBytes, &pd, sizeof(pd));
551     }
552 
getVarDataASTNode553     VarData getVarData() const {
554         SkASSERT(fData.fKind == NodeData::Kind::kVarData);
555         VarData result;
556         memcpy(&result, fData.fBytes, sizeof(result));
557         return result;
558     }
559 
setVarDataASTNode560     void setVarData(const ASTNode::VarData& vd) {
561         SkASSERT(fData.fKind == NodeData::Kind::kVarData);
562         memcpy(fData.fBytes, &vd, sizeof(vd));
563     }
564 
getFunctionDataASTNode565     FunctionData getFunctionData() const {
566         SkASSERT(fData.fKind == NodeData::Kind::kFunctionData);
567         FunctionData result;
568         memcpy(&result, fData.fBytes, sizeof(result));
569         return result;
570     }
571 
setFunctionDataASTNode572     void setFunctionData(const ASTNode::FunctionData& fd) {
573         SkASSERT(fData.fKind == NodeData::Kind::kFunctionData);
574         memcpy(fData.fBytes, &fd, sizeof(fd));
575     }
576 
getInterfaceBlockDataASTNode577     InterfaceBlockData getInterfaceBlockData() const {
578         SkASSERT(fData.fKind == NodeData::Kind::kInterfaceBlockData);
579         InterfaceBlockData result;
580         memcpy(&result, fData.fBytes, sizeof(result));
581         return result;
582     }
583 
setInterfaceBlockDataASTNode584     void setInterfaceBlockData(const ASTNode::InterfaceBlockData& id) {
585         SkASSERT(fData.fKind == NodeData::Kind::kInterfaceBlockData);
586         memcpy(fData.fBytes, &id, sizeof(id));
587     }
588 
getSectionDataASTNode589     SectionData getSectionData() const {
590         SkASSERT(fData.fKind == NodeData::Kind::kSectionData);
591         SectionData result;
592         memcpy(&result, fData.fBytes, sizeof(result));
593         return result;
594     }
595 
addChildASTNode596     void addChild(ID id) {
597         SkASSERT(!(*fNodes)[id.fValue].fNext);
598         if (fLastChild) {
599             SkASSERT(!(*fNodes)[fLastChild.fValue].fNext);
600             (*fNodes)[fLastChild.fValue].fNext = id;
601         } else {
602             fFirstChild = id;
603         }
604         fLastChild = id;
605         SkASSERT(!(*fNodes)[fLastChild.fValue].fNext);
606     }
607 
beginASTNode608     iterator begin() const {
609         return iterator(fNodes, fFirstChild);
610     }
611 
endASTNode612     iterator end() const {
613         return iterator(fNodes, ID(-1));
614     }
615 
616     String description() const;
617 
618     std::vector<ASTNode>* fNodes;
619 
620     NodeData fData;
621 
622     int fOffset;
623 
624     Kind fKind;
625 
626     ID fFirstChild;
627 
628     ID fLastChild;
629 
630     ID fNext;
631 };
632 
633 } // namespace
634 
635 #endif
636