1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup fn_find
20 //! @{
21
22
23
24 template<typename T1>
25 arma_warn_unused
26 inline
27 typename
28 enable_if2
29 <
30 is_arma_type<T1>::value,
31 const mtOp<uword, T1, op_find_simple>
32 >::result
find(const T1 & X)33 find(const T1& X)
34 {
35 arma_extra_debug_sigprint();
36
37 return mtOp<uword, T1, op_find_simple>(X);
38 }
39
40
41
42 template<typename T1>
43 arma_warn_unused
44 inline
45 const mtOp<uword, T1, op_find>
find(const Base<typename T1::elem_type,T1> & X,const uword k,const char * direction="first")46 find(const Base<typename T1::elem_type,T1>& X, const uword k, const char* direction = "first")
47 {
48 arma_extra_debug_sigprint();
49
50 const char sig = (direction != nullptr) ? direction[0] : char(0);
51
52 arma_debug_check
53 (
54 ( (sig != 'f') && (sig != 'F') && (sig != 'l') && (sig != 'L') ),
55 "find(): direction must be \"first\" or \"last\""
56 );
57
58 const uword type = ( (sig == 'f') || (sig == 'F') ) ? 0 : 1;
59
60 return mtOp<uword, T1, op_find>(X.get_ref(), k, type);
61 }
62
63
64
65 //
66
67
68
69 template<typename T1>
70 arma_warn_unused
71 inline
72 uvec
find(const BaseCube<typename T1::elem_type,T1> & X)73 find(const BaseCube<typename T1::elem_type,T1>& X)
74 {
75 arma_extra_debug_sigprint();
76
77 typedef typename T1::elem_type eT;
78
79 const unwrap_cube<T1> tmp(X.get_ref());
80
81 const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
82
83 return find(R);
84 }
85
86
87
88 template<typename T1>
89 arma_warn_unused
90 inline
91 uvec
find(const BaseCube<typename T1::elem_type,T1> & X,const uword k,const char * direction="first")92 find(const BaseCube<typename T1::elem_type,T1>& X, const uword k, const char* direction = "first")
93 {
94 arma_extra_debug_sigprint();
95
96 typedef typename T1::elem_type eT;
97
98 const unwrap_cube<T1> tmp(X.get_ref());
99
100 const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
101
102 return find(R, k, direction);
103 }
104
105
106
107 template<typename T1, typename op_rel_type>
108 arma_warn_unused
109 inline
110 uvec
find(const mtOpCube<uword,T1,op_rel_type> & X,const uword k=0,const char * direction="first")111 find(const mtOpCube<uword, T1, op_rel_type>& X, const uword k = 0, const char* direction = "first")
112 {
113 arma_extra_debug_sigprint();
114
115 typedef typename T1::elem_type eT;
116
117 const unwrap_cube<T1> tmp(X.m);
118
119 const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
120
121 return find( mtOp<uword, Mat<eT>, op_rel_type>(R, X.aux), k, direction );
122 }
123
124
125
126 template<typename T1, typename T2, typename glue_rel_type>
127 arma_warn_unused
128 inline
129 uvec
find(const mtGlueCube<uword,T1,T2,glue_rel_type> & X,const uword k=0,const char * direction="first")130 find(const mtGlueCube<uword, T1, T2, glue_rel_type>& X, const uword k = 0, const char* direction = "first")
131 {
132 arma_extra_debug_sigprint();
133
134 typedef typename T1::elem_type eT1;
135 typedef typename T2::elem_type eT2;
136
137 const unwrap_cube<T1> tmp1(X.A);
138 const unwrap_cube<T2> tmp2(X.B);
139
140 arma_debug_assert_same_size( tmp1.M, tmp2.M, "relational operator" );
141
142 const Mat<eT1> R1( const_cast< eT1* >(tmp1.M.memptr()), tmp1.M.n_elem, 1, false );
143 const Mat<eT2> R2( const_cast< eT2* >(tmp2.M.memptr()), tmp2.M.n_elem, 1, false );
144
145 return find( mtGlue<uword, Mat<eT1>, Mat<eT2>, glue_rel_type>(R1, R2), k, direction );
146 }
147
148
149
150 //
151
152
153
154 template<typename T1>
155 arma_warn_unused
156 inline
157 Col<uword>
find(const SpBase<typename T1::elem_type,T1> & X,const uword k=0)158 find(const SpBase<typename T1::elem_type,T1>& X, const uword k = 0)
159 {
160 arma_extra_debug_sigprint();
161
162 const SpProxy<T1> P(X.get_ref());
163
164 const uword n_rows = P.get_n_rows();
165 const uword n_nz = P.get_n_nonzero();
166
167 Mat<uword> tmp(n_nz, 1, arma_nozeros_indicator());
168
169 uword* tmp_mem = tmp.memptr();
170
171 typename SpProxy<T1>::const_iterator_type it = P.begin();
172
173 for(uword i=0; i<n_nz; ++i)
174 {
175 const uword index = it.row() + it.col()*n_rows;
176
177 tmp_mem[i] = index;
178
179 ++it;
180 }
181
182 Col<uword> out;
183
184 const uword count = (k == 0) ? uword(n_nz) : uword( (std::min)(n_nz, k) );
185
186 out.steal_mem_col(tmp, count);
187
188 return out;
189 }
190
191
192
193 template<typename T1>
194 arma_warn_unused
195 inline
196 Col<uword>
find(const SpBase<typename T1::elem_type,T1> & X,const uword k,const char * direction)197 find(const SpBase<typename T1::elem_type,T1>& X, const uword k, const char* direction)
198 {
199 arma_extra_debug_sigprint();
200
201 arma_ignore(X);
202 arma_ignore(k);
203 arma_ignore(direction);
204
205 arma_check(true, "find(SpBase,k,direction): not implemented yet"); // TODO
206
207 Col<uword> out;
208
209 return out;
210 }
211
212
213
214 //
215
216
217
218 template<typename T1>
219 arma_warn_unused
220 inline
221 typename
222 enable_if2
223 <
224 is_arma_type<T1>::value,
225 const mtOp<uword, T1, op_find_finite>
226 >::result
find_finite(const T1 & X)227 find_finite(const T1& X)
228 {
229 arma_extra_debug_sigprint();
230
231 return mtOp<uword, T1, op_find_finite>(X);
232 }
233
234
235
236 template<typename T1>
237 arma_warn_unused
238 inline
239 typename
240 enable_if2
241 <
242 is_arma_type<T1>::value,
243 const mtOp<uword, T1, op_find_nonfinite>
244 >::result
find_nonfinite(const T1 & X)245 find_nonfinite(const T1& X)
246 {
247 arma_extra_debug_sigprint();
248
249 return mtOp<uword, T1, op_find_nonfinite>(X);
250 }
251
252
253
254 //
255
256
257
258 template<typename T1>
259 arma_warn_unused
260 inline
261 uvec
find_finite(const BaseCube<typename T1::elem_type,T1> & X)262 find_finite(const BaseCube<typename T1::elem_type,T1>& X)
263 {
264 arma_extra_debug_sigprint();
265
266 typedef typename T1::elem_type eT;
267
268 const unwrap_cube<T1> tmp(X.get_ref());
269
270 const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
271
272 return find_finite(R);
273 }
274
275
276
277 template<typename T1>
278 arma_warn_unused
279 inline
280 uvec
find_nonfinite(const BaseCube<typename T1::elem_type,T1> & X)281 find_nonfinite(const BaseCube<typename T1::elem_type,T1>& X)
282 {
283 arma_extra_debug_sigprint();
284
285 typedef typename T1::elem_type eT;
286
287 const unwrap_cube<T1> tmp(X.get_ref());
288
289 const Mat<eT> R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false );
290
291 return find_nonfinite(R);
292 }
293
294
295
296 //
297
298
299
300 template<typename T1>
301 arma_warn_unused
302 inline
303 Col<uword>
find_finite(const SpBase<typename T1::elem_type,T1> & X)304 find_finite(const SpBase<typename T1::elem_type,T1>& X)
305 {
306 arma_extra_debug_sigprint();
307
308 const SpProxy<T1> P(X.get_ref());
309
310 const uword n_rows = P.get_n_rows();
311 const uword n_nz = P.get_n_nonzero();
312
313 Mat<uword> tmp(n_nz, 1, arma_nozeros_indicator());
314
315 uword* tmp_mem = tmp.memptr();
316
317 typename SpProxy<T1>::const_iterator_type it = P.begin();
318
319 uword count = 0;
320
321 for(uword i=0; i<n_nz; ++i)
322 {
323 if(arma_isfinite(*it))
324 {
325 const uword index = it.row() + it.col()*n_rows;
326
327 tmp_mem[count] = index;
328
329 ++count;
330 }
331
332 ++it;
333 }
334
335 Col<uword> out;
336
337 if(count > 0) { out.steal_mem_col(tmp, count); }
338
339 return out;
340 }
341
342
343
344 template<typename T1>
345 arma_warn_unused
346 inline
347 Col<uword>
find_nonfinite(const SpBase<typename T1::elem_type,T1> & X)348 find_nonfinite(const SpBase<typename T1::elem_type,T1>& X)
349 {
350 arma_extra_debug_sigprint();
351
352 const SpProxy<T1> P(X.get_ref());
353
354 const uword n_rows = P.get_n_rows();
355 const uword n_nz = P.get_n_nonzero();
356
357 Mat<uword> tmp(n_nz, 1, arma_nozeros_indicator());
358
359 uword* tmp_mem = tmp.memptr();
360
361 typename SpProxy<T1>::const_iterator_type it = P.begin();
362
363 uword count = 0;
364
365 for(uword i=0; i<n_nz; ++i)
366 {
367 if(arma_isfinite(*it) == false)
368 {
369 const uword index = it.row() + it.col()*n_rows;
370
371 tmp_mem[count] = index;
372
373 ++count;
374 }
375
376 ++it;
377 }
378
379 Col<uword> out;
380
381 if(count > 0) { out.steal_mem_col(tmp, count); }
382
383 return out;
384 }
385
386
387
388 //! @}
389