1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 #ifndef HELIB_PTRMATRIX_H
13 #define HELIB_PTRMATRIX_H
14 /**
15 * @file PtrMatrix.h
16 * @brief Convenience class templates providing a unified interface
17 * for a matrix of objects, returning pointers to these objects.
18 **/
19 #include <initializer_list>
20 #include <helib/PtrVector.h>
21
22 namespace helib {
23
24 //! @brief An abstract class for an array of PtrVectors
25 template <typename T>
26 struct PtrMatrix
27 {
28 virtual PtrVector<T>& operator[](long) = 0; // returns a row
29 virtual const PtrVector<T>& operator[](long) const = 0; // returns a row
30 virtual long size() const = 0; // How many rows
31 // FIXME: Make this pure virtual
32 #pragma GCC diagnostic push
33 #pragma GCC diagnostic ignored "-Wunused-parameter"
resizePtrMatrix34 virtual void resize(long newSize) // reset the number of rows
35 {
36 throw LogicError("Cannot resize generic PtrMatrix");
37 }
38 #pragma GCC diagnostic pop
~PtrMatrixPtrMatrix39 virtual ~PtrMatrix() {}
40
41 // Return a pointer to some non-Null T, if it can find one.
42 // This is convenient since T may not have an empty constructor
ptr2nonNullPtrMatrix43 virtual const T* ptr2nonNull() const
44 {
45 for (long i = 0; i < size(); i++) {
46 const T* pt = (*this)[i].ptr2nonNull();
47 if (pt != nullptr)
48 return pt;
49 }
50 return nullptr;
51 }
52 };
53
54 template <typename T>
lsize(const PtrMatrix<T> & v)55 long lsize(const PtrMatrix<T>& v)
56 {
57 return v.size();
58 }
59 template <typename T>
resize(PtrMatrix<T> & v,long newSize)60 void resize(PtrMatrix<T>& v, long newSize)
61 {
62 v.resize(newSize);
63 }
64 template <typename T>
setLengthZero(PtrMatrix<T> & v)65 void setLengthZero(PtrMatrix<T>& v)
66 {
67 v.resize(0);
68 }
69 // implementation of resize function below
70
71 // This header provides some implementations of these interfaces, but
72 // users can define their own as needed. The ones defined here are:
73
74 // struct PtrMatrix_Vec; // NTL::Vec<NTL::Vec<T>>
75 // struct PtrMatrix_vector; // std::vector<std::vector<T>>
76 // struct PtrMatrix_ptVec; // NTL::Vec<NTL::Vec<T>*>
77 // struct PtrMatrix_ptvector; // std::vector<std::vector<T>*>
78
79 template <typename T>
ptr2nonNull(std::initializer_list<const PtrVector<T> * > list)80 const T* ptr2nonNull(std::initializer_list<const PtrVector<T>*> list)
81 {
82 for (auto elem : list) {
83 const T* ptr = elem->ptr2nonNull();
84 if (ptr != nullptr)
85 return ptr;
86 }
87 return nullptr;
88 }
89
90 /*******************************************************************/
91 /* Implementation details: applications should not care about them */
92 /*******************************************************************/
93
94 //! @brief An implementation of PtrMatrix using Vec< Vec<T> >
95 template <typename T>
96 struct PtrMatrix_Vec : PtrMatrix<T>
97 {
98 NTL::Vec<NTL::Vec<T>>& buffer;
99 std::vector<PtrVector_VecT<T>> rows;
100 // rows[i] is a PtrVector_VecT<T> object 'pointing' to buffer[i]
101 // the above uses std::vector to be able to use emplace
102
PtrMatrix_VecPtrMatrix_Vec103 PtrMatrix_Vec(NTL::Vec<NTL::Vec<T>>& mat) : buffer(mat)
104 {
105 rows.reserve(lsize(mat)); // allocate memory
106 for (int i = 0; i < lsize(mat); i++) // initialize
107 rows.emplace_back(buffer[i]);
108 }
109 PtrVector<T>& operator[](long i) override // returns a row
110 {
111 return rows[i];
112 }
113 const PtrVector<T>& operator[](long i) const override // returns a row
114 {
115 return rows[i];
116 }
sizePtrMatrix_Vec117 long size() const override { return lsize(rows); } // How many rows
resizePtrMatrix_Vec118 void resize(long newSize) override // reset the number of rows
119 {
120 long oldSize = size();
121 if (oldSize == newSize)
122 return; // nothing to do
123
124 buffer.SetLength(newSize); // resize buffer, then add/delete 'pointers'
125 if (newSize > oldSize) {
126 rows.reserve(newSize);
127 for (int i = oldSize; i < newSize; i++)
128 rows.emplace_back(buffer[i]);
129 }
130 // else rows.resize(newSize);
131 // Can't shrink without operator=
132 else {
133 std::cerr << "Attempt to shrink PtrMatrix_Vec failed\n";
134 }
135 }
136 };
137
138 //! @brief An implementation of PtrMatrix using Vec< Vec<T>* >
139 template <typename T>
140 struct PtrMatrix_ptVec : PtrMatrix<T>
141 {
142 NTL::Vec<NTL::Vec<T>*>& buffer;
143 std::vector<PtrVector_VecT<T>> rows;
144 // rows[i] is a PtrVector_VecT<T> object 'pointing' to *buffer[i]
145 // the above uses std::vector to be able to use emplace
146
PtrMatrix_ptVecPtrMatrix_ptVec147 PtrMatrix_ptVec(NTL::Vec<NTL::Vec<T>*>& mat) : buffer(mat)
148 {
149 rows.reserve(lsize(mat)); // allocate memory
150 for (int i = 0; i < lsize(mat); i++) // initialize
151 rows.emplace_back(*(buffer[i]));
152 }
153 PtrVector<T>& operator[](long i) override // returns a row
154 {
155 return rows[i];
156 }
157 const PtrVector<T>& operator[](long i) const override // returns a row
158 {
159 return rows[i];
160 }
sizePtrMatrix_ptVec161 long size() const override { return lsize(rows); } // How many rows
162 };
163
164 //! @brief An implementation of PtrMatrix using vector< vector<T> >
165 template <typename T>
166 struct PtrMatrix_vector : PtrMatrix<T>
167 {
168 std::vector<std::vector<T>>& buffer;
169 std::vector<PtrVector_vectorT<T>> rows;
170 // rows[i] is a PtrVector_vectorT<T> object 'pointing' to buffer[i]
171
PtrMatrix_vectorPtrMatrix_vector172 PtrMatrix_vector(std::vector<std::vector<T>>& mat) : buffer(mat)
173 {
174 rows.reserve(lsize(mat)); // allocate memory
175 for (int i = 0; i < lsize(mat); i++) // initialize
176 rows.emplace_back(buffer[i]);
177 }
178 PtrVector<T>& operator[](long i) override // returns a row
179 {
180 return rows[i];
181 }
182 const PtrVector<T>& operator[](long i) const override // returns a row
183 {
184 return rows[i];
185 }
sizePtrMatrix_vector186 long size() const override { return lsize(rows); } // How many rows
resizePtrMatrix_vector187 void resize(long newSize) override // reset the number of rows
188 {
189 long oldSize = size();
190 if (oldSize == newSize)
191 return; // nothing to do
192
193 buffer.resize(newSize); // resize buffer, then add/delete 'pointers'
194 if (newSize > oldSize) {
195 rows.reserve(newSize);
196 for (int i = oldSize; i < newSize; i++)
197 rows.emplace_back(buffer[i]);
198 }
199 // else rows.resize(newSize);
200 // Can't shrink without operator=
201 else {
202 std::cerr << "Attempt to shrink PtrMatrix_vector failed\n";
203 }
204 }
205 };
206
207 //! @brief An implementation of PtrMatrix using vector< vector<T>* >
208 template <typename T>
209 struct PtrMatrix_ptvector : PtrMatrix<T>
210 {
211 std::vector<std::vector<T>*>& buffer;
212 std::vector<PtrVector_vectorT<T>> rows;
213 // rows[i] is a PtrVector_vectorT<T> object 'pointing' to *buffer[i]
214
PtrMatrix_ptvectorPtrMatrix_ptvector215 PtrMatrix_ptvector(std::vector<std::vector<T>*>& mat) : buffer(mat)
216 {
217 rows.reserve(lsize(mat)); // allocate memory
218 for (int i = 0; i < lsize(mat); i++) // initialize
219 rows.emplace_back(*(buffer[i]));
220 }
221 PtrVector<T>& operator[](long i) override // returns a row
222 {
223 return rows[i];
224 }
225 const PtrVector<T>& operator[](long i) const override // returns a row
226 {
227 return rows[i];
228 }
sizePtrMatrix_ptvector229 long size() const override { return lsize(rows); } // How many rows
230 };
231
232 } // namespace helib
233
234 #endif // ifndef HELIB_PTRMATRIX_H
235