1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <set>
16 
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include "source/fuzz/equivalence_relation.h"
20 
21 namespace spvtools {
22 namespace fuzz {
23 namespace {
24 
25 struct UInt32Equals {
operator ()spvtools::fuzz::__anon3c017d750111::UInt32Equals26   bool operator()(const uint32_t* first, const uint32_t* second) const {
27     return *first == *second;
28   }
29 };
30 
31 struct UInt32Hash {
operator ()spvtools::fuzz::__anon3c017d750111::UInt32Hash32   size_t operator()(const uint32_t* element) const {
33     return static_cast<size_t>(*element);
34   }
35 };
36 
ToUIntVector(const std::vector<const uint32_t * > & pointers)37 std::vector<uint32_t> ToUIntVector(
38     const std::vector<const uint32_t*>& pointers) {
39   std::vector<uint32_t> result;
40   for (auto pointer : pointers) {
41     result.push_back(*pointer);
42   }
43   return result;
44 }
45 
TEST(EquivalenceRelationTest,BasicTest)46 TEST(EquivalenceRelationTest, BasicTest) {
47   EquivalenceRelation<uint32_t, UInt32Hash, UInt32Equals> relation;
48   ASSERT_TRUE(relation.GetAllKnownValues().empty());
49 
50   for (uint32_t element = 2; element < 80; element += 2) {
51     relation.MakeEquivalent(0, element);
52     relation.MakeEquivalent(element - 1, element + 1);
53   }
54 
55   for (uint32_t element = 82; element < 100; element += 2) {
56     relation.MakeEquivalent(80, element);
57     relation.MakeEquivalent(element - 1, element + 1);
58   }
59 
60   relation.MakeEquivalent(78, 80);
61 
62   std::vector<uint32_t> class1;
63   for (uint32_t element = 0; element < 98; element += 2) {
64     ASSERT_TRUE(relation.IsEquivalent(0, element));
65     ASSERT_TRUE(relation.IsEquivalent(element, element + 2));
66     class1.push_back(element);
67   }
68   class1.push_back(98);
69 
70   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(0)),
71               testing::WhenSorted(class1));
72   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(4)),
73               testing::WhenSorted(class1));
74   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(40)),
75               testing::WhenSorted(class1));
76 
77   std::vector<uint32_t> class2;
78   for (uint32_t element = 1; element < 79; element += 2) {
79     ASSERT_TRUE(relation.IsEquivalent(1, element));
80     ASSERT_TRUE(relation.IsEquivalent(element, element + 2));
81     class2.push_back(element);
82   }
83   class2.push_back(79);
84   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(1)),
85               testing::WhenSorted(class2));
86   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(11)),
87               testing::WhenSorted(class2));
88   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(31)),
89               testing::WhenSorted(class2));
90 
91   std::vector<uint32_t> class3;
92   for (uint32_t element = 81; element < 99; element += 2) {
93     ASSERT_TRUE(relation.IsEquivalent(81, element));
94     ASSERT_TRUE(relation.IsEquivalent(element, element + 2));
95     class3.push_back(element);
96   }
97   class3.push_back(99);
98   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(81)),
99               testing::WhenSorted(class3));
100   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(91)),
101               testing::WhenSorted(class3));
102   ASSERT_THAT(ToUIntVector(relation.GetEquivalenceClass(99)),
103               testing::WhenSorted(class3));
104 
105   bool first = true;
106   std::vector<const uint32_t*> previous_class;
107   for (auto representative : relation.GetEquivalenceClassRepresentatives()) {
108     std::vector<const uint32_t*> current_class =
109         relation.GetEquivalenceClass(*representative);
110     ASSERT_TRUE(std::find(current_class.begin(), current_class.end(),
111                           representative) != current_class.end());
112     if (!first) {
113       ASSERT_TRUE(std::find(previous_class.begin(), previous_class.end(),
114                             representative) == previous_class.end());
115     }
116     previous_class = current_class;
117     first = false;
118   }
119 }
120 
121 }  // namespace
122 }  // namespace fuzz
123 }  // namespace spvtools
124