1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup op_log_det
20 //! @{
21 
22 
23 
24 template<typename T1>
25 inline
26 bool
apply_direct(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)27 op_log_det::apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
28   {
29   arma_extra_debug_sigprint();
30 
31   typedef typename T1::elem_type eT;
32 
33   if(strip_diagmat<T1>::do_diagmat)
34     {
35     const strip_diagmat<T1> strip(expr.get_ref());
36 
37     return op_log_det::apply_diagmat(out_val, out_sign, strip.M);
38     }
39 
40   if(strip_trimat<T1>::do_trimat)
41     {
42     const strip_trimat<T1> strip(expr.get_ref());
43 
44     return op_log_det::apply_trimat(out_val, out_sign, strip.M);
45     }
46 
47   Mat<eT> A(expr.get_ref());
48 
49   arma_debug_check( (A.is_square() == false), "log_det(): given matrix must be square sized" );
50 
51   if(A.is_diagmat())  { return op_log_det::apply_diagmat(out_val, out_sign, A); }
52 
53   const bool is_triu =                   trimat_helper::is_triu(A);
54   const bool is_tril = is_triu ? false : trimat_helper::is_tril(A);
55 
56   if(is_triu || is_tril)  { return op_log_det::apply_trimat(out_val, out_sign, A); }
57 
58   return auxlib::log_det(out_val, out_sign, A);
59   }
60 
61 
62 
63 template<typename T1>
64 inline
65 bool
apply_diagmat(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)66 op_log_det::apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
67   {
68   arma_extra_debug_sigprint();
69 
70   typedef typename T1::elem_type eT;
71   typedef typename T1::pod_type   T;
72 
73   const diagmat_proxy<T1> A(expr.get_ref());
74 
75   arma_debug_check( (A.n_rows != A.n_cols), "log_det(): given matrix must be square sized" );
76 
77   const uword N = (std::min)(A.n_rows, A.n_cols);
78 
79   if(N == 0)
80     {
81     out_val  = eT(0);
82     out_sign =  T(1);
83 
84     return true;
85     }
86 
87   eT x = A[0];
88 
89   T  sign = (is_cx<eT>::no) ?         ( (access::tmp_real(x) < T(0)) ?   T(-1) : T(1) ) : T(1);
90   eT val  = (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x    ) : std::log(x);
91 
92   for(uword i=1; i<N; ++i)
93     {
94     x = A[i];
95 
96     sign *= (is_cx<eT>::no) ?         ( (access::tmp_real(x) < T(0)) ?   T(-1) : T(1) ) : T(1);
97     val  += (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x    ) : std::log(x);
98     }
99 
100   out_val  = val;
101   out_sign = sign;
102 
103   return (arma_isnan(out_val) == false);
104   }
105 
106 
107 
108 template<typename T1>
109 inline
110 bool
apply_trimat(typename T1::elem_type & out_val,typename T1::pod_type & out_sign,const Base<typename T1::elem_type,T1> & expr)111 op_log_det::apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base<typename T1::elem_type,T1>& expr)
112   {
113   arma_extra_debug_sigprint();
114 
115   typedef typename T1::elem_type eT;
116   typedef typename T1::pod_type   T;
117 
118   const Proxy<T1> P(expr.get_ref());
119 
120   const uword N = P.get_n_rows();
121 
122   arma_debug_check( (N != P.get_n_cols()), "log_det(): given matrix must be square sized" );
123 
124   if(N == 0)
125     {
126     out_val  = eT(0);
127     out_sign =  T(1);
128 
129     return true;
130     }
131 
132   eT x = P.at(0,0);
133 
134   T  sign = (is_cx<eT>::no) ?         ( (access::tmp_real(x) < T(0)) ?   T(-1) : T(1) ) : T(1);
135   eT val  = (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x    ) : std::log(x);
136 
137   for(uword i=1; i<N; ++i)
138     {
139     x = P.at(i,i);
140 
141     sign *= (is_cx<eT>::no) ?         ( (access::tmp_real(x) < T(0)) ?   T(-1) : T(1) ) : T(1);
142     val  += (is_cx<eT>::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x    ) : std::log(x);
143     }
144 
145   out_val  = val;
146   out_sign = sign;
147 
148   return (arma_isnan(out_val) == false);
149   }
150 
151 
152 
153 //
154 
155 
156 
157 template<typename T1>
158 inline
159 bool
apply_direct(typename T1::pod_type & out_val,const Base<typename T1::elem_type,T1> & expr)160 op_log_det_sympd::apply_direct(typename T1::pod_type& out_val, const Base<typename T1::elem_type,T1>& expr)
161   {
162   arma_extra_debug_sigprint();
163 
164   typedef typename T1::elem_type eT;
165 
166   Mat<eT> A(expr.get_ref());
167 
168   arma_debug_check( (A.is_square() == false), "log_det_sympd(): given matrix must be square sized" );
169 
170   if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false))
171     {
172     if(is_cx<eT>::no )  { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not symmetric"); }
173     if(is_cx<eT>::yes)  { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not hermitian"); }
174     }
175 
176   return auxlib::log_det_sympd(out_val, A);
177   }
178 
179 
180 
181 //! @}
182