1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #include <iostream>
13 #include <NTL/tools.h>
14 #include <helib/NumbTh.h>
15 #include <helib/PtrVector.h>
16 #include <helib/PtrMatrix.h>
17 #include <helib/debugging.h>
18 
19 #include "gtest/gtest.h"
20 #include "test_common.h"
21 
22 namespace {
23 
24 // A class with no default constructor
25 class MyClass
26 {
27   int myInt;
MyClass()28   MyClass(){}; // private
29 public:
MyClass(int i)30   MyClass(int i) : myInt(i) {}
get() const31   int get() const { return myInt; }
set(int i)32   void set(int i) { myInt = i; }
33 };
34 
35 class GTestPtrVector : public ::testing::Test
36 {
37 protected:
GTestPtrVector()38   GTestPtrVector() : vLength(6), zero(0){};
39 
TearDown()40   virtual void TearDown() override { helib::cleanupDebugGlobals(); }
41 
42   const int vLength;
43   MyClass zero;
44 };
45 
46 typedef helib::PtrVector<MyClass> MyPtrVec;
47 typedef helib::PtrVector_VecT<MyClass> MyPtrVec_Vec;
48 typedef helib::PtrVector_VecPt<MyClass> MyPtrVec_VecPt;
49 typedef helib::PtrVector_vectorT<MyClass> MyPtrVec_vector;
50 typedef helib::PtrVector_vectorPt<MyClass> MyPtrVec_vectorPt;
51 
52 typedef helib::PtrVector_slice<MyClass> MyPtrVec_slice; // A slice of MyPtrVec
53 
54 typedef helib::PtrMatrix<MyClass> MyPtrMat;
55 typedef helib::PtrMatrix_Vec<MyClass> MyPtrMat_Vec;
56 typedef helib::PtrMatrix_ptVec<MyClass> MyPtrMat_ptVec;
57 typedef helib::PtrMatrix_vector<MyClass> MyPtrMat_vector;
58 typedef helib::PtrMatrix_ptvector<MyClass> MyPtrMat_ptvector;
59 
60 // compare a "generic" vectors to pointers to vector to objects
61 template <typename T2>
pointersEqual(const MyPtrVec & a,T2 & b)62 ::testing::AssertionResult pointersEqual(const MyPtrVec& a, T2& b)
63 {
64   if (helib::lsize(a) != helib::lsize(b)) {
65     return ::testing::AssertionFailure()
66            << "sizes do not match (" << helib::lsize(a) << " vs "
67            << helib::lsize(b) << ")";
68   }
69   for (long i = 0; i < helib::lsize(b); i++) {
70     if (a[i] != &b[i]) {
71       return ::testing::AssertionFailure()
72              << "difference found in the " << i << "th position: " << a[i]
73              << " vs " << &b[i];
74     }
75   }
76   return ::testing::AssertionSuccess();
77 }
78 
test1(MyClass array[],int length,const MyPtrVec & ptrs)79 void test1(MyClass array[], int length, const MyPtrVec& ptrs)
80 {
81   if (helib_test::verbose) {
82     std::cout << "test1 " << std::flush;
83   }
84   for (int i = 0; i < length; i++)
85     ASSERT_EQ(ptrs[i], &(array[i]));
86   ASSERT_EQ(ptrs.numNonNull(), length);
87   ASSERT_EQ(ptrs.size(), length);
88   const MyClass* pt = ptrs.ptr2nonNull();
89   if (length > 0) {
90     ASSERT_NE(pt, nullptr) << "but length > 0";
91   } else if (length <= 0) {
92     ASSERT_EQ(pt, nullptr) << "however length <= 0";
93   }
94 }
95 
test2(MyClass * array[],int length,const MyPtrVec & ptrs)96 void test2(MyClass* array[], int length, const MyPtrVec& ptrs)
97 {
98   if (helib_test::verbose) {
99     std::cout << "test2 " << std::flush;
100   }
101   for (int i = 0; i < length; i++)
102     ASSERT_EQ(ptrs[i], array[i]);
103   ASSERT_EQ(ptrs.size(), length);
104 
105   ASSERT_EQ(ptrs.numNonNull(), std::min(4, length));
106   ASSERT_NE(ptrs.ptr2nonNull(), nullptr);
107 }
108 
printPtrVector(const MyPtrVec & ptrs)109 void printPtrVector(const MyPtrVec& ptrs)
110 {
111   if (helib_test::verbose) {
112     for (int i = 0; i < ptrs.size(); i++) {
113       std::cout << ((i == 0) ? '[' : ',');
114       MyClass* pt = ptrs[i];
115       if (pt == nullptr)
116         std::cout << "null";
117       else
118         std::cout << pt->get();
119     }
120     std::cout << ']';
121   }
122 }
test3(MyPtrVec & ptrs)123 void test3(MyPtrVec& ptrs)
124 {
125   if (helib_test::verbose) {
126     std::cout << "\nBefore resize: ";
127     printPtrVector(ptrs);
128 
129     int length = ptrs.size();
130     ptrs.resize(length + 1);
131     ptrs[length]->set(length + 1);
132 
133     std::cout << "\n After resize: ";
134     printPtrVector(ptrs);
135     std::cout << std::endl;
136   }
137 }
138 
139 template <typename T>
test4(const MyPtrMat & mat,const T & array)140 void test4(const MyPtrMat& mat, const T& array)
141 {
142   if (helib_test::verbose) {
143     std::cout << "test4 " << std::flush;
144   }
145   ASSERT_EQ(mat.size(), helib::lsize(array));
146   for (int i = 0; i < helib::lsize(array); i++)
147     ASSERT_TRUE(pointersEqual(mat[i], array[i]));
148 }
149 
150 template <typename T>
test5(const MyPtrMat & mat,const T & array)151 void test5(const MyPtrMat& mat, const T& array)
152 {
153   if (helib_test::verbose) {
154     std::cout << "test5 " << std::flush;
155   }
156   ASSERT_EQ(mat.size(), helib::lsize(array));
157   for (int i = 0; i < helib::lsize(array); i++)
158     ASSERT_TRUE(pointersEqual(mat[i], *array[i]));
159 }
160 
TEST_F(GTestPtrVector,pointerVectorsRemainConsistent)161 TEST_F(GTestPtrVector, pointerVectorsRemainConsistent)
162 {
163   MyClass zero(0);
164   std::vector<MyClass> v1(vLength, zero);
165   NTL::Vec<MyClass> v2(NTL::INIT_SIZE, vLength, zero);
166 
167   std::vector<MyClass*> v3(vLength, nullptr);
168   for (int i = 1; i < vLength - 1; i++)
169     v3[i] = &(v1[i]);
170 
171   NTL::Vec<MyClass*> v4(NTL::INIT_SIZE, vLength, nullptr);
172   for (int i = 1; i < vLength - 1; i++)
173     v4[i] = &(v2[i]);
174 
175   MyPtrVec_vector vv1(v1);
176   MyPtrVec_VecPt vv4(v4);
177 
178   ASSERT_NO_FATAL_FAILURE(test1(&v1[0], 6, vv1));
179   ASSERT_NO_FATAL_FAILURE(test1(&v2[0], 6, MyPtrVec_Vec(v2)));
180 
181   MyPtrVec_slice vs1(vv1, 1);
182   ASSERT_NO_FATAL_FAILURE(test1(&v1[1], 5, vs1));
183   MyPtrVec_slice vss1(vs1, 1, 3);
184   ASSERT_NO_FATAL_FAILURE(test1(&v1[2], 3, vss1));
185 
186   ASSERT_NO_FATAL_FAILURE(test2(&v3[0], 6, MyPtrVec_vectorPt(v3)));
187   ASSERT_NO_FATAL_FAILURE(test2(&v4[0], 6, vv4));
188 
189   MyPtrVec_slice vs4(vv4, 1);
190   ASSERT_NO_FATAL_FAILURE(test2(&v4[1], 5, vs4));
191   MyPtrVec_slice vss4(vs4, 1, 3);
192   ASSERT_NO_FATAL_FAILURE(test2(&v4[2], 3, vss4));
193 
194   ASSERT_NO_FATAL_FAILURE(test3(vv1));
195 
196   std::vector<std::vector<MyClass>> mat1(6);
197   std::vector<std::vector<MyClass>*> mat2(6);
198   NTL::Vec<NTL::Vec<MyClass>> mat3(NTL::INIT_SIZE, 6);
199   NTL::Vec<NTL::Vec<MyClass>*> mat4(NTL::INIT_SIZE, 6);
200   for (long i = 0; i < 6; i++) {
201     mat1[i].resize(4, MyClass(i));
202     mat2[5 - i] = &mat1[i];
203 
204     mat3[i].SetLength(3, MyClass(i + 10));
205     mat4[5 - i] = &mat3[i];
206   }
207 
208   ASSERT_NO_FATAL_FAILURE(test4(MyPtrMat_vector(mat1), mat1));
209   ASSERT_NO_FATAL_FAILURE(test4(MyPtrMat_Vec(mat3), mat3));
210 
211   ASSERT_NO_FATAL_FAILURE(test5(MyPtrMat_ptvector(mat2), mat2));
212   ASSERT_NO_FATAL_FAILURE(test5(MyPtrMat_ptVec(mat4), mat4));
213   if (helib_test::verbose) {
214     // Tidy up newline-free previous output
215     std::cout << std::endl;
216   }
217 }
218 
219 } // namespace
220