1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 #include <iostream>
13 #include <NTL/tools.h>
14 #include <helib/NumbTh.h>
15 #include <helib/PtrVector.h>
16 #include <helib/PtrMatrix.h>
17 #include <helib/debugging.h>
18
19 #include "gtest/gtest.h"
20 #include "test_common.h"
21
22 namespace {
23
24 // A class with no default constructor
25 class MyClass
26 {
27 int myInt;
MyClass()28 MyClass(){}; // private
29 public:
MyClass(int i)30 MyClass(int i) : myInt(i) {}
get() const31 int get() const { return myInt; }
set(int i)32 void set(int i) { myInt = i; }
33 };
34
35 class GTestPtrVector : public ::testing::Test
36 {
37 protected:
GTestPtrVector()38 GTestPtrVector() : vLength(6), zero(0){};
39
TearDown()40 virtual void TearDown() override { helib::cleanupDebugGlobals(); }
41
42 const int vLength;
43 MyClass zero;
44 };
45
46 typedef helib::PtrVector<MyClass> MyPtrVec;
47 typedef helib::PtrVector_VecT<MyClass> MyPtrVec_Vec;
48 typedef helib::PtrVector_VecPt<MyClass> MyPtrVec_VecPt;
49 typedef helib::PtrVector_vectorT<MyClass> MyPtrVec_vector;
50 typedef helib::PtrVector_vectorPt<MyClass> MyPtrVec_vectorPt;
51
52 typedef helib::PtrVector_slice<MyClass> MyPtrVec_slice; // A slice of MyPtrVec
53
54 typedef helib::PtrMatrix<MyClass> MyPtrMat;
55 typedef helib::PtrMatrix_Vec<MyClass> MyPtrMat_Vec;
56 typedef helib::PtrMatrix_ptVec<MyClass> MyPtrMat_ptVec;
57 typedef helib::PtrMatrix_vector<MyClass> MyPtrMat_vector;
58 typedef helib::PtrMatrix_ptvector<MyClass> MyPtrMat_ptvector;
59
60 // compare a "generic" vectors to pointers to vector to objects
61 template <typename T2>
pointersEqual(const MyPtrVec & a,T2 & b)62 ::testing::AssertionResult pointersEqual(const MyPtrVec& a, T2& b)
63 {
64 if (helib::lsize(a) != helib::lsize(b)) {
65 return ::testing::AssertionFailure()
66 << "sizes do not match (" << helib::lsize(a) << " vs "
67 << helib::lsize(b) << ")";
68 }
69 for (long i = 0; i < helib::lsize(b); i++) {
70 if (a[i] != &b[i]) {
71 return ::testing::AssertionFailure()
72 << "difference found in the " << i << "th position: " << a[i]
73 << " vs " << &b[i];
74 }
75 }
76 return ::testing::AssertionSuccess();
77 }
78
test1(MyClass array[],int length,const MyPtrVec & ptrs)79 void test1(MyClass array[], int length, const MyPtrVec& ptrs)
80 {
81 if (helib_test::verbose) {
82 std::cout << "test1 " << std::flush;
83 }
84 for (int i = 0; i < length; i++)
85 ASSERT_EQ(ptrs[i], &(array[i]));
86 ASSERT_EQ(ptrs.numNonNull(), length);
87 ASSERT_EQ(ptrs.size(), length);
88 const MyClass* pt = ptrs.ptr2nonNull();
89 if (length > 0) {
90 ASSERT_NE(pt, nullptr) << "but length > 0";
91 } else if (length <= 0) {
92 ASSERT_EQ(pt, nullptr) << "however length <= 0";
93 }
94 }
95
test2(MyClass * array[],int length,const MyPtrVec & ptrs)96 void test2(MyClass* array[], int length, const MyPtrVec& ptrs)
97 {
98 if (helib_test::verbose) {
99 std::cout << "test2 " << std::flush;
100 }
101 for (int i = 0; i < length; i++)
102 ASSERT_EQ(ptrs[i], array[i]);
103 ASSERT_EQ(ptrs.size(), length);
104
105 ASSERT_EQ(ptrs.numNonNull(), std::min(4, length));
106 ASSERT_NE(ptrs.ptr2nonNull(), nullptr);
107 }
108
printPtrVector(const MyPtrVec & ptrs)109 void printPtrVector(const MyPtrVec& ptrs)
110 {
111 if (helib_test::verbose) {
112 for (int i = 0; i < ptrs.size(); i++) {
113 std::cout << ((i == 0) ? '[' : ',');
114 MyClass* pt = ptrs[i];
115 if (pt == nullptr)
116 std::cout << "null";
117 else
118 std::cout << pt->get();
119 }
120 std::cout << ']';
121 }
122 }
test3(MyPtrVec & ptrs)123 void test3(MyPtrVec& ptrs)
124 {
125 if (helib_test::verbose) {
126 std::cout << "\nBefore resize: ";
127 printPtrVector(ptrs);
128
129 int length = ptrs.size();
130 ptrs.resize(length + 1);
131 ptrs[length]->set(length + 1);
132
133 std::cout << "\n After resize: ";
134 printPtrVector(ptrs);
135 std::cout << std::endl;
136 }
137 }
138
139 template <typename T>
test4(const MyPtrMat & mat,const T & array)140 void test4(const MyPtrMat& mat, const T& array)
141 {
142 if (helib_test::verbose) {
143 std::cout << "test4 " << std::flush;
144 }
145 ASSERT_EQ(mat.size(), helib::lsize(array));
146 for (int i = 0; i < helib::lsize(array); i++)
147 ASSERT_TRUE(pointersEqual(mat[i], array[i]));
148 }
149
150 template <typename T>
test5(const MyPtrMat & mat,const T & array)151 void test5(const MyPtrMat& mat, const T& array)
152 {
153 if (helib_test::verbose) {
154 std::cout << "test5 " << std::flush;
155 }
156 ASSERT_EQ(mat.size(), helib::lsize(array));
157 for (int i = 0; i < helib::lsize(array); i++)
158 ASSERT_TRUE(pointersEqual(mat[i], *array[i]));
159 }
160
TEST_F(GTestPtrVector,pointerVectorsRemainConsistent)161 TEST_F(GTestPtrVector, pointerVectorsRemainConsistent)
162 {
163 MyClass zero(0);
164 std::vector<MyClass> v1(vLength, zero);
165 NTL::Vec<MyClass> v2(NTL::INIT_SIZE, vLength, zero);
166
167 std::vector<MyClass*> v3(vLength, nullptr);
168 for (int i = 1; i < vLength - 1; i++)
169 v3[i] = &(v1[i]);
170
171 NTL::Vec<MyClass*> v4(NTL::INIT_SIZE, vLength, nullptr);
172 for (int i = 1; i < vLength - 1; i++)
173 v4[i] = &(v2[i]);
174
175 MyPtrVec_vector vv1(v1);
176 MyPtrVec_VecPt vv4(v4);
177
178 ASSERT_NO_FATAL_FAILURE(test1(&v1[0], 6, vv1));
179 ASSERT_NO_FATAL_FAILURE(test1(&v2[0], 6, MyPtrVec_Vec(v2)));
180
181 MyPtrVec_slice vs1(vv1, 1);
182 ASSERT_NO_FATAL_FAILURE(test1(&v1[1], 5, vs1));
183 MyPtrVec_slice vss1(vs1, 1, 3);
184 ASSERT_NO_FATAL_FAILURE(test1(&v1[2], 3, vss1));
185
186 ASSERT_NO_FATAL_FAILURE(test2(&v3[0], 6, MyPtrVec_vectorPt(v3)));
187 ASSERT_NO_FATAL_FAILURE(test2(&v4[0], 6, vv4));
188
189 MyPtrVec_slice vs4(vv4, 1);
190 ASSERT_NO_FATAL_FAILURE(test2(&v4[1], 5, vs4));
191 MyPtrVec_slice vss4(vs4, 1, 3);
192 ASSERT_NO_FATAL_FAILURE(test2(&v4[2], 3, vss4));
193
194 ASSERT_NO_FATAL_FAILURE(test3(vv1));
195
196 std::vector<std::vector<MyClass>> mat1(6);
197 std::vector<std::vector<MyClass>*> mat2(6);
198 NTL::Vec<NTL::Vec<MyClass>> mat3(NTL::INIT_SIZE, 6);
199 NTL::Vec<NTL::Vec<MyClass>*> mat4(NTL::INIT_SIZE, 6);
200 for (long i = 0; i < 6; i++) {
201 mat1[i].resize(4, MyClass(i));
202 mat2[5 - i] = &mat1[i];
203
204 mat3[i].SetLength(3, MyClass(i + 10));
205 mat4[5 - i] = &mat3[i];
206 }
207
208 ASSERT_NO_FATAL_FAILURE(test4(MyPtrMat_vector(mat1), mat1));
209 ASSERT_NO_FATAL_FAILURE(test4(MyPtrMat_Vec(mat3), mat3));
210
211 ASSERT_NO_FATAL_FAILURE(test5(MyPtrMat_ptvector(mat2), mat2));
212 ASSERT_NO_FATAL_FAILURE(test5(MyPtrMat_ptVec(mat4), mat4));
213 if (helib_test::verbose) {
214 // Tidy up newline-free previous output
215 std::cout << std::endl;
216 }
217 }
218
219 } // namespace
220