1 // Copyright Contributors to the OpenVDB Project
2 // SPDX-License-Identifier: MPL-2.0
3 
4 #include <openvdb_ax/ast/AST.h>
5 #include <openvdb_ax/ast/Scanners.h>
6 #include <openvdb_ax/ast/PrintTree.h>
7 #include <openvdb_ax/Exceptions.h>
8 
9 #include "../util.h"
10 
11 #include <cppunit/extensions/HelperMacros.h>
12 
13 #include <string>
14 
15 using namespace openvdb::ax::ast;
16 using namespace openvdb::ax::ast::tokens;
17 
18 namespace {
19 
20 static const unittest_util::CodeTests tests =
21 {
22     { "a + b;",                 Node::Ptr(
23                                     new BinaryOperator(
24                                         new Local("a"),
25                                         new Local("b"),
26                                         OperatorToken::PLUS
27                                     )
28                                 )
29     },
30     { "a - b;",                 Node::Ptr(
31                                     new BinaryOperator(
32                                         new Local("a"),
33                                         new Local("b"),
34                                         OperatorToken::MINUS
35                                     )
36                                 )
37     },
38     { "a * b;",                 Node::Ptr(
39                                     new BinaryOperator(
40                                         new Local("a"),
41                                         new Local("b"),
42                                         OperatorToken::MULTIPLY
43                                     )
44                                 )
45     },
46     { "a / b;",                 Node::Ptr(
47                                     new BinaryOperator(
48                                         new Local("a"),
49                                         new Local("b"),
50                                         OperatorToken::DIVIDE
51                                     )
52                                 )
53     },
54     { "a % b;",                 Node::Ptr(
55                                     new BinaryOperator(
56                                         new Local("a"),
57                                         new Local("b"),
58                                         OperatorToken::MODULO
59                                     )
60                                 )
61     },
62     { "a << b;",                Node::Ptr(
63                                     new BinaryOperator(
64                                         new Local("a"),
65                                         new Local("b"),
66                                         OperatorToken::SHIFTLEFT
67                                     )
68                                 )
69     },
70     { "a >> b;",                Node::Ptr(
71                                     new BinaryOperator(
72                                         new Local("a"),
73                                         new Local("b"),
74                                         OperatorToken::SHIFTRIGHT
75                                     )
76                                 )
77     },
78     { "a & b;",                 Node::Ptr(
79                                     new BinaryOperator(
80                                         new Local("a"),
81                                         new Local("b"),
82                                         OperatorToken::BITAND
83                                     )
84                                 )
85     },
86     { "a | b;",                 Node::Ptr(
87                                     new BinaryOperator(
88                                         new Local("a"),
89                                         new Local("b"),
90                                         OperatorToken::BITOR
91                                     )
92                                 )
93     },
94     { "a ^ b;",                 Node::Ptr(
95                                     new BinaryOperator(
96                                         new Local("a"),
97                                         new Local("b"),
98                                         OperatorToken::BITXOR
99                                     )
100                                 )
101     },
102     { "a && b;",                Node::Ptr(
103                                     new BinaryOperator(
104                                         new Local("a"),
105                                         new Local("b"),
106                                         OperatorToken::AND
107                                     )
108                                 )
109     },
110     { "a || b;",                Node::Ptr(
111                                     new BinaryOperator(
112                                         new Local("a"),
113                                         new Local("b"),
114                                         OperatorToken::OR
115                                     )
116                                 )
117     },
118     { "a == b;",                Node::Ptr(
119                                     new BinaryOperator(
120                                         new Local("a"),
121                                         new Local("b"),
122                                         OperatorToken::EQUALSEQUALS
123                                     )
124                                 )
125     },
126     { "a != b;",                Node::Ptr(
127                                     new BinaryOperator(
128                                         new Local("a"),
129                                         new Local("b"),
130                                         OperatorToken::NOTEQUALS
131                                     )
132                                 )
133     },
134     { "a > b;",                 Node::Ptr(
135                                     new BinaryOperator(
136                                         new Local("a"),
137                                         new Local("b"),
138                                         OperatorToken::MORETHAN
139                                     )
140                                 )
141     },
142     { "a < b;",                 Node::Ptr(
143                                     new BinaryOperator(
144                                         new Local("a"),
145                                         new Local("b"),
146                                         OperatorToken::LESSTHAN
147                                     )
148                                 )
149     },
150     { "a >= b;",                Node::Ptr(
151                                     new BinaryOperator(
152                                         new Local("a"),
153                                         new Local("b"),
154                                         OperatorToken::MORETHANOREQUAL
155                                     )
156                                 )
157     },
158     { "a <= b;",                Node::Ptr(
159                                     new BinaryOperator(
160                                         new Local("a"),
161                                         new Local("b"),
162                                         OperatorToken::LESSTHANOREQUAL
163                                     )
164                                 )
165     },
166     { "(a) + (a);",             Node::Ptr(
167                                     new BinaryOperator(
168                                         new Local("a"),
169                                         new Local("a"),
170                                         OperatorToken::PLUS
171                                     )
172                                 )
173     },
174     { "(a,b,c) + (d,e,f);",     Node::Ptr(
175                                     new BinaryOperator(
176                                         new CommaOperator({
177                                             new Local("a"), new Local("b"), new Local("c")
178                                         }),
179                                         new CommaOperator({
180                                             new Local("d"), new Local("e"), new Local("f")
181                                         }),
182                                         OperatorToken::PLUS
183                                     )
184                                 )
185     },
186     { "func1() + func2();",      Node::Ptr(
187                                     new BinaryOperator(
188                                         new FunctionCall("func1"),
189                                         new FunctionCall("func2"),
190                                         OperatorToken::PLUS
191                                     )
192                                 )
193     },
194     { "a + b - c;",             Node::Ptr(
195                                     new BinaryOperator(
196                                         new BinaryOperator(
197                                             new Local("a"),
198                                             new Local("b"),
199                                             OperatorToken::PLUS
200                                         ),
201                                         new Local("c"),
202                                         OperatorToken::MINUS
203                                     )
204                                 )
205     },
206     { "~a + !b;",               Node::Ptr(
207                                     new BinaryOperator(
208                                         new UnaryOperator(new Local("a"), OperatorToken::BITNOT),
209                                         new UnaryOperator(new Local("b"), OperatorToken::NOT),
210                                         OperatorToken::PLUS
211                                     )
212                                 )
213     },
214     { "++a - --b;",             Node::Ptr(
215                                     new BinaryOperator(
216                                         new Crement(new Local("a"), Crement::Operation::Increment, false),
217                                         new Crement(new Local("b"), Crement::Operation::Decrement, false),
218                                         OperatorToken::MINUS
219                                     )
220                                 )
221     },
222     { "a-- + b++;",             Node::Ptr(
223                                     new BinaryOperator(
224                                         new Crement(new Local("a"), Crement::Operation::Decrement, true),
225                                         new Crement(new Local("b"), Crement::Operation::Increment, true),
226                                         OperatorToken::PLUS
227                                     )
228                                 )
229     },
230     { "int(a) + float(b);",     Node::Ptr(
231                                     new BinaryOperator(
232                                         new Cast(new Local("a"), CoreType::INT32),
233                                         new Cast(new Local("b"), CoreType::FLOAT),
234                                         OperatorToken::PLUS
235                                     )
236                                 )
237     },
238     { "{a,b,c} + {d,e,f};",     Node::Ptr(
239                                     new BinaryOperator(
240                                         new ArrayPack({
241                                             new Local("a"),
242                                             new Local("b"),
243                                             new Local("c")
244                                         }),
245                                         new ArrayPack({
246                                             new Local("d"),
247                                             new Local("e"),
248                                             new Local("f")
249                                         }),
250                                         OperatorToken::PLUS
251                                     )
252                                 )
253     },
254     { "a.x + b.y;",             Node::Ptr(
255                                     new BinaryOperator(
256                                         new ArrayUnpack(new Local("a"), new Value<int32_t>(0)),
257                                         new ArrayUnpack(new Local("b"), new Value<int32_t>(1)),
258                                         OperatorToken::PLUS
259                                     )
260                                 )
261     },
262     { "0 + 1;",                 Node::Ptr(
263                                     new BinaryOperator(
264                                         new Value<int32_t>(0),
265                                         new Value<int32_t>(1),
266                                         OperatorToken::PLUS
267                                     )
268                                 )
269     },
270     { "0.0f + 1.0;",            Node::Ptr(
271                                     new BinaryOperator(
272                                         new Value<float>(0.0),
273                                         new Value<double>(1.0),
274                                         OperatorToken::PLUS
275                                     )
276                                 )
277     },
278     { "@a + @b;",               Node::Ptr(
279                                     new BinaryOperator(
280                                         new Attribute("a", CoreType::FLOAT, true),
281                                         new Attribute("b", CoreType::FLOAT, true),
282                                         OperatorToken::PLUS
283                                     )
284                                 )
285     },
286     { "\"a\" + \"b\";",         Node::Ptr(
287                                     new BinaryOperator(
288                                         new Value<std::string>("a"),
289                                         new Value<std::string>("b"),
290                                         OperatorToken::PLUS
291                                     )
292                                 )
293     },
294 };
295 
296 }
297 
298 class TestBinaryOperatorNode : public CppUnit::TestCase
299 {
300 public:
301 
302     CPPUNIT_TEST_SUITE(TestBinaryOperatorNode);
303     CPPUNIT_TEST(testSyntax);
304     CPPUNIT_TEST(testASTNode);
305     CPPUNIT_TEST_SUITE_END();
306 
testSyntax()307     void testSyntax() { TEST_SYNTAX_PASSES(tests); }
308     void testASTNode();
309 };
310 
311 CPPUNIT_TEST_SUITE_REGISTRATION(TestBinaryOperatorNode);
312 
testASTNode()313 void TestBinaryOperatorNode::testASTNode()
314 {
315     for (const auto& test : tests) {
316         const std::string& code = test.first;
317         const Node* expected = test.second.get();
318         const Tree::ConstPtr tree = parse(code.c_str());
319         CPPUNIT_ASSERT_MESSAGE(ERROR_MSG("No AST returned", code), static_cast<bool>(tree));
320 
321         // get the first statement
322         const Node* result = tree->child(0)->child(0);
323         CPPUNIT_ASSERT(result);
324         CPPUNIT_ASSERT_MESSAGE(ERROR_MSG("Invalid AST node", code),
325             Node::BinaryOperatorNode == result->nodetype());
326 
327         std::vector<const Node*> resultList, expectedList;
328         linearize(*result, resultList);
329         linearize(*expected, expectedList);
330 
331         if (!unittest_util::compareLinearTrees(expectedList, resultList)) {
332             std::ostringstream os;
333             os << "\nExpected:\n";
334             openvdb::ax::ast::print(*expected, true, os);
335             os << "Result:\n";
336             openvdb::ax::ast::print(*result, true, os);
337             CPPUNIT_FAIL(ERROR_MSG("Mismatching Trees for Binary Operator code", code) + os.str());
338         }
339     }
340 }
341 
342