1 // Copyright Contributors to the OpenVDB Project
2 // SPDX-License-Identifier: MPL-2.0
3 
4 #include <openvdb_ax/ast/AST.h>
5 #include <openvdb_ax/ast/Scanners.h>
6 #include <openvdb_ax/ast/PrintTree.h>
7 #include <openvdb_ax/Exceptions.h>
8 
9 #include "../util.h"
10 
11 #include <cppunit/extensions/HelperMacros.h>
12 
13 #include <string>
14 
15 using namespace openvdb::ax::ast;
16 using namespace openvdb::ax::ast::tokens;
17 
18 namespace {
19 
20 static const unittest_util::CodeTests tests =
21 {
22     { "a.x;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(0))) },
23     { "a.y;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(1))) },
24     { "a.z;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(2))) },
25     { "a.r;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(0))) },
26     { "a.g;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(1))) },
27     { "a.b;",             Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(2))) },
28     { "x.x;",             Node::Ptr(new ArrayUnpack(new Local("x"), new Value<int32_t>(0))) },
29     { "@x.x;",            Node::Ptr(new ArrayUnpack(new Attribute("x", CoreType::FLOAT, true), new Value<int32_t>(0))) },
30     { "@a.x;",            Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(0))) },
31     { "@b.y;",            Node::Ptr(new ArrayUnpack(new Attribute("b", CoreType::FLOAT, true), new Value<int32_t>(1))) },
32     { "@c.z;",            Node::Ptr(new ArrayUnpack(new Attribute("c", CoreType::FLOAT, true), new Value<int32_t>(2))) },
33     { "@a.r;",            Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(0))) },
34     { "@a.g;",            Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(1))) },
35     { "@a.b;",            Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(2))) },
36     { "@a[0l];",          Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int64_t>(0))) },
37     { "@a[0];",           Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(0))) },
38     { "@a[1];",           Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(1))) },
39     { "@a[2];",           Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(2))) },
40     { "@a[0.0f];",        Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<float>(0.0f))) },
41     { "@a[0.0];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<double>(0.0))) },
42     { "@a[\"str\"];",     Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<std::string>("str"))) },
43     { "@a[true];",        Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<bool>(true))) },
44     { "@a[false];",       Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<bool>(false))) },
45     { "@a[a];",           Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("a"))) },
46     { "@a[0,0];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(0), new Value<int32_t>(0))) },
47     { "@a[1,0];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(1), new Value<int32_t>(0))) },
48     { "@a[2,0];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Value<int32_t>(2), new Value<int32_t>(0))) },
49     { "a[0,0];",          Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(0), new Value<int32_t>(0))) },
50     { "a[1,0];",          Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(1), new Value<int32_t>(0))) },
51     { "a[2,0];",          Node::Ptr(new ArrayUnpack(new Local("a"), new Value<int32_t>(2), new Value<int32_t>(0))) },
52     { "@a[a,0];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("a"), new Value<int32_t>(0))) },
53     { "@a[b,1];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("b"), new Value<int32_t>(1))) },
54     { "@a[c,2];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("c"), new Value<int32_t>(2))) },
55     { "@a[a,d];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("a"), new Local("d"))) },
56     { "@a[b,e];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("b"), new Local("e"))) },
57     { "@a[c,f];",         Node::Ptr(new ArrayUnpack(new Attribute("a", CoreType::FLOAT, true), new Local("c"), new Local("f"))) },
58     //
59     { "a[(a),1+1];",      Node::Ptr(new ArrayUnpack(new Local("a"),
60                              new Local("a"),
61                              new BinaryOperator(new Value<int32_t>(1), new Value<int32_t>(1), OperatorToken::PLUS)))
62     },
63     { "a[!0,a=b];",       Node::Ptr(new ArrayUnpack(new Local("a"),
64                              new UnaryOperator(new Value<int32_t>(0), OperatorToken::NOT),
65                              new AssignExpression(new Local("a"), new Local("b"))))
66     },
67     { "a[test(),$A];",    Node::Ptr(new ArrayUnpack(new Local("a"),
68                              new FunctionCall("test"),
69                              new ExternalVariable("A", CoreType::FLOAT)))
70     },
71     { "a[a++,++a];",      Node::Ptr(new ArrayUnpack(new Local("a"),
72                              new Crement(new Local("a"), Crement::Operation::Increment, true),
73                              new Crement(new Local("a"), Crement::Operation::Increment, false)))
74     },
75     { "a[a[0,0],0];",     Node::Ptr(new ArrayUnpack(new Local("a"),
76                              new ArrayUnpack(new Local("a"), new Value<int32_t>(0), new Value<int32_t>(0)),
77                              new Value<int32_t>(0)))
78     },
79     { "a[(1,2,3)];",    Node::Ptr(new ArrayUnpack(new Local("a"),
80                             new CommaOperator({
81                                 new Value<int32_t>(1),
82                                 new Value<int32_t>(2),
83                                 new Value<int32_t>(3)
84                             })
85                         ))
86     },
87     { "a[(1,2,3),(4,5,6)];",    Node::Ptr(new ArrayUnpack(new Local("a"),
88                                     new CommaOperator({
89                                         new Value<int32_t>(1),
90                                         new Value<int32_t>(2),
91                                         new Value<int32_t>(3),
92                                     }),
93                                     new CommaOperator({
94                                         new Value<int32_t>(4),
95                                         new Value<int32_t>(5),
96                                         new Value<int32_t>(6),
97                                     })
98                                 ))
99     },
100     { "a[a[0,0],a[0]];",  Node::Ptr(new ArrayUnpack(new Local("a"),
101                              new ArrayUnpack(new Local("a"), new Value<int32_t>(0), new Value<int32_t>(0)),
102                              new ArrayUnpack(new Local("a"), new Value<int32_t>(0))))
103     }
104     // @todo  should this be a syntax error
105     // { "@a[{1,2,3},{1,2,3,4}];", }
106 };
107 
108 }
109 
110 class TestArrayUnpackNode : public CppUnit::TestCase
111 {
112 public:
113 
114     CPPUNIT_TEST_SUITE(TestArrayUnpackNode);
115     CPPUNIT_TEST(testSyntax);
116     CPPUNIT_TEST(testASTNode);
117     CPPUNIT_TEST_SUITE_END();
118 
testSyntax()119     void testSyntax() { TEST_SYNTAX_PASSES(tests); }
120     void testASTNode();
121 };
122 
123 CPPUNIT_TEST_SUITE_REGISTRATION(TestArrayUnpackNode);
124 
testASTNode()125 void TestArrayUnpackNode::testASTNode()
126 {
127     for (const auto& test : tests) {
128         const std::string& code = test.first;
129         const Node* expected = test.second.get();
130         const Tree::ConstPtr tree = parse(code.c_str());
131         CPPUNIT_ASSERT_MESSAGE(ERROR_MSG("No AST returned", code), static_cast<bool>(tree));
132 
133         // get the first statement
134         const Node* result = tree->child(0)->child(0);
135         CPPUNIT_ASSERT(result);
136         CPPUNIT_ASSERT_MESSAGE(ERROR_MSG("Invalid AST node", code),
137             Node::ArrayUnpackNode == result->nodetype());
138 
139         std::vector<const Node*> resultList, expectedList;
140         linearize(*result, resultList);
141         linearize(*expected, expectedList);
142 
143         if (!unittest_util::compareLinearTrees(expectedList, resultList)) {
144             std::ostringstream os;
145             os << "\nExpected:\n";
146             openvdb::ax::ast::print(*expected, true, os);
147             os << "Result:\n";
148             openvdb::ax::ast::print(*result, true, os);
149             CPPUNIT_FAIL(ERROR_MSG("Mismatching Trees for Array Unpack code", code) + os.str());
150         }
151     }
152 }
153 
154