1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup op_det
20 //! @{
21 
22 
23 
24 template<typename T1>
25 inline
26 bool
apply_direct(typename T1::elem_type & out_val,const Base<typename T1::elem_type,T1> & expr)27 op_det::apply_direct(typename T1::elem_type& out_val, const Base<typename T1::elem_type,T1>& expr)
28   {
29   arma_extra_debug_sigprint();
30 
31   typedef typename T1::elem_type eT;
32   typedef typename T1::pod_type   T;
33 
34   if(strip_diagmat<T1>::do_diagmat)
35     {
36     const strip_diagmat<T1> strip(expr.get_ref());
37 
38     out_val = op_det::apply_diagmat(strip.M);
39 
40     return true;
41     }
42 
43   if(strip_trimat<T1>::do_trimat)
44     {
45     const strip_trimat<T1> strip(expr.get_ref());
46 
47     out_val = op_det::apply_trimat(strip.M);
48 
49     return true;
50     }
51 
52   Mat<eT> A(expr.get_ref());
53 
54   arma_debug_check( (A.is_square() == false), "det(): given matrix must be square sized" );
55 
56   if((A.n_rows <= 4) && is_cx<eT>::no)
57     {
58     constexpr T det_min =        std::numeric_limits<T>::epsilon();
59     constexpr T det_max = T(1) / std::numeric_limits<T>::epsilon();
60 
61     const eT     det_val = op_det::apply_tiny(A);
62     const  T abs_det_val = std::abs(det_val);
63 
64     if((abs_det_val > det_min) && (abs_det_val < det_max))  { out_val = det_val; return true; }
65 
66     // fallthrough if det_val is suspect
67     }
68 
69   if(A.is_diagmat())  { out_val = op_det::apply_diagmat(A); return true; }
70 
71   const bool is_triu =                   trimat_helper::is_triu(A);
72   const bool is_tril = is_triu ? false : trimat_helper::is_tril(A);
73 
74   if(is_triu || is_tril)  { out_val = op_det::apply_trimat(A); return true; }
75 
76   return auxlib::det(out_val, A);
77   }
78 
79 
80 
81 template<typename T1>
82 inline
83 typename T1::elem_type
apply_diagmat(const Base<typename T1::elem_type,T1> & expr)84 op_det::apply_diagmat(const Base<typename T1::elem_type,T1>& expr)
85   {
86   arma_extra_debug_sigprint();
87 
88   typedef typename T1::elem_type eT;
89 
90   const diagmat_proxy<T1> A(expr.get_ref());
91 
92   arma_debug_check( (A.n_rows != A.n_cols), "det(): given matrix must be square sized" );
93 
94   const uword N = (std::min)(A.n_rows, A.n_cols);
95 
96   eT val = eT(1);
97 
98   for(uword i=0; i<N; ++i)  { val *= A[i]; }
99 
100   return val;
101   }
102 
103 
104 
105 template<typename T1>
106 inline
107 typename T1::elem_type
apply_trimat(const Base<typename T1::elem_type,T1> & expr)108 op_det::apply_trimat(const Base<typename T1::elem_type,T1>& expr)
109   {
110   arma_extra_debug_sigprint();
111 
112   typedef typename T1::elem_type eT;
113 
114   const Proxy<T1> P(expr.get_ref());
115 
116   const uword N = P.get_n_rows();
117 
118   arma_debug_check( (N != P.get_n_cols()), "det(): given matrix must be square sized" );
119 
120   eT val = eT(1);
121 
122   for(uword i=0; i<N; ++i)  { val *= P.at(i,i); }
123 
124   return val;
125   }
126 
127 
128 
129 template<typename eT>
130 arma_cold
131 inline
132 eT
apply_tiny(const Mat<eT> & X)133 op_det::apply_tiny(const Mat<eT>& X)
134   {
135   arma_extra_debug_sigprint();
136 
137   // NOTE: assuming matrix X is square sized
138 
139   const uword N  = X.n_rows;
140   const eT*   Xm = X.memptr();
141 
142   if(N == 0)  { return eT(1); }
143 
144   if(N == 1)  { return Xm[0]; }
145 
146   if(N == 2)
147     {
148     return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
149     }
150 
151   if(N == 3)
152     {
153     // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
154     // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
155     // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
156     // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
157     // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
158     // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
159     // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
160 
161     const eT val1 = Xm[pos<0,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]);
162     const eT val2 = Xm[pos<1,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]);
163     const eT val3 = Xm[pos<2,0>::n3]*(Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]);
164 
165     return ( val1 - val2 + val3 );
166     }
167 
168   if(N == 4)
169     {
170     const eT val_03_12 = Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4];
171     const eT val_02_13 = Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4];
172     const eT val_03_11 = Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4];
173 
174     const eT val_01_13 = Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4];
175     const eT val_02_11 = Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4];
176     const eT val_01_12 = Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4];
177 
178     const eT val_03_10 = Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4];
179     const eT val_00_13 = Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4];
180     const eT val_02_10 = Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4];
181     const eT val_00_12 = Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4];
182 
183     const eT val_01_10 = Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4];
184     const eT val_00_11 = Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4];
185 
186     const eT val_21_30 = Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4];
187     const eT val_22_30 = Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4];
188     const eT val_23_30 = Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4];
189 
190     const eT val_20_31 = Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4];
191     const eT val_22_31 = Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4];
192     const eT val_23_31 = Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4];
193 
194     const eT val_20_32 = Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4];
195     const eT val_21_32 = Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4];
196     const eT val_23_32 = Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4];
197 
198     const eT val_20_33 = Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4];
199     const eT val_21_33 = Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4];
200     const eT val_22_33 = Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4];
201 
202     const eT val = \
203         val_03_12 * val_21_30 \
204       - val_02_13 * val_21_30 \
205       - val_03_11 * val_22_30 \
206       + val_01_13 * val_22_30 \
207       + val_02_11 * val_23_30 \
208       - val_01_12 * val_23_30 \
209       - val_03_12 * val_20_31 \
210       + val_02_13 * val_20_31 \
211       + val_03_10 * val_22_31 \
212       - val_00_13 * val_22_31 \
213       - val_02_10 * val_23_31 \
214       + val_00_12 * val_23_31 \
215       + val_03_11 * val_20_32 \
216       - val_01_13 * val_20_32 \
217       - val_03_10 * val_21_32 \
218       + val_00_13 * val_21_32 \
219       + val_01_10 * val_23_32 \
220       - val_00_11 * val_23_32 \
221       - val_02_11 * val_20_33 \
222       + val_01_12 * val_20_33 \
223       + val_02_10 * val_21_33 \
224       - val_00_12 * val_21_33 \
225       - val_01_10 * val_22_33 \
226       + val_00_11 * val_22_33 \
227       ;
228 
229     return val;
230     }
231 
232   return eT(0);
233   }
234 
235 
236 
237 //! @}
238