1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup spglue_schur
20 //! @{
21
22
23
24 template<typename T1, typename T2>
25 arma_hot
26 inline
27 void
apply(SpMat<typename T1::elem_type> & out,const SpGlue<T1,T2,spglue_schur> & X)28 spglue_schur::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_schur>& X)
29 {
30 arma_extra_debug_sigprint();
31
32 typedef typename T1::elem_type eT;
33
34 const SpProxy<T1> pa(X.A);
35 const SpProxy<T2> pb(X.B);
36
37 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
38
39 if(is_alias == false)
40 {
41 spglue_schur::apply_noalias(out, pa, pb);
42 }
43 else
44 {
45 SpMat<eT> tmp;
46
47 spglue_schur::apply_noalias(tmp, pa, pb);
48
49 out.steal_mem(tmp);
50 }
51 }
52
53
54
55 template<typename eT, typename T1, typename T2>
56 arma_hot
57 inline
58 void
apply_noalias(SpMat<eT> & out,const SpProxy<T1> & pa,const SpProxy<T2> & pb)59 spglue_schur::apply_noalias(SpMat<eT>& out, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
60 {
61 arma_extra_debug_sigprint();
62
63 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
64
65 if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
66 {
67 out.zeros(pa.get_n_rows(), pa.get_n_cols());
68 return;
69 }
70
71 const uword max_n_nonzero = (std::min)(pa.get_n_nonzero(), pb.get_n_nonzero());
72
73 // Resize memory to upper bound
74 out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
75
76 // Now iterate across both matrices.
77 typename SpProxy<T1>::const_iterator_type x_it = pa.begin();
78 typename SpProxy<T1>::const_iterator_type x_end = pa.end();
79
80 typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
81 typename SpProxy<T2>::const_iterator_type y_end = pb.end();
82
83 uword count = 0;
84
85 while( (x_it != x_end) || (y_it != y_end) )
86 {
87 const uword x_it_row = x_it.row();
88 const uword x_it_col = x_it.col();
89
90 const uword y_it_row = y_it.row();
91 const uword y_it_col = y_it.col();
92
93 if(x_it == y_it)
94 {
95 const eT out_val = (*x_it) * (*y_it);
96
97 if(out_val != eT(0))
98 {
99 access::rw(out.values[count]) = out_val;
100
101 access::rw(out.row_indices[count]) = x_it_row;
102 access::rw(out.col_ptrs[x_it_col + 1])++;
103 ++count;
104 }
105
106 ++x_it;
107 ++y_it;
108 }
109 else
110 {
111 if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
112 {
113 ++x_it;
114 }
115 else
116 {
117 ++y_it;
118 }
119 }
120
121 arma_check( (count > max_n_nonzero), "internal error: spglue_schur::apply_noalias(): count > max_n_nonzero" );
122 }
123
124 const uword out_n_cols = out.n_cols;
125
126 uword* col_ptrs = access::rwp(out.col_ptrs);
127
128 // Fix column pointers to be cumulative.
129 for(uword c = 1; c <= out_n_cols; ++c)
130 {
131 col_ptrs[c] += col_ptrs[c - 1];
132 }
133
134 if(count < max_n_nonzero)
135 {
136 if(count <= (max_n_nonzero/2))
137 {
138 out.mem_resize(count);
139 }
140 else
141 {
142 // quick resize without reallocating memory and copying data
143 access::rw( out.n_nonzero) = count;
144 access::rw( out.values[count]) = eT(0);
145 access::rw(out.row_indices[count]) = uword(0);
146 }
147 }
148 }
149
150
151
152 template<typename eT>
153 arma_hot
154 inline
155 void
apply_noalias(SpMat<eT> & out,const SpMat<eT> & A,const SpMat<eT> & B)156 spglue_schur::apply_noalias(SpMat<eT>& out, const SpMat<eT>& A, const SpMat<eT>& B)
157 {
158 arma_extra_debug_sigprint();
159
160 const SpProxy< SpMat<eT> > pa(A);
161 const SpProxy< SpMat<eT> > pb(B);
162
163 spglue_schur::apply_noalias(out, pa, pb);
164 }
165
166
167
168 //
169 //
170 //
171
172
173
174 template<typename T1, typename T2>
175 inline
176 void
dense_schur_sparse(SpMat<typename T1::elem_type> & out,const T1 & x,const T2 & y)177 spglue_schur_misc::dense_schur_sparse(SpMat<typename T1::elem_type>& out, const T1& x, const T2& y)
178 {
179 arma_extra_debug_sigprint();
180
181 typedef typename T1::elem_type eT;
182
183 const Proxy<T1> pa(x);
184 const SpProxy<T2> pb(y);
185
186 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
187
188 const uword max_n_nonzero = pb.get_n_nonzero();
189
190 // Resize memory to upper bound.
191 out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero);
192
193 uword count = 0;
194
195 typename SpProxy<T2>::const_iterator_type it = pb.begin();
196 typename SpProxy<T2>::const_iterator_type it_end = pb.end();
197
198 while(it != it_end)
199 {
200 const uword it_row = it.row();
201 const uword it_col = it.col();
202
203 const eT val = (*it) * pa.at(it_row, it_col);
204
205 if(val != eT(0))
206 {
207 access::rw( out.values[count]) = val;
208 access::rw( out.row_indices[count]) = it_row;
209 access::rw(out.col_ptrs[it_col + 1])++;
210 ++count;
211 }
212
213 ++it;
214
215 arma_check( (count > max_n_nonzero), "internal error: spglue_schur_misc::dense_schur_sparse(): count > max_n_nonzero" );
216 }
217
218 // Fix column pointers.
219 for(uword c = 1; c <= out.n_cols; ++c)
220 {
221 access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
222 }
223
224 if(count < max_n_nonzero)
225 {
226 if(count <= (max_n_nonzero/2))
227 {
228 out.mem_resize(count);
229 }
230 else
231 {
232 // quick resize without reallocating memory and copying data
233 access::rw( out.n_nonzero) = count;
234 access::rw( out.values[count]) = eT(0);
235 access::rw(out.row_indices[count]) = uword(0);
236 }
237 }
238 }
239
240
241
242 //
243
244
245
246 template<typename T1, typename T2>
247 inline
248 void
apply(SpMat<typename eT_promoter<T1,T2>::eT> & out,const mtSpGlue<typename eT_promoter<T1,T2>::eT,T1,T2,spglue_schur_mixed> & expr)249 spglue_schur_mixed::apply(SpMat<typename eT_promoter<T1,T2>::eT>& out, const mtSpGlue<typename eT_promoter<T1,T2>::eT, T1, T2, spglue_schur_mixed>& expr)
250 {
251 arma_extra_debug_sigprint();
252
253 typedef typename T1::elem_type eT1;
254 typedef typename T2::elem_type eT2;
255
256 typedef typename promote_type<eT1,eT2>::result out_eT;
257
258 promote_type<eT1,eT2>::check();
259
260 if( (is_same_type<eT1,out_eT>::no) && (is_same_type<eT2,out_eT>::yes) )
261 {
262 // upgrade T1
263
264 const unwrap_spmat<T1> UA(expr.A);
265 const unwrap_spmat<T2> UB(expr.B);
266
267 const SpMat<eT1>& A = UA.M;
268 const SpMat<eT2>& B = UB.M;
269
270 SpMat<out_eT> AA(arma_layout_indicator(), A);
271
272 for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
273
274 const SpMat<out_eT>& BB = reinterpret_cast< const SpMat<out_eT>& >(B);
275
276 out = AA % BB;
277 }
278 else
279 if( (is_same_type<eT1,out_eT>::yes) && (is_same_type<eT2,out_eT>::no) )
280 {
281 // upgrade T2
282
283 const unwrap_spmat<T1> UA(expr.A);
284 const unwrap_spmat<T2> UB(expr.B);
285
286 const SpMat<eT1>& A = UA.M;
287 const SpMat<eT2>& B = UB.M;
288
289 const SpMat<out_eT>& AA = reinterpret_cast< const SpMat<out_eT>& >(A);
290
291 SpMat<out_eT> BB(arma_layout_indicator(), B);
292
293 for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
294
295 out = AA % BB;
296 }
297 else
298 {
299 // upgrade T1 and T2
300
301 const unwrap_spmat<T1> UA(expr.A);
302 const unwrap_spmat<T2> UB(expr.B);
303
304 const SpMat<eT1>& A = UA.M;
305 const SpMat<eT2>& B = UB.M;
306
307 SpMat<out_eT> AA(arma_layout_indicator(), A);
308 SpMat<out_eT> BB(arma_layout_indicator(), B);
309
310 for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); }
311 for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); }
312
313 out = AA % BB;
314 }
315 }
316
317
318
319 template<typename T1, typename T2>
320 inline
321 void
dense_schur_sparse(SpMat<typename promote_type<typename T1::elem_type,typename T2::elem_type>::result> & out,const T1 & X,const T2 & Y)322 spglue_schur_mixed::dense_schur_sparse(SpMat< typename promote_type<typename T1::elem_type, typename T2::elem_type >::result>& out, const T1& X, const T2& Y)
323 {
324 arma_extra_debug_sigprint();
325
326 typedef typename T1::elem_type eT1;
327 typedef typename T2::elem_type eT2;
328
329 typedef typename promote_type<eT1,eT2>::result out_eT;
330
331 promote_type<eT1,eT2>::check();
332
333 const Proxy<T1> pa(X);
334 const SpProxy<T2> pb(Y);
335
336 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
337
338 // count new size
339 uword new_n_nonzero = 0;
340
341 typename SpProxy<T2>::const_iterator_type it = pb.begin();
342 typename SpProxy<T2>::const_iterator_type it_end = pb.end();
343
344 while(it != it_end)
345 {
346 if( (out_eT(*it) * out_eT(pa.at(it.row(), it.col()))) != out_eT(0) ) { ++new_n_nonzero; }
347
348 ++it;
349 }
350
351 // Resize memory accordingly.
352 out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero);
353
354 uword count = 0;
355
356 typename SpProxy<T2>::const_iterator_type it2 = pb.begin();
357
358 while(it2 != it_end)
359 {
360 const uword it2_row = it2.row();
361 const uword it2_col = it2.col();
362
363 const out_eT val = out_eT(*it2) * out_eT(pa.at(it2_row, it2_col));
364
365 if(val != out_eT(0))
366 {
367 access::rw( out.values[count]) = val;
368 access::rw( out.row_indices[count]) = it2_row;
369 access::rw(out.col_ptrs[it2_col + 1])++;
370 ++count;
371 }
372
373 ++it2;
374 }
375
376 // Fix column pointers.
377 for(uword c = 1; c <= out.n_cols; ++c)
378 {
379 access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1];
380 }
381 }
382
383
384
385 //! @}
386