1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2017 Imperial College London
5  * Copyright 2013-2017 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/ConjugateGradientDescent.h"
21 
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/ObjectFactory.h"
25 
26 
27 namespace mirtk {
28 
29 
30 // Register energy term with object factory during static initialization
31 mirtkAutoRegisterOptimizerMacro(ConjugateGradientDescent);
32 
33 
34 // =============================================================================
35 // Construction/Destruction
36 // =============================================================================
37 
38 // -----------------------------------------------------------------------------
ConjugateGradientDescent(ObjectiveFunction * f)39 ConjugateGradientDescent::ConjugateGradientDescent(ObjectiveFunction *f)
40 :
41   GradientDescent(f),
42   _UseConjugateGradient(true),
43   _ConjugateTotalGradient(true),
44   _g(nullptr), _h(nullptr)
45 {
46 }
47 
48 // -----------------------------------------------------------------------------
CopyAttributes(const ConjugateGradientDescent & other)49 void ConjugateGradientDescent::CopyAttributes(const ConjugateGradientDescent &other)
50 {
51   _UseConjugateGradient   = other._UseConjugateGradient;
52   _ConjugateTotalGradient = other._ConjugateTotalGradient;
53 
54   Deallocate(_g);
55   if (other._g && _Function) {
56     Allocate(_g, _Function->NumberOfDOFs());
57     memcpy(_g, other._g, _Function->NumberOfDOFs() * sizeof(double));
58   }
59 
60   Deallocate(_h);
61   if (other._h && _Function) {
62     Allocate(_h, _Function->NumberOfDOFs());
63     memcpy(_h, other._h, _Function->NumberOfDOFs() * sizeof(double));
64   }
65 }
66 
67 // -----------------------------------------------------------------------------
ConjugateGradientDescent(const ConjugateGradientDescent & other)68 ConjugateGradientDescent::ConjugateGradientDescent(const ConjugateGradientDescent &other)
69 :
70   GradientDescent(other),
71   _g(nullptr), _h(nullptr)
72 {
73   CopyAttributes(other);
74 }
75 
76 // -----------------------------------------------------------------------------
operator =(const ConjugateGradientDescent & other)77 ConjugateGradientDescent &ConjugateGradientDescent::operator =(const ConjugateGradientDescent &other)
78 {
79   if (this != &other) {
80     GradientDescent::operator =(other);
81     CopyAttributes(other);
82   }
83   return *this;
84 }
85 
86 // -----------------------------------------------------------------------------
~ConjugateGradientDescent()87 ConjugateGradientDescent::~ConjugateGradientDescent()
88 {
89   Deallocate(_g);
90   Deallocate(_h);
91 }
92 
93 // =============================================================================
94 // Parameters
95 // =============================================================================
96 
97 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)98 bool ConjugateGradientDescent::Set(const char *name, const char *value)
99 {
100   if (strcmp(name, "Conjugate total energy gradient") == 0 ||
101       strcmp(name, "Conjugate total gradient") == 0) {
102     return FromString(value, _ConjugateTotalGradient);
103   }
104   return GradientDescent::Set(name, value);
105 }
106 
107 // -----------------------------------------------------------------------------
Parameter() const108 ParameterList ConjugateGradientDescent::Parameter() const
109 {
110   ParameterList params = GradientDescent::Parameter();
111   Insert(params, "Conjugate total energy gradient", _ConjugateTotalGradient);
112   return params;
113 }
114 
115 // =============================================================================
116 // Optimization
117 // =============================================================================
118 
119 // -----------------------------------------------------------------------------
Initialize()120 void ConjugateGradientDescent::Initialize()
121 {
122   GradientDescent::Initialize();
123   Deallocate(_g), Allocate(_g, _Function->NumberOfDOFs());
124   Deallocate(_h), Allocate(_h, _Function->NumberOfDOFs());
125   ResetConjugateGradient();
126 }
127 
128 // -----------------------------------------------------------------------------
Finalize()129 void ConjugateGradientDescent::Finalize()
130 {
131   GradientDescent::Finalize();
132   Deallocate(_g);
133   Deallocate(_h);
134 }
135 
136 // -----------------------------------------------------------------------------
Gradient(double * gradient,double step,bool * sgn_chg)137 void ConjugateGradientDescent::Gradient(double *gradient, double step, bool *sgn_chg)
138 {
139   // Compute gradient of objective function
140   Function()->DataFidelityGradient(gradient, step, sgn_chg);
141   if (!_UseConjugateGradient || _ConjugateTotalGradient) {
142     Function()->AddConstraintGradient(gradient, step, sgn_chg);
143   }
144 
145   // Conjugate gradient
146   if (_UseConjugateGradient) {
147     ConjugateGradient(gradient);
148   } else {
149     ResetConjugateGradient();
150   }
151 
152   // Add non-conjugated constraint gradient
153   if (_UseConjugateGradient && !_ConjugateTotalGradient) {
154     Function()->AddConstraintGradient(gradient, step, sgn_chg);
155   }
156 }
157 
158 // -----------------------------------------------------------------------------
ConjugateGradient(double * gradient)159 void ConjugateGradientDescent::ConjugateGradient(double *gradient)
160 {
161   const int ndofs = _Function->NumberOfDOFs();
162   if (IsNaN(_g[0])) {
163     for (int i = 0; i < ndofs; ++i) _g[i] = -gradient[i];
164     memcpy(_h, _g, ndofs * sizeof(double));
165   } else {
166     double gg  = .0;
167     double dgg = .0;
168     for (int i = 0; i < ndofs; ++i) {
169       gg  += _g[i] * _g[i];
170       dgg += (gradient[i] + _g[i]) * gradient[i];
171     }
172     double gamma = max(dgg / gg, .0);
173     for (int i = 0; i < ndofs; ++i) {
174       _g[i] = -gradient[i];
175       _h[i] = _g[i] + gamma * _h[i];
176       gradient[i] = -_h[i];
177     }
178   }
179 }
180 
181 
182 } // namespace mirtk
183