1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup op_log_det
20 //! @{
21
22
23
24 template<typename T1>
25 inline
26 bool
apply_direct(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)27 op_log_det::apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
28 {
29 arma_extra_debug_sigprint();
30
31 typedef typename T1::elem_type eT;
32
33 if(strip_diagmat<T1>::do_diagmat)
34 {
35 const strip_diagmat<T1> strip(expr.get_ref());
36
37 return op_log_det::apply_diagmat(out_val, out_sign, strip.M);
38 }
39
40 if(strip_trimat<T1>::do_trimat)
41 {
42 const strip_trimat<T1> strip(expr.get_ref());
43
44 return op_log_det::apply_trimat(out_val, out_sign, strip.M);
45 }
46
47 Mat<eT> A(expr.get_ref());
48
49 arma_debug_check( (A.is_square() == false), "log_det(): given matrix must be square sized" );
50
51 if(A.is_diagmat()) { return op_log_det::apply_diagmat(out_val, out_sign, A); }
52
53 const bool is_triu = trimat_helper::is_triu(A);
54 const bool is_tril = is_triu ? false : trimat_helper::is_tril(A);
55
56 if(is_triu || is_tril) { return op_log_det::apply_trimat(out_val, out_sign, A); }
57
58 return auxlib::log_det(out_val, out_sign, A);
59 }
60
61
62
63 template<typename T1>
64 inline
65 bool
apply_diagmat(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)66 op_log_det::apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
67 {
68 arma_extra_debug_sigprint();
69
70 typedef typename T1::elem_type eT;
71 typedef typename T1::pod_type T;
72
73 const diagmat_proxy<T1> A(expr.get_ref());
74
75 arma_debug_check( (A.n_rows != A.n_cols), "log_det(): given matrix must be square sized" );
76
77 const uword N = (std::min)(A.n_rows, A.n_cols);
78
79 if(N == 0)
80 {
81 out_val = eT(0);
82 out_sign = T(1);
83
84 return true;
85 }
86
87 eT x = A[0];
88
89 T sign = (is_cx<eT>::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1);
90 eT val = (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
91
92 for(uword i=1; i<N; ++i)
93 {
94 x = A[i];
95
96 sign *= (is_cx<eT>::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1);
97 val += (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
98 }
99
100 out_val = val;
101 out_sign = sign;
102
103 return (arma_isnan(out_val) == false);
104 }
105
106
107
108 template<typename T1>
109 inline
110 bool
apply_trimat(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)111 op_log_det::apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
112 {
113 arma_extra_debug_sigprint();
114
115 typedef typename T1::elem_type eT;
116 typedef typename T1::pod_type T;
117
118 const Proxy<T1> P(expr.get_ref());
119
120 const uword N = P.get_n_rows();
121
122 arma_debug_check( (N != P.get_n_cols()), "log_det(): given matrix must be square sized" );
123
124 if(N == 0)
125 {
126 out_val = eT(0);
127 out_sign = T(1);
128
129 return true;
130 }
131
132 eT x = P.at(0,0);
133
134 T sign = (is_cx<eT>::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1);
135 eT val = (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
136
137 for(uword i=1; i<N; ++i)
138 {
139 x = P.at(i,i);
140
141 sign *= (is_cx<eT>::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1);
142 val += (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
143 }
144
145 out_val = val;
146 out_sign = sign;
147
148 return (arma_isnan(out_val) == false);
149 }
150
151
152
153 //
154
155
156
157 template<typename T1>
158 inline
159 bool
apply_direct(typename T1::pod_type & out_val,const Base<typename T1::elem_type,T1> & expr)160 op_log_det_sympd::apply_direct(typename T1::pod_type& out_val, const Base<typename T1::elem_type,T1>& expr)
161 {
162 arma_extra_debug_sigprint();
163
164 typedef typename T1::elem_type eT;
165
166 Mat<eT> A(expr.get_ref());
167
168 arma_debug_check( (A.is_square() == false), "log_det_sympd(): given matrix must be square sized" );
169
170 if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false))
171 {
172 if(is_cx<eT>::no ) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not symmetric"); }
173 if(is_cx<eT>::yes) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not hermitian"); }
174 }
175
176 return auxlib::log_det_sympd(out_val, A);
177 }
178
179
180
181 //! @}
182