1 /* _______________________________________________________________________
2
3 DAKOTA: Design Analysis Kit for Optimization and Terascale Applications
4 Copyright 2014-2020 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
5 This software is distributed under the GNU Lesser General Public License.
6 For more information, see the README file in the top Dakota directory.
7 _______________________________________________________________________ */
8
9 #include "ReducedBasis.hpp"
10 #include "dakota_linear_algebra.hpp"
11 #include "dakota_global_defs.hpp"
12
13 #include <Teuchos_SerialDenseHelpers.hpp>
14
15 namespace Dakota {
16
17 // ------------------------------------------
18
ReducedBasis()19 ReducedBasis::ReducedBasis() :
20 col_means_computed(false), is_centered(false), is_valid_svd(false)
21 {
22 }
23
24 // ------------------------------------------
25
26 void
set_matrix(const RealMatrix & mat)27 ReducedBasis::set_matrix(const RealMatrix & mat)
28 {
29 matrix = mat;
30 col_means_computed = false;
31 is_centered = false;
32 is_valid_svd = false;
33 }
34
35 // ------------------------------------------
36
37 void
center_matrix()38 ReducedBasis::center_matrix()
39 {
40 if ( is_centered )
41 return;
42
43 compute_column_means(matrix, column_means);
44 col_means_computed = true;
45
46 // working vector
47 RealVector column_vec(matrix.numRows());
48
49 for( int i=0; i<matrix.numCols(); ++i ) {
50 column_vec.putScalar(column_means(i));
51 RealVector matrix_column = Teuchos::getCol(Teuchos::View, matrix, i);
52 matrix_column -= column_vec;
53 }
54
55 is_centered = true;
56 is_valid_svd = false;
57 }
58
59 // ------------------------------------------
60
61 void
update_svd(bool do_center)62 ReducedBasis::update_svd(bool do_center)
63 {
64 if( is_valid_svd )
65 return;
66
67 if( matrix.empty() )
68 throw std::runtime_error("Matrix is empty. Make sure to call set_matrix(...) first.");
69
70 if( do_center )
71 center_matrix();
72
73 workingMatrix = matrix; // because the matrix gets overwritten by U_matrix values
74 svd(workingMatrix, S_values, VT_matrix);
75 U_matrix = workingMatrix;
76
77 RealVector ones(S_values.length());
78 ones = 1.0;
79 singular_values_sum = ones.dot(S_values);
80
81 eigen_values_sum = 0.0;
82 for( int i=0; i<S_values.length(); ++i )
83 eigen_values_sum += S_values(i)*S_values(i);
84
85 is_valid_svd = true;
86 }
87
88 // ------------------------------------------
89
90 RealVector
get_singular_values(const TruncationCondition & truncation_cond) const91 ReducedBasis::get_singular_values(const TruncationCondition & truncation_cond) const
92 {
93 int num_values = truncation_cond.get_num_components(*this);
94
95 RealVector vec(num_values);
96 for( int i=0; i<num_values; ++i )
97 vec(i) = S_values(i);
98
99 return vec;
100 }
101
102 // ------------------------------------------
103 // -------- Truncation Conditions ----------
104 // ------------------------------------------
105
106
107 void
sanity_check(const ReducedBasis & basis) const108 ReducedBasis::TruncationCondition::sanity_check(const ReducedBasis & basis) const
109 {
110 if( !basis.is_valid() ) {
111 Cerr << "\nError: Truncation condition cannot be applied before computing a valid ReducedBasis SVD."
112 << std::endl;
113 abort_handler(-1);
114 }
115 }
116
117 // ------------------------------------------
118
Untruncated()119 ReducedBasis::Untruncated::Untruncated() :
120 TruncationCondition()
121 { }
122
get_num_components(const ReducedBasis & basis) const123 int ReducedBasis::Untruncated::get_num_components(const ReducedBasis & basis) const
124 {
125 sanity_check(basis);
126 return basis.get_singular_values().length();
127 }
128
129 // ------------------------------------------
130
VarianceExplained(Real var_exp)131 ReducedBasis::VarianceExplained::VarianceExplained(Real var_exp) :
132 TruncationCondition(),
133 variance_explained(var_exp)
134 {
135 if( (0.0 > var_exp) || (1.0 < var_exp) ) {
136 Cerr << "\nError: VarianceExplained Truncation condition must be in the range (0.0, 1,0)."
137 << std::endl;
138 abort_handler(-1);
139 }
140 }
141
get_num_components(const ReducedBasis & basis) const142 int ReducedBasis::VarianceExplained::get_num_components(const ReducedBasis & basis) const
143 {
144 sanity_check(basis);
145
146 Real total_sum = basis.get_eigen_values_sum();
147 const RealVector & singular_vals = basis.get_singular_values();
148 int num_comp = 0;
149 Real partial_sum = 0.0;
150
151 while( partial_sum/total_sum < variance_explained )
152 partial_sum += singular_vals(num_comp)*singular_vals(num_comp++);
153
154 return num_comp;
155 }
156
157 // ------------------------------------------
158
HeuristicVarianceExplained(Real var_exp)159 ReducedBasis::HeuristicVarianceExplained::HeuristicVarianceExplained(Real var_exp) :
160 TruncationCondition(),
161 variance_explained(var_exp)
162 {
163 if( (0.0 > var_exp) || (1.0 < var_exp) ) {
164 Cerr << "\nError: HeuristicVarianceExplained Truncation condition must be in the range (0.0, 1,0)."
165 << std::endl;
166 abort_handler(-1);
167 }
168 }
169
get_num_components(const ReducedBasis & basis) const170 int ReducedBasis::HeuristicVarianceExplained::get_num_components(const ReducedBasis & basis) const
171 {
172 sanity_check(basis);
173
174 const RealVector & singular_vals = basis.get_singular_values();
175 Real largest_eig_val = singular_vals(0)*singular_vals(0);
176 int num_comp = 0;
177 Real ratio = 1.0;
178
179 while( ratio > (1.0-variance_explained) )
180 ratio = singular_vals(num_comp)*singular_vals(num_comp++)/largest_eig_val;
181
182 return num_comp;
183 }
184
185 // ------------------------------------------
186
NumComponents(int num_comp)187 ReducedBasis::NumComponents::NumComponents(int num_comp) :
188 TruncationCondition(),
189 num_components(num_comp)
190 { }
191
get_num_components(const ReducedBasis & basis) const192 int ReducedBasis::NumComponents::get_num_components(const ReducedBasis & basis) const
193 {
194 sanity_check(basis);
195 return num_components;
196 }
197
198 } // namespace Dakota
199