1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup op_det
20 //! @{
21
22
23
24 template<typename T1>
25 inline
26 bool
apply_direct(typename T1::elem_type & out_val,const Base<typename T1::elem_type,T1> & expr)27 op_det::apply_direct(typename T1::elem_type& out_val, const Base<typename T1::elem_type,T1>& expr)
28 {
29 arma_extra_debug_sigprint();
30
31 typedef typename T1::elem_type eT;
32 typedef typename T1::pod_type T;
33
34 if(strip_diagmat<T1>::do_diagmat)
35 {
36 const strip_diagmat<T1> strip(expr.get_ref());
37
38 out_val = op_det::apply_diagmat(strip.M);
39
40 return true;
41 }
42
43 if(strip_trimat<T1>::do_trimat)
44 {
45 const strip_trimat<T1> strip(expr.get_ref());
46
47 out_val = op_det::apply_trimat(strip.M);
48
49 return true;
50 }
51
52 Mat<eT> A(expr.get_ref());
53
54 arma_debug_check( (A.is_square() == false), "det(): given matrix must be square sized" );
55
56 if((A.n_rows <= 4) && is_cx<eT>::no)
57 {
58 constexpr T det_min = std::numeric_limits<T>::epsilon();
59 constexpr T det_max = T(1) / std::numeric_limits<T>::epsilon();
60
61 const eT det_val = op_det::apply_tiny(A);
62 const T abs_det_val = std::abs(det_val);
63
64 if((abs_det_val > det_min) && (abs_det_val < det_max)) { out_val = det_val; return true; }
65
66 // fallthrough if det_val is suspect
67 }
68
69 if(A.is_diagmat()) { out_val = op_det::apply_diagmat(A); return true; }
70
71 const bool is_triu = trimat_helper::is_triu(A);
72 const bool is_tril = is_triu ? false : trimat_helper::is_tril(A);
73
74 if(is_triu || is_tril) { out_val = op_det::apply_trimat(A); return true; }
75
76 return auxlib::det(out_val, A);
77 }
78
79
80
81 template<typename T1>
82 inline
83 typename T1::elem_type
apply_diagmat(const Base<typename T1::elem_type,T1> & expr)84 op_det::apply_diagmat(const Base<typename T1::elem_type,T1>& expr)
85 {
86 arma_extra_debug_sigprint();
87
88 typedef typename T1::elem_type eT;
89
90 const diagmat_proxy<T1> A(expr.get_ref());
91
92 arma_debug_check( (A.n_rows != A.n_cols), "det(): given matrix must be square sized" );
93
94 const uword N = (std::min)(A.n_rows, A.n_cols);
95
96 eT val = eT(1);
97
98 for(uword i=0; i<N; ++i) { val *= A[i]; }
99
100 return val;
101 }
102
103
104
105 template<typename T1>
106 inline
107 typename T1::elem_type
apply_trimat(const Base<typename T1::elem_type,T1> & expr)108 op_det::apply_trimat(const Base<typename T1::elem_type,T1>& expr)
109 {
110 arma_extra_debug_sigprint();
111
112 typedef typename T1::elem_type eT;
113
114 const Proxy<T1> P(expr.get_ref());
115
116 const uword N = P.get_n_rows();
117
118 arma_debug_check( (N != P.get_n_cols()), "det(): given matrix must be square sized" );
119
120 eT val = eT(1);
121
122 for(uword i=0; i<N; ++i) { val *= P.at(i,i); }
123
124 return val;
125 }
126
127
128
129 template<typename eT>
130 arma_cold
131 inline
132 eT
apply_tiny(const Mat<eT> & X)133 op_det::apply_tiny(const Mat<eT>& X)
134 {
135 arma_extra_debug_sigprint();
136
137 // NOTE: assuming matrix X is square sized
138
139 const uword N = X.n_rows;
140 const eT* Xm = X.memptr();
141
142 if(N == 0) { return eT(1); }
143
144 if(N == 1) { return Xm[0]; }
145
146 if(N == 2)
147 {
148 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
149 }
150
151 if(N == 3)
152 {
153 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
154 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
155 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
156 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
157 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
158 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
159 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
160
161 const eT val1 = Xm[pos<0,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]);
162 const eT val2 = Xm[pos<1,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]);
163 const eT val3 = Xm[pos<2,0>::n3]*(Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]);
164
165 return ( val1 - val2 + val3 );
166 }
167
168 if(N == 4)
169 {
170 const eT val_03_12 = Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4];
171 const eT val_02_13 = Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4];
172 const eT val_03_11 = Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4];
173
174 const eT val_01_13 = Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4];
175 const eT val_02_11 = Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4];
176 const eT val_01_12 = Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4];
177
178 const eT val_03_10 = Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4];
179 const eT val_00_13 = Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4];
180 const eT val_02_10 = Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4];
181 const eT val_00_12 = Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4];
182
183 const eT val_01_10 = Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4];
184 const eT val_00_11 = Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4];
185
186 const eT val_21_30 = Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4];
187 const eT val_22_30 = Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4];
188 const eT val_23_30 = Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4];
189
190 const eT val_20_31 = Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4];
191 const eT val_22_31 = Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4];
192 const eT val_23_31 = Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4];
193
194 const eT val_20_32 = Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4];
195 const eT val_21_32 = Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4];
196 const eT val_23_32 = Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4];
197
198 const eT val_20_33 = Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4];
199 const eT val_21_33 = Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4];
200 const eT val_22_33 = Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4];
201
202 const eT val = \
203 val_03_12 * val_21_30 \
204 - val_02_13 * val_21_30 \
205 - val_03_11 * val_22_30 \
206 + val_01_13 * val_22_30 \
207 + val_02_11 * val_23_30 \
208 - val_01_12 * val_23_30 \
209 - val_03_12 * val_20_31 \
210 + val_02_13 * val_20_31 \
211 + val_03_10 * val_22_31 \
212 - val_00_13 * val_22_31 \
213 - val_02_10 * val_23_31 \
214 + val_00_12 * val_23_31 \
215 + val_03_11 * val_20_32 \
216 - val_01_13 * val_20_32 \
217 - val_03_10 * val_21_32 \
218 + val_00_13 * val_21_32 \
219 + val_01_10 * val_23_32 \
220 - val_00_11 * val_23_32 \
221 - val_02_11 * val_20_33 \
222 + val_01_12 * val_20_33 \
223 + val_02_10 * val_21_33 \
224 - val_00_12 * val_21_33 \
225 - val_01_10 * val_22_33 \
226 + val_00_11 * val_22_33 \
227 ;
228
229 return val;
230 }
231
232 return eT(0);
233 }
234
235
236
237 //! @}
238