1 //                                               -*- C++ -*-
2 /**
3  *  @brief Class for a quadratic numerical math gradient implementation
4  *
5  *  Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6  *
7  *  This library is free software: you can redistribute it and/or modify
8  *  it under the terms of the GNU Lesser General Public License as published by
9  *  the Free Software Foundation, either version 3 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This library is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU Lesser General Public License for more details.
16  *
17  *  You should have received a copy of the GNU Lesser General Public License
18  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 #include "openturns/LinearGradient.hxx"
22 #include "openturns/PersistentObjectFactory.hxx"
23 #include "openturns/Os.hxx"
24 #include "openturns/Lapack.hxx"
25 
26 
27 BEGIN_NAMESPACE_OPENTURNS
28 
29 
30 
31 CLASSNAMEINIT(LinearGradient)
32 
33 static const Factory<LinearGradient> Factory_LinearGradient;
34 
35 /* Default constructor */
LinearGradient()36 LinearGradient::LinearGradient()
37   : GradientImplementation()
38 {
39   // Nothing to do
40 }
41 
42 /* Parameter constructor */
LinearGradient(const Point & center,const Matrix & constant,const SymmetricTensor & linear)43 LinearGradient::LinearGradient(const Point & center,
44                                const Matrix & constant,
45                                const SymmetricTensor & linear)
46   : GradientImplementation()
47   , center_(center)
48   , constant_(constant)
49   , linear_(linear)
50 {
51   /* Check if the dimensions of the constant term is compatible with the linear term */
52   if ((constant.getNbRows() != linear.getNbRows()) || (constant.getNbColumns() != linear.getNbSheets())) throw InvalidDimensionException(HERE) << "Constant term dimensions are incompatible with the linear term";
53   /* Check if the dimensions of the center term is compatible with the linear term */
54   if ((center.getDimension() != constant.getNbRows()) || (center.getDimension() != linear.getNbRows())) throw InvalidDimensionException(HERE) << "Center term dimensions are incompatible with the constant term or the linear term";
55 }
56 
57 /* Virtual constructor */
clone() const58 LinearGradient * LinearGradient::clone() const
59 {
60   return new LinearGradient(*this);
61 }
62 
63 /* Comparison operator */
operator ==(const LinearGradient & other) const64 Bool LinearGradient::operator ==(const LinearGradient & other) const
65 {
66   return ((linear_ == other.linear_) && (constant_ == other.constant_) && (center_ == other.center_));
67 }
68 
69 /* String converter */
__repr__() const70 String LinearGradient::__repr__() const
71 {
72   OSS oss(true);
73   oss << "class=" << LinearGradient::GetClassName()
74       << " name=" << getName()
75       << " center=" << center_.__repr__()
76       << " constant=" << constant_.__repr__()
77       << " linear=" << linear_;
78   return oss;
79 }
80 
__str__(const String & offset) const81 String LinearGradient::__str__(const String & offset) const
82 {
83   OSS oss(false);
84   oss << LinearGradient::GetClassName() << Os::GetEndOfLine() << offset
85       << "  center :"  << Os::GetEndOfLine() << offset << center_.__str__(offset + "  ") << Os::GetEndOfLine()
86       << "  constant :" << Os::GetEndOfLine() << offset << constant_.__str__(offset + "  ") << Os::GetEndOfLine()
87       << "  linear :" << Os::GetEndOfLine() << offset << linear_.__str__(offset + "  ") << Os::GetEndOfLine();
88   return oss;
89 }
90 
91 /* Accessor for the center */
getCenter() const92 Point LinearGradient::getCenter() const
93 {
94   return center_;
95 }
96 
97 /* Accessor for the constant term */
getConstant() const98 Matrix LinearGradient::getConstant() const
99 {
100   return constant_;
101 }
102 
103 /* Accessor for the linear term */
getLinear() const104 SymmetricTensor LinearGradient::getLinear() const
105 {
106   return linear_;
107 }
108 
109 /* Here is the interface that all derived class must implement */
110 
111 /* Gradient() */
gradient(const Point & inP) const112 Matrix LinearGradient::gradient(const Point & inP) const
113 {
114   if (inP.getDimension() != constant_.getNbRows()) throw InvalidArgumentException(HERE) << "Invalid input dimension";
115   Matrix value(constant_);
116   // Add the linear term <linear, x>
117   const UnsignedInteger nbSheets = linear_.getNbSheets();
118   const UnsignedInteger nbRows = linear_.getNbRows();
119   if (nbSheets == 0 || nbRows == 0)
120     return value;
121   const Point delta(inP - center_);
122   char uplo('L');
123   int n(nbRows);
124   int one(1);
125   double alpha(1.0);
126   double beta(1.0);
127   int luplo(1);
128   for(UnsignedInteger k = 0; k < nbSheets; ++k)
129     dsymv_(&uplo, &n, &alpha, const_cast<double*>(&(linear_(0, 0, k))), &n, const_cast<double*>(&(delta[0])), &one, &beta, &value(0, k), &one, &luplo);
130   callsNumber_.increment();
131   return value;
132 }
133 
134 /* Accessor for input point dimension */
getInputDimension() const135 UnsignedInteger LinearGradient::getInputDimension() const
136 {
137   return center_.getDimension();
138 }
139 
140 /* Accessor for output point dimension */
getOutputDimension() const141 UnsignedInteger LinearGradient::getOutputDimension() const
142 {
143   return constant_.getNbColumns();
144 }
145 
146 /* Method save() stores the object through the StorageManager */
save(Advocate & adv) const147 void LinearGradient::save(Advocate & adv) const
148 {
149   GradientImplementation::save(adv);
150   adv.saveAttribute( "center_", center_ );
151   adv.saveAttribute( "constant_", constant_ );
152   adv.saveAttribute( "linear_", linear_ );
153 }
154 
155 /* Method load() reloads the object from the StorageManager */
load(Advocate & adv)156 void LinearGradient::load(Advocate & adv)
157 {
158   GradientImplementation::load(adv);
159   adv.loadAttribute( "center_", center_ );
160   adv.loadAttribute( "constant_", constant_ );
161   adv.loadAttribute( "linear_", linear_ );
162 }
163 
164 END_NAMESPACE_OPENTURNS
165