1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup spglue_schur
20 //! @{
21 
22 
23 
24 template<typename T1, typename T2>
25 arma_hot
26 inline
27 void
apply(SpMat<typename T1::elem_type> & out,const SpGlue<T1,T2,spglue_schur> & X)28 spglue_schur::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_schur>& X)
29   {
30   arma_extra_debug_sigprint();
31 
32   typedef typename T1::elem_type eT;
33 
34   const SpProxy<T1> pa(X.A);
35   const SpProxy<T2> pb(X.B);
36 
37   const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
38 
39   if(is_alias == false)
40     {
41     spglue_schur::apply_noalias(out, pa, pb);
42     }
43   else
44     {
45     SpMat<eT> tmp;
46 
47     spglue_schur::apply_noalias(tmp, pa, pb);
48 
49     out.steal_mem(tmp);
50     }
51   }
52 
53 
54 
55 template<typename eT, typename T1, typename T2>
56 arma_hot
57 inline
58 void
apply_noalias(SpMat<eT> & out,const SpProxy<T1> & pa,const SpProxy<T2> & pb)59 spglue_schur::apply_noalias(SpMat<eT>& out, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
60   {
61   arma_extra_debug_sigprint();
62 
63   arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
64 
65   if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
66     {
67     out.zeros(pa.get_n_rows(), pa.get_n_cols());
68     return;
69     }
70 
71   const uword max_n_nonzero = (std::min)(pa.get_n_nonzero(), pb.get_n_nonzero());
72 
73   // Resize memory to upper bound
74   out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
75 
76   // Now iterate across both matrices.
77   typename SpProxy<T1>::const_iterator_type x_it  = pa.begin();
78   typename SpProxy<T1>::const_iterator_type x_end = pa.end();
79 
80   typename SpProxy<T2>::const_iterator_type y_it  = pb.begin();
81   typename SpProxy<T2>::const_iterator_type y_end = pb.end();
82 
83   uword count = 0;
84 
85   while( (x_it != x_end) || (y_it != y_end) )
86     {
87     const uword x_it_row = x_it.row();
88     const uword x_it_col = x_it.col();
89 
90     const uword y_it_row = y_it.row();
91     const uword y_it_col = y_it.col();
92 
93     if(x_it == y_it)
94       {
95       const eT out_val = (*x_it) * (*y_it);
96 
97       if(out_val != eT(0))
98         {
99         access::rw(out.values[count]) = out_val;
100 
101         access::rw(out.row_indices[count]) = x_it_row;
102         access::rw(out.col_ptrs[x_it_col + 1])++;
103         ++count;
104         }
105 
106       ++x_it;
107       ++y_it;
108       }
109     else
110       {
111       if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
112         {
113         ++x_it;
114         }
115       else
116         {
117         ++y_it;
118         }
119       }
120 
121     arma_check( (count > max_n_nonzero), "internal error: spglue_schur::apply_noalias(): count > max_n_nonzero" );
122     }
123 
124   const uword out_n_cols = out.n_cols;
125 
126   uword* col_ptrs = access::rwp(out.col_ptrs);
127 
128   // Fix column pointers to be cumulative.
129   for(uword c = 1; c <= out_n_cols; ++c)
130     {
131     col_ptrs[c] += col_ptrs[c - 1];
132     }
133 
134   if(count < max_n_nonzero)
135     {
136     if(count <= (max_n_nonzero/2))
137       {
138       out.mem_resize(count);
139       }
140     else
141       {
142       // quick resize without reallocating memory and copying data
143       access::rw(         out.n_nonzero) = count;
144       access::rw(     out.values[count]) = eT(0);
145       access::rw(out.row_indices[count]) = uword(0);
146       }
147     }
148   }
149 
150 
151 
152 template<typename eT>
153 arma_hot
154 inline
155 void
apply_noalias(SpMat<eT> & out,const SpMat<eT> & A,const SpMat<eT> & B)156 spglue_schur::apply_noalias(SpMat<eT>& out, const SpMat<eT>& A, const SpMat<eT>& B)
157   {
158   arma_extra_debug_sigprint();
159 
160   const SpProxy< SpMat<eT> > pa(A);
161   const SpProxy< SpMat<eT> > pb(B);
162 
163   spglue_schur::apply_noalias(out, pa, pb);
164   }
165 
166 
167 
168 //
169 //
170 //
171 
172 
173 
174 template<typename T1, typename T2>
175 inline
176 void
dense_schur_sparse(SpMat<typename T1::elem_type> & out,const T1 & x,const T2 & y)177 spglue_schur_misc::dense_schur_sparse(SpMat<typename T1::elem_type>& out, const T1& x, const T2& y)
178   {
179   arma_extra_debug_sigprint();
180 
181   typedef typename T1::elem_type eT;
182 
183   const   Proxy<T1> pa(x);
184   const SpProxy<T2> pb(y);
185 
186   arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
187 
188   const uword max_n_nonzero = pb.get_n_nonzero();
189 
190   // Resize memory to upper bound.
191   out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
192 
193   uword count = 0;
194 
195   typename SpProxy<T2>::const_iterator_type it     = pb.begin();
196   typename SpProxy<T2>::const_iterator_type it_end = pb.end();
197 
198   while(it != it_end)
199     {
200     const uword it_row = it.row();
201     const uword it_col = it.col();
202 
203     const eT val = (*it) * pa.at(it_row, it_col);
204 
205     if(val != eT(0))
206       {
207       access::rw(        out.values[count]) = val;
208       access::rw(   out.row_indices[count]) = it_row;
209       access::rw(out.col_ptrs[it_col + 1])++;
210       ++count;
211       }
212 
213     ++it;
214 
215     arma_check( (count > max_n_nonzero), "internal error: spglue_schur_misc::dense_schur_sparse(): count > max_n_nonzero" );
216     }
217 
218   // Fix column pointers.
219   for(uword c = 1; c <= out.n_cols; ++c)
220     {
221     access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
222     }
223 
224   if(count < max_n_nonzero)
225     {
226     if(count <= (max_n_nonzero/2))
227       {
228       out.mem_resize(count);
229       }
230     else
231       {
232       // quick resize without reallocating memory and copying data
233       access::rw(         out.n_nonzero) = count;
234       access::rw(     out.values[count]) = eT(0);
235       access::rw(out.row_indices[count]) = uword(0);
236       }
237     }
238   }
239 
240 
241 
242 //
243 
244 
245 
246 template<typename T1, typename T2>
247 inline
248 void
apply(SpMat<typename eT_promoter<T1,T2>::eT> & out,const mtSpGlue<typename eT_promoter<T1,T2>::eT,T1,T2,spglue_schur_mixed> & expr)249 spglue_schur_mixed::apply(SpMat<typename eT_promoter<T1,T2>::eT>& out, const mtSpGlue<typename eT_promoter<T1,T2>::eT, T1, T2, spglue_schur_mixed>& expr)
250   {
251   arma_extra_debug_sigprint();
252 
253   typedef typename T1::elem_type eT1;
254   typedef typename T2::elem_type eT2;
255 
256   typedef typename promote_type<eT1,eT2>::result out_eT;
257 
258   promote_type<eT1,eT2>::check();
259 
260   if( (is_same_type<eT1,out_eT>::no) && (is_same_type<eT2,out_eT>::yes) )
261     {
262     // upgrade T1
263 
264     const unwrap_spmat<T1> UA(expr.A);
265     const unwrap_spmat<T2> UB(expr.B);
266 
267     const SpMat<eT1>& A = UA.M;
268     const SpMat<eT2>& B = UB.M;
269 
270     SpMat<out_eT> AA(arma_layout_indicator(), A);
271 
272     for(uword i=0; i < A.n_nonzero; ++i)  { access::rw(AA.values[i]) = out_eT(A.values[i]); }
273 
274     const SpMat<out_eT>& BB = reinterpret_cast< const SpMat<out_eT>& >(B);
275 
276     out = AA % BB;
277     }
278   else
279   if( (is_same_type<eT1,out_eT>::yes) && (is_same_type<eT2,out_eT>::no) )
280     {
281     // upgrade T2
282 
283     const unwrap_spmat<T1> UA(expr.A);
284     const unwrap_spmat<T2> UB(expr.B);
285 
286     const SpMat<eT1>& A = UA.M;
287     const SpMat<eT2>& B = UB.M;
288 
289     const SpMat<out_eT>& AA = reinterpret_cast< const SpMat<out_eT>& >(A);
290 
291     SpMat<out_eT> BB(arma_layout_indicator(), B);
292 
293     for(uword i=0; i < B.n_nonzero; ++i)  { access::rw(BB.values[i]) = out_eT(B.values[i]); }
294 
295     out = AA % BB;
296     }
297   else
298     {
299     // upgrade T1 and T2
300 
301     const unwrap_spmat<T1> UA(expr.A);
302     const unwrap_spmat<T2> UB(expr.B);
303 
304     const SpMat<eT1>& A = UA.M;
305     const SpMat<eT2>& B = UB.M;
306 
307     SpMat<out_eT> AA(arma_layout_indicator(), A);
308     SpMat<out_eT> BB(arma_layout_indicator(), B);
309 
310     for(uword i=0; i < A.n_nonzero; ++i)  { access::rw(AA.values[i]) = out_eT(A.values[i]); }
311     for(uword i=0; i < B.n_nonzero; ++i)  { access::rw(BB.values[i]) = out_eT(B.values[i]); }
312 
313     out = AA % BB;
314     }
315   }
316 
317 
318 
319 template<typename T1, typename T2>
320 inline
321 void
dense_schur_sparse(SpMat<typename promote_type<typename T1::elem_type,typename T2::elem_type>::result> & out,const T1 & X,const T2 & Y)322 spglue_schur_mixed::dense_schur_sparse(SpMat< typename promote_type<typename T1::elem_type, typename T2::elem_type >::result>& out, const T1& X, const T2& Y)
323   {
324   arma_extra_debug_sigprint();
325 
326   typedef typename T1::elem_type eT1;
327   typedef typename T2::elem_type eT2;
328 
329   typedef typename promote_type<eT1,eT2>::result out_eT;
330 
331   promote_type<eT1,eT2>::check();
332 
333   const   Proxy<T1> pa(X);
334   const SpProxy<T2> pb(Y);
335 
336   arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
337 
338   // count new size
339   uword new_n_nonzero = 0;
340 
341   typename SpProxy<T2>::const_iterator_type it     = pb.begin();
342   typename SpProxy<T2>::const_iterator_type it_end = pb.end();
343 
344   while(it != it_end)
345     {
346     if( (out_eT(*it) * out_eT(pa.at(it.row(), it.col()))) != out_eT(0) )  { ++new_n_nonzero; }
347 
348     ++it;
349     }
350 
351   // Resize memory accordingly.
352   out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero);
353 
354   uword count = 0;
355 
356   typename SpProxy<T2>::const_iterator_type it2 = pb.begin();
357 
358   while(it2 != it_end)
359     {
360     const uword it2_row = it2.row();
361     const uword it2_col = it2.col();
362 
363     const out_eT val = out_eT(*it2) * out_eT(pa.at(it2_row, it2_col));
364 
365     if(val != out_eT(0))
366       {
367       access::rw(        out.values[count]) = val;
368       access::rw(   out.row_indices[count]) = it2_row;
369       access::rw(out.col_ptrs[it2_col + 1])++;
370       ++count;
371       }
372 
373     ++it2;
374     }
375 
376   // Fix column pointers.
377   for(uword c = 1; c <= out.n_cols; ++c)
378     {
379     access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
380     }
381   }
382 
383 
384 
385 //! @}
386