1 /*  _______________________________________________________________________
2 
3     DAKOTA: Design Analysis Kit for Optimization and Terascale Applications
4     Copyright 2014-2020 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
5     This software is distributed under the GNU Lesser General Public License.
6     For more information, see the README file in the top Dakota directory.
7     _______________________________________________________________________ */
8 
9 #include "ReducedBasis.hpp"
10 #include "dakota_linear_algebra.hpp"
11 #include "dakota_global_defs.hpp"
12 
13 #include <Teuchos_SerialDenseHelpers.hpp>
14 
15 namespace Dakota {
16 
17 // ------------------------------------------
18 
ReducedBasis()19 ReducedBasis::ReducedBasis() :
20   col_means_computed(false), is_centered(false), is_valid_svd(false)
21 {
22 }
23 
24 // ------------------------------------------
25 
26 void
set_matrix(const RealMatrix & mat)27 ReducedBasis::set_matrix(const RealMatrix & mat)
28 {
29   matrix = mat;
30   col_means_computed = false;
31   is_centered = false;
32   is_valid_svd = false;
33 }
34 
35 // ------------------------------------------
36 
37 void
center_matrix()38 ReducedBasis::center_matrix()
39 {
40   if ( is_centered )
41     return;
42 
43   compute_column_means(matrix, column_means);
44   col_means_computed = true;
45 
46   // working vector
47   RealVector column_vec(matrix.numRows());
48 
49   for( int i=0; i<matrix.numCols(); ++i ) {
50     column_vec.putScalar(column_means(i));
51     RealVector matrix_column = Teuchos::getCol(Teuchos::View, matrix, i);
52     matrix_column -= column_vec;
53   }
54 
55   is_centered = true;
56   is_valid_svd = false;
57 }
58 
59 // ------------------------------------------
60 
61 void
update_svd(bool do_center)62 ReducedBasis::update_svd(bool do_center)
63 {
64   if( is_valid_svd )
65     return;
66 
67   if( matrix.empty() )
68     throw std::runtime_error("Matrix is empty.  Make sure to call set_matrix(...) first.");
69 
70   if( do_center )
71     center_matrix();
72 
73   workingMatrix = matrix; // because the matrix gets overwritten by U_matrix values
74   svd(workingMatrix, S_values, VT_matrix);
75   U_matrix = workingMatrix;
76 
77   RealVector ones(S_values.length());
78   ones = 1.0;
79   singular_values_sum = ones.dot(S_values);
80 
81   eigen_values_sum = 0.0;
82   for( int i=0; i<S_values.length(); ++i )
83     eigen_values_sum += S_values(i)*S_values(i);
84 
85   is_valid_svd = true;
86 }
87 
88 // ------------------------------------------
89 
90 RealVector
get_singular_values(const TruncationCondition & truncation_cond) const91 ReducedBasis::get_singular_values(const TruncationCondition & truncation_cond) const
92 {
93   int num_values = truncation_cond.get_num_components(*this);
94 
95   RealVector vec(num_values);
96   for( int i=0; i<num_values; ++i )
97     vec(i) = S_values(i);
98 
99   return vec;
100 }
101 
102 // ------------------------------------------
103 // -------- Truncation Conditions  ----------
104 // ------------------------------------------
105 
106 
107 void
sanity_check(const ReducedBasis & basis) const108 ReducedBasis::TruncationCondition::sanity_check(const ReducedBasis & basis) const
109 {
110   if( !basis.is_valid() ) {
111     Cerr << "\nError: Truncation condition cannot be applied before computing a valid ReducedBasis SVD."
112       << std::endl;
113     abort_handler(-1);
114   }
115 }
116 
117 // ------------------------------------------
118 
Untruncated()119 ReducedBasis::Untruncated::Untruncated() :
120   TruncationCondition()
121 { }
122 
get_num_components(const ReducedBasis & basis) const123 int ReducedBasis::Untruncated::get_num_components(const ReducedBasis & basis) const
124 {
125   sanity_check(basis);
126   return basis.get_singular_values().length();
127 }
128 
129 // ------------------------------------------
130 
VarianceExplained(Real var_exp)131 ReducedBasis::VarianceExplained::VarianceExplained(Real var_exp) :
132   TruncationCondition(),
133   variance_explained(var_exp)
134 {
135   if( (0.0 > var_exp) || (1.0 < var_exp) ) {
136     Cerr << "\nError: VarianceExplained Truncation condition must be in the range (0.0, 1,0)."
137       << std::endl;
138     abort_handler(-1);
139   }
140 }
141 
get_num_components(const ReducedBasis & basis) const142 int ReducedBasis::VarianceExplained::get_num_components(const ReducedBasis & basis) const
143 {
144   sanity_check(basis);
145 
146   Real total_sum = basis.get_eigen_values_sum();
147   const RealVector & singular_vals = basis.get_singular_values();
148   int num_comp = 0;
149   Real partial_sum = 0.0;
150 
151   while( partial_sum/total_sum < variance_explained )
152     partial_sum += singular_vals(num_comp)*singular_vals(num_comp++);
153 
154   return num_comp;
155 }
156 
157 // ------------------------------------------
158 
HeuristicVarianceExplained(Real var_exp)159 ReducedBasis::HeuristicVarianceExplained::HeuristicVarianceExplained(Real var_exp) :
160   TruncationCondition(),
161   variance_explained(var_exp)
162 {
163   if( (0.0 > var_exp) || (1.0 < var_exp) ) {
164     Cerr << "\nError: HeuristicVarianceExplained Truncation condition must be in the range (0.0, 1,0)."
165       << std::endl;
166     abort_handler(-1);
167   }
168 }
169 
get_num_components(const ReducedBasis & basis) const170 int ReducedBasis::HeuristicVarianceExplained::get_num_components(const ReducedBasis & basis) const
171 {
172   sanity_check(basis);
173 
174   const RealVector & singular_vals = basis.get_singular_values();
175   Real largest_eig_val = singular_vals(0)*singular_vals(0);
176   int num_comp = 0;
177   Real ratio = 1.0;
178 
179   while( ratio > (1.0-variance_explained) )
180     ratio = singular_vals(num_comp)*singular_vals(num_comp++)/largest_eig_val;
181 
182   return num_comp;
183 }
184 
185 // ------------------------------------------
186 
NumComponents(int num_comp)187 ReducedBasis::NumComponents::NumComponents(int num_comp) :
188   TruncationCondition(),
189   num_components(num_comp)
190 { }
191 
get_num_components(const ReducedBasis & basis) const192 int ReducedBasis::NumComponents::get_num_components(const ReducedBasis & basis) const
193 {
194   sanity_check(basis);
195   return num_components;
196 }
197 
198 }  // namespace Dakota
199