1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup op_index_max
20 //! @{
21
22
23
24 template<typename T1>
25 inline
26 void
apply(Mat<uword> & out,const mtOp<uword,T1,op_index_max> & in)27 op_index_max::apply(Mat<uword>& out, const mtOp<uword,T1,op_index_max>& in)
28 {
29 arma_extra_debug_sigprint();
30
31 typedef typename T1::elem_type eT;
32
33 const uword dim = in.aux_uword_a;
34 arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1" );
35
36 const quasi_unwrap<T1> U(in.m);
37 const Mat<eT>& X = U.M;
38
39 if(U.is_alias(out) == false)
40 {
41 op_index_max::apply_noalias(out, X, dim);
42 }
43 else
44 {
45 Mat<uword> tmp;
46
47 op_index_max::apply_noalias(tmp, X, dim);
48
49 out.steal_mem(tmp);
50 }
51 }
52
53
54
55 template<typename eT>
56 inline
57 void
apply_noalias(Mat<uword> & out,const Mat<eT> & X,const uword dim)58 op_index_max::apply_noalias(Mat<uword>& out, const Mat<eT>& X, const uword dim)
59 {
60 arma_extra_debug_sigprint();
61
62 typedef typename get_pod_type<eT>::result T;
63
64 const uword X_n_rows = X.n_rows;
65 const uword X_n_cols = X.n_cols;
66
67 if(dim == 0)
68 {
69 arma_extra_debug_print("op_index_max::apply(): dim = 0");
70
71 out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols);
72
73 if(X_n_rows == 0) { return; }
74
75 uword* out_mem = out.memptr();
76
77 for(uword col=0; col < X_n_cols; ++col)
78 {
79 op_max::direct_max( X.colptr(col), X_n_rows, out_mem[col] );
80 }
81 }
82 else
83 if(dim == 1)
84 {
85 arma_extra_debug_print("op_index_max::apply(): dim = 1");
86
87 out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0);
88
89 if(X_n_cols == 0) { return; }
90
91 uword* out_mem = out.memptr();
92
93 Col<T> tmp(X_n_rows, arma_nozeros_indicator());
94
95 T* tmp_mem = tmp.memptr();
96
97 if(is_cx<eT>::yes)
98 {
99 const eT* col_mem = X.colptr(0);
100
101 for(uword row=0; row < X_n_rows; ++row)
102 {
103 tmp_mem[row] = eop_aux::arma_abs(col_mem[row]);
104 }
105 }
106 else
107 {
108 arrayops::copy(tmp_mem, (T*)(X.colptr(0)), X_n_rows);
109 }
110
111 for(uword col=1; col < X_n_cols; ++col)
112 {
113 const eT* col_mem = X.colptr(col);
114
115 for(uword row=0; row < X_n_rows; ++row)
116 {
117 T& max_val = tmp_mem[row];
118 T col_val = (is_cx<eT>::yes) ? T(eop_aux::arma_abs(col_mem[row])) : T(access::tmp_real(col_mem[row]));
119
120 if(max_val < col_val)
121 {
122 max_val = col_val;
123
124 out_mem[row] = col;
125 }
126 }
127 }
128 }
129 }
130
131
132
133 template<typename T1>
134 inline
135 void
apply(Cube<uword> & out,const mtOpCube<uword,T1,op_index_max> & in)136 op_index_max::apply(Cube<uword>& out, const mtOpCube<uword, T1, op_index_max>& in)
137 {
138 arma_extra_debug_sigprint();
139
140 const uword dim = in.aux_uword_a;
141 arma_debug_check( (dim > 2), "index_max(): parameter 'dim' must be 0 or 1 or 2" );
142
143 const unwrap_cube<T1> U(in.m);
144
145 if(U.is_alias(out) == false)
146 {
147 op_index_max::apply_noalias(out, U.M, dim);
148 }
149 else
150 {
151 Cube<uword> tmp;
152
153 op_index_max::apply_noalias(tmp, U.M, dim);
154
155 out.steal_mem(tmp);
156 }
157 }
158
159
160
161 template<typename eT>
162 inline
163 void
apply_noalias(Cube<uword> & out,const Cube<eT> & X,const uword dim,const typename arma_not_cx<eT>::result * junk)164 op_index_max::apply_noalias(Cube<uword>& out, const Cube<eT>& X, const uword dim, const typename arma_not_cx<eT>::result* junk)
165 {
166 arma_extra_debug_sigprint();
167 arma_ignore(junk);
168
169 const uword X_n_rows = X.n_rows;
170 const uword X_n_cols = X.n_cols;
171 const uword X_n_slices = X.n_slices;
172
173 if(dim == 0)
174 {
175 arma_extra_debug_print("op_index_max::apply(): dim = 0");
176
177 out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices);
178
179 if(out.is_empty() || X.is_empty()) { return; }
180
181 for(uword slice=0; slice < X_n_slices; ++slice)
182 {
183 uword* out_mem = out.slice_memptr(slice);
184
185 for(uword col=0; col < X_n_cols; ++col)
186 {
187 op_max::direct_max( X.slice_colptr(slice,col), X_n_rows, out_mem[col] );
188 }
189 }
190 }
191 else
192 if(dim == 1)
193 {
194 arma_extra_debug_print("op_index_max::apply(): dim = 1");
195
196 out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices);
197
198 if(out.is_empty() || X.is_empty()) { return; }
199
200 Col<eT> tmp(X_n_rows, arma_nozeros_indicator());
201
202 eT* tmp_mem = tmp.memptr();
203
204 for(uword slice=0; slice < X_n_slices; ++slice)
205 {
206 uword* out_mem = out.slice_memptr(slice);
207
208 arrayops::copy(tmp_mem, X.slice_colptr(slice,0), X_n_rows);
209
210 for(uword col=1; col < X_n_cols; ++col)
211 {
212 const eT* col_mem = X.slice_colptr(slice,col);
213
214 for(uword row=0; row < X_n_rows; ++row)
215 {
216 const eT val = col_mem[row];
217
218 if(val > tmp_mem[row])
219 {
220 tmp_mem[row] = val;
221 out_mem[row] = col;
222 }
223 }
224 }
225 }
226 }
227 else
228 if(dim == 2)
229 {
230 arma_extra_debug_print("op_index_max::apply(): dim = 2");
231
232 out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0);
233
234 if(out.is_empty() || X.is_empty()) { return; }
235
236 Mat<eT> tmp(X.slice_memptr(0), X_n_rows, X_n_cols); // copy slice 0
237
238 eT* tmp_mem = tmp.memptr();
239 uword* out_mem = out.memptr();
240
241 const uword N = X.n_elem_slice;
242
243 for(uword slice=1; slice < X_n_slices; ++slice)
244 {
245 const eT* X_slice_mem = X.slice_memptr(slice);
246
247 for(uword i=0; i < N; ++i)
248 {
249 const eT val = X_slice_mem[i];
250
251 if(val > tmp_mem[i])
252 {
253 tmp_mem[i] = val;
254 out_mem[i] = slice;
255 }
256 }
257 }
258 }
259 }
260
261
262
263 template<typename eT>
264 inline
265 void
apply_noalias(Cube<uword> & out,const Cube<eT> & X,const uword dim,const typename arma_cx_only<eT>::result * junk)266 op_index_max::apply_noalias(Cube<uword>& out, const Cube<eT>& X, const uword dim, const typename arma_cx_only<eT>::result* junk)
267 {
268 arma_extra_debug_sigprint();
269 arma_ignore(junk);
270
271 typedef typename get_pod_type<eT>::result T;
272
273 const uword X_n_rows = X.n_rows;
274 const uword X_n_cols = X.n_cols;
275 const uword X_n_slices = X.n_slices;
276
277 if(dim == 0)
278 {
279 arma_extra_debug_print("op_index_max::apply(): dim = 0");
280
281 out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols, X_n_slices);
282
283 if(out.is_empty() || X.is_empty()) { return; }
284
285 for(uword slice=0; slice < X_n_slices; ++slice)
286 {
287 uword* out_mem = out.slice_memptr(slice);
288
289 for(uword col=0; col < X_n_cols; ++col)
290 {
291 op_max::direct_max( X.slice_colptr(slice,col), X_n_rows, out_mem[col] );
292 }
293 }
294 }
295 else
296 if(dim == 1)
297 {
298 arma_extra_debug_print("op_index_max::apply(): dim = 1");
299
300 out.zeros(X_n_rows, (X_n_cols > 0) ? 1 : 0, X_n_slices);
301
302 if(out.is_empty() || X.is_empty()) { return; }
303
304 Col<T> tmp(X_n_rows, arma_nozeros_indicator());
305
306 T* tmp_mem = tmp.memptr();
307
308 for(uword slice=0; slice < X_n_slices; ++slice)
309 {
310 uword* out_mem = out.slice_memptr(slice);
311
312 const eT* col0_mem = X.slice_colptr(slice,0);
313
314 for(uword row=0; row < X_n_rows; ++row)
315 {
316 tmp_mem[row] = std::abs( col0_mem[row] );
317 }
318
319 for(uword col=1; col < X_n_cols; ++col)
320 {
321 const eT* col_mem = X.slice_colptr(slice,col);
322
323 for(uword row=0; row < X_n_rows; ++row)
324 {
325 const T val = std::abs( col_mem[row] );
326
327 if(val > tmp_mem[row])
328 {
329 tmp_mem[row] = val;
330 out_mem[row] = col;
331 }
332 }
333 }
334 }
335 }
336 else
337 if(dim == 2)
338 {
339 arma_extra_debug_print("op_index_max::apply(): dim = 2");
340
341 out.zeros(X_n_rows, X_n_cols, (X_n_slices > 0) ? 1 : 0);
342
343 if(out.is_empty() || X.is_empty()) { return; }
344
345 uword* out_mem = out.memptr();
346
347 Mat<T> tmp(X_n_rows, X_n_cols, arma_nozeros_indicator());
348
349 T* tmp_mem = tmp.memptr();
350 const eT* X_slice0_mem = X.slice_memptr(0);
351
352 const uword N = X.n_elem_slice;
353
354 for(uword i=0; i<N; ++i)
355 {
356 tmp_mem[i] = std::abs( X_slice0_mem[i] );
357 }
358
359 for(uword slice=1; slice < X_n_slices; ++slice)
360 {
361 const eT* X_slice_mem = X.slice_memptr(slice);
362
363 for(uword i=0; i < N; ++i)
364 {
365 const T val = std::abs( X_slice_mem[i] );
366
367 if(val > tmp_mem[i])
368 {
369 tmp_mem[i] = val;
370 out_mem[i] = slice;
371 }
372 }
373 }
374 }
375 }
376
377
378
379 template<typename T1>
380 inline
381 void
apply(Mat<uword> & out,const SpBase<typename T1::elem_type,T1> & expr,const uword dim)382 op_index_max::apply(Mat<uword>& out, const SpBase<typename T1::elem_type,T1>& expr, const uword dim)
383 {
384 arma_extra_debug_sigprint();
385
386 typedef typename T1::elem_type eT;
387
388 arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1" );
389
390 const unwrap_spmat<T1> U(expr.get_ref());
391 const SpMat<eT>& X = U.M;
392
393 const uword X_n_rows = X.n_rows;
394 const uword X_n_cols = X.n_cols;
395
396 if(dim == 0)
397 {
398 arma_extra_debug_print("op_index_max::apply(): dim = 0");
399
400 out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols);
401
402 if(X_n_rows == 0) { return; }
403
404 uword* out_mem = out.memptr();
405
406 for(uword col=0; col < X_n_cols; ++col)
407 {
408 out_mem[col] = X.col(col).index_max();
409 }
410 }
411 else
412 if(dim == 1)
413 {
414 arma_extra_debug_print("op_index_max::apply(): dim = 1");
415
416 out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0);
417
418 if(X_n_cols == 0) { return; }
419
420 uword* out_mem = out.memptr();
421
422 const SpMat<eT> Xt = X.st();
423
424 for(uword row=0; row < X_n_rows; ++row)
425 {
426 out_mem[row] = Xt.col(row).index_max();
427 }
428 }
429 }
430
431
432
433 //! @}
434