1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #ifndef HELIB_PTRMATRIX_H
13 #define HELIB_PTRMATRIX_H
14 /**
15  * @file PtrMatrix.h
16  * @brief Convenience class templates providing a unified interface
17  *    for a matrix of objects, returning pointers to these objects.
18  **/
19 #include <initializer_list>
20 #include <helib/PtrVector.h>
21 
22 namespace helib {
23 
24 //! @brief An abstract class for an array of PtrVectors
25 template <typename T>
26 struct PtrMatrix
27 {
28   virtual PtrVector<T>& operator[](long) = 0;             // returns a row
29   virtual const PtrVector<T>& operator[](long) const = 0; // returns a row
30   virtual long size() const = 0;                          // How many rows
31   // FIXME: Make this pure virtual
32 #pragma GCC diagnostic push
33 #pragma GCC diagnostic ignored "-Wunused-parameter"
resizePtrMatrix34   virtual void resize(long newSize) // reset the number of rows
35   {
36     throw LogicError("Cannot resize generic PtrMatrix");
37   }
38 #pragma GCC diagnostic pop
~PtrMatrixPtrMatrix39   virtual ~PtrMatrix() {}
40 
41   // Return a pointer to some non-Null T, if it can find one.
42   // This is convenient since T may not have an empty constructor
ptr2nonNullPtrMatrix43   virtual const T* ptr2nonNull() const
44   {
45     for (long i = 0; i < size(); i++) {
46       const T* pt = (*this)[i].ptr2nonNull();
47       if (pt != nullptr)
48         return pt;
49     }
50     return nullptr;
51   }
52 };
53 
54 template <typename T>
lsize(const PtrMatrix<T> & v)55 long lsize(const PtrMatrix<T>& v)
56 {
57   return v.size();
58 }
59 template <typename T>
resize(PtrMatrix<T> & v,long newSize)60 void resize(PtrMatrix<T>& v, long newSize)
61 {
62   v.resize(newSize);
63 }
64 template <typename T>
setLengthZero(PtrMatrix<T> & v)65 void setLengthZero(PtrMatrix<T>& v)
66 {
67   v.resize(0);
68 }
69 // implementation of resize function below
70 
71 // This header provides some implementations of these interfaces, but
72 // users can define their own as needed. The ones defined here are:
73 
74 // struct PtrMatrix_Vec;    // NTL::Vec<NTL::Vec<T>>
75 // struct PtrMatrix_vector; // std::vector<std::vector<T>>
76 // struct PtrMatrix_ptVec;    // NTL::Vec<NTL::Vec<T>*>
77 // struct PtrMatrix_ptvector; // std::vector<std::vector<T>*>
78 
79 template <typename T>
ptr2nonNull(std::initializer_list<const PtrVector<T> * > list)80 const T* ptr2nonNull(std::initializer_list<const PtrVector<T>*> list)
81 {
82   for (auto elem : list) {
83     const T* ptr = elem->ptr2nonNull();
84     if (ptr != nullptr)
85       return ptr;
86   }
87   return nullptr;
88 }
89 
90 /*******************************************************************/
91 /* Implementation details: applications should not care about them */
92 /*******************************************************************/
93 
94 //! @brief An implementation of PtrMatrix using Vec< Vec<T> >
95 template <typename T>
96 struct PtrMatrix_Vec : PtrMatrix<T>
97 {
98   NTL::Vec<NTL::Vec<T>>& buffer;
99   std::vector<PtrVector_VecT<T>> rows;
100   // rows[i] is a PtrVector_VecT<T> object 'pointing' to buffer[i]
101   // the above uses std::vector to be able to use emplace
102 
PtrMatrix_VecPtrMatrix_Vec103   PtrMatrix_Vec(NTL::Vec<NTL::Vec<T>>& mat) : buffer(mat)
104   {
105     rows.reserve(lsize(mat));            // allocate memory
106     for (int i = 0; i < lsize(mat); i++) // initialize
107       rows.emplace_back(buffer[i]);
108   }
109   PtrVector<T>& operator[](long i) override // returns a row
110   {
111     return rows[i];
112   }
113   const PtrVector<T>& operator[](long i) const override // returns a row
114   {
115     return rows[i];
116   }
sizePtrMatrix_Vec117   long size() const override { return lsize(rows); } // How many rows
resizePtrMatrix_Vec118   void resize(long newSize) override                 // reset the number of rows
119   {
120     long oldSize = size();
121     if (oldSize == newSize)
122       return; // nothing to do
123 
124     buffer.SetLength(newSize); // resize buffer, then add/delete 'pointers'
125     if (newSize > oldSize) {
126       rows.reserve(newSize);
127       for (int i = oldSize; i < newSize; i++)
128         rows.emplace_back(buffer[i]);
129     }
130     //    else rows.resize(newSize);
131     // Can't shrink without operator=
132     else {
133       std::cerr << "Attempt to shrink PtrMatrix_Vec failed\n";
134     }
135   }
136 };
137 
138 //! @brief An implementation of PtrMatrix using Vec< Vec<T>* >
139 template <typename T>
140 struct PtrMatrix_ptVec : PtrMatrix<T>
141 {
142   NTL::Vec<NTL::Vec<T>*>& buffer;
143   std::vector<PtrVector_VecT<T>> rows;
144   // rows[i] is a PtrVector_VecT<T> object 'pointing' to *buffer[i]
145   // the above uses std::vector to be able to use emplace
146 
PtrMatrix_ptVecPtrMatrix_ptVec147   PtrMatrix_ptVec(NTL::Vec<NTL::Vec<T>*>& mat) : buffer(mat)
148   {
149     rows.reserve(lsize(mat));            // allocate memory
150     for (int i = 0; i < lsize(mat); i++) // initialize
151       rows.emplace_back(*(buffer[i]));
152   }
153   PtrVector<T>& operator[](long i) override // returns a row
154   {
155     return rows[i];
156   }
157   const PtrVector<T>& operator[](long i) const override // returns a row
158   {
159     return rows[i];
160   }
sizePtrMatrix_ptVec161   long size() const override { return lsize(rows); } // How many rows
162 };
163 
164 //! @brief An implementation of PtrMatrix using vector< vector<T> >
165 template <typename T>
166 struct PtrMatrix_vector : PtrMatrix<T>
167 {
168   std::vector<std::vector<T>>& buffer;
169   std::vector<PtrVector_vectorT<T>> rows;
170   // rows[i] is a PtrVector_vectorT<T> object 'pointing' to buffer[i]
171 
PtrMatrix_vectorPtrMatrix_vector172   PtrMatrix_vector(std::vector<std::vector<T>>& mat) : buffer(mat)
173   {
174     rows.reserve(lsize(mat));            // allocate memory
175     for (int i = 0; i < lsize(mat); i++) // initialize
176       rows.emplace_back(buffer[i]);
177   }
178   PtrVector<T>& operator[](long i) override // returns a row
179   {
180     return rows[i];
181   }
182   const PtrVector<T>& operator[](long i) const override // returns a row
183   {
184     return rows[i];
185   }
sizePtrMatrix_vector186   long size() const override { return lsize(rows); } // How many rows
resizePtrMatrix_vector187   void resize(long newSize) override                 // reset the number of rows
188   {
189     long oldSize = size();
190     if (oldSize == newSize)
191       return; // nothing to do
192 
193     buffer.resize(newSize); // resize buffer, then add/delete 'pointers'
194     if (newSize > oldSize) {
195       rows.reserve(newSize);
196       for (int i = oldSize; i < newSize; i++)
197         rows.emplace_back(buffer[i]);
198     }
199     //    else rows.resize(newSize);
200     // Can't shrink without operator=
201     else {
202       std::cerr << "Attempt to shrink PtrMatrix_vector failed\n";
203     }
204   }
205 };
206 
207 //! @brief An implementation of PtrMatrix using vector< vector<T>* >
208 template <typename T>
209 struct PtrMatrix_ptvector : PtrMatrix<T>
210 {
211   std::vector<std::vector<T>*>& buffer;
212   std::vector<PtrVector_vectorT<T>> rows;
213   // rows[i] is a PtrVector_vectorT<T> object 'pointing' to *buffer[i]
214 
PtrMatrix_ptvectorPtrMatrix_ptvector215   PtrMatrix_ptvector(std::vector<std::vector<T>*>& mat) : buffer(mat)
216   {
217     rows.reserve(lsize(mat));            // allocate memory
218     for (int i = 0; i < lsize(mat); i++) // initialize
219       rows.emplace_back(*(buffer[i]));
220   }
221   PtrVector<T>& operator[](long i) override // returns a row
222   {
223     return rows[i];
224   }
225   const PtrVector<T>& operator[](long i) const override // returns a row
226   {
227     return rows[i];
228   }
sizePtrMatrix_ptvector229   long size() const override { return lsize(rows); } // How many rows
230 };
231 
232 } // namespace helib
233 
234 #endif // ifndef HELIB_PTRMATRIX_H
235