1 // -*- C++ -*-
2 /**
3 * @brief Class for a quadratic numerical math gradient implementation
4 *
5 * Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6 *
7 * This library is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU Lesser General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public License
18 * along with this library. If not, see <http://www.gnu.org/licenses/>.
19 *
20 */
21 #include "openturns/LinearGradient.hxx"
22 #include "openturns/PersistentObjectFactory.hxx"
23 #include "openturns/Os.hxx"
24 #include "openturns/Lapack.hxx"
25
26
27 BEGIN_NAMESPACE_OPENTURNS
28
29
30
31 CLASSNAMEINIT(LinearGradient)
32
33 static const Factory<LinearGradient> Factory_LinearGradient;
34
35 /* Default constructor */
LinearGradient()36 LinearGradient::LinearGradient()
37 : GradientImplementation()
38 {
39 // Nothing to do
40 }
41
42 /* Parameter constructor */
LinearGradient(const Point & center,const Matrix & constant,const SymmetricTensor & linear)43 LinearGradient::LinearGradient(const Point & center,
44 const Matrix & constant,
45 const SymmetricTensor & linear)
46 : GradientImplementation()
47 , center_(center)
48 , constant_(constant)
49 , linear_(linear)
50 {
51 /* Check if the dimensions of the constant term is compatible with the linear term */
52 if ((constant.getNbRows() != linear.getNbRows()) || (constant.getNbColumns() != linear.getNbSheets())) throw InvalidDimensionException(HERE) << "Constant term dimensions are incompatible with the linear term";
53 /* Check if the dimensions of the center term is compatible with the linear term */
54 if ((center.getDimension() != constant.getNbRows()) || (center.getDimension() != linear.getNbRows())) throw InvalidDimensionException(HERE) << "Center term dimensions are incompatible with the constant term or the linear term";
55 }
56
57 /* Virtual constructor */
clone() const58 LinearGradient * LinearGradient::clone() const
59 {
60 return new LinearGradient(*this);
61 }
62
63 /* Comparison operator */
operator ==(const LinearGradient & other) const64 Bool LinearGradient::operator ==(const LinearGradient & other) const
65 {
66 return ((linear_ == other.linear_) && (constant_ == other.constant_) && (center_ == other.center_));
67 }
68
69 /* String converter */
__repr__() const70 String LinearGradient::__repr__() const
71 {
72 OSS oss(true);
73 oss << "class=" << LinearGradient::GetClassName()
74 << " name=" << getName()
75 << " center=" << center_.__repr__()
76 << " constant=" << constant_.__repr__()
77 << " linear=" << linear_;
78 return oss;
79 }
80
__str__(const String & offset) const81 String LinearGradient::__str__(const String & offset) const
82 {
83 OSS oss(false);
84 oss << LinearGradient::GetClassName() << Os::GetEndOfLine() << offset
85 << " center :" << Os::GetEndOfLine() << offset << center_.__str__(offset + " ") << Os::GetEndOfLine()
86 << " constant :" << Os::GetEndOfLine() << offset << constant_.__str__(offset + " ") << Os::GetEndOfLine()
87 << " linear :" << Os::GetEndOfLine() << offset << linear_.__str__(offset + " ") << Os::GetEndOfLine();
88 return oss;
89 }
90
91 /* Accessor for the center */
getCenter() const92 Point LinearGradient::getCenter() const
93 {
94 return center_;
95 }
96
97 /* Accessor for the constant term */
getConstant() const98 Matrix LinearGradient::getConstant() const
99 {
100 return constant_;
101 }
102
103 /* Accessor for the linear term */
getLinear() const104 SymmetricTensor LinearGradient::getLinear() const
105 {
106 return linear_;
107 }
108
109 /* Here is the interface that all derived class must implement */
110
111 /* Gradient() */
gradient(const Point & inP) const112 Matrix LinearGradient::gradient(const Point & inP) const
113 {
114 if (inP.getDimension() != constant_.getNbRows()) throw InvalidArgumentException(HERE) << "Invalid input dimension";
115 Matrix value(constant_);
116 // Add the linear term <linear, x>
117 const UnsignedInteger nbSheets = linear_.getNbSheets();
118 const UnsignedInteger nbRows = linear_.getNbRows();
119 if (nbSheets == 0 || nbRows == 0)
120 return value;
121 const Point delta(inP - center_);
122 char uplo('L');
123 int n(nbRows);
124 int one(1);
125 double alpha(1.0);
126 double beta(1.0);
127 int luplo(1);
128 for(UnsignedInteger k = 0; k < nbSheets; ++k)
129 dsymv_(&uplo, &n, &alpha, const_cast<double*>(&(linear_(0, 0, k))), &n, const_cast<double*>(&(delta[0])), &one, &beta, &value(0, k), &one, &luplo);
130 callsNumber_.increment();
131 return value;
132 }
133
134 /* Accessor for input point dimension */
getInputDimension() const135 UnsignedInteger LinearGradient::getInputDimension() const
136 {
137 return center_.getDimension();
138 }
139
140 /* Accessor for output point dimension */
getOutputDimension() const141 UnsignedInteger LinearGradient::getOutputDimension() const
142 {
143 return constant_.getNbColumns();
144 }
145
146 /* Method save() stores the object through the StorageManager */
save(Advocate & adv) const147 void LinearGradient::save(Advocate & adv) const
148 {
149 GradientImplementation::save(adv);
150 adv.saveAttribute( "center_", center_ );
151 adv.saveAttribute( "constant_", constant_ );
152 adv.saveAttribute( "linear_", linear_ );
153 }
154
155 /* Method load() reloads the object from the StorageManager */
load(Advocate & adv)156 void LinearGradient::load(Advocate & adv)
157 {
158 GradientImplementation::load(adv);
159 adv.loadAttribute( "center_", center_ );
160 adv.loadAttribute( "constant_", constant_ );
161 adv.loadAttribute( "linear_", linear_ );
162 }
163
164 END_NAMESPACE_OPENTURNS
165