1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Copyright 2013-2017 Imperial College London
5 * Copyright 2013-2017 Andreas Schuh
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "mirtk/ConjugateGradientDescent.h"
21
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/ObjectFactory.h"
25
26
27 namespace mirtk {
28
29
30 // Register energy term with object factory during static initialization
31 mirtkAutoRegisterOptimizerMacro(ConjugateGradientDescent);
32
33
34 // =============================================================================
35 // Construction/Destruction
36 // =============================================================================
37
38 // -----------------------------------------------------------------------------
ConjugateGradientDescent(ObjectiveFunction * f)39 ConjugateGradientDescent::ConjugateGradientDescent(ObjectiveFunction *f)
40 :
41 GradientDescent(f),
42 _UseConjugateGradient(true),
43 _ConjugateTotalGradient(true),
44 _g(nullptr), _h(nullptr)
45 {
46 }
47
48 // -----------------------------------------------------------------------------
CopyAttributes(const ConjugateGradientDescent & other)49 void ConjugateGradientDescent::CopyAttributes(const ConjugateGradientDescent &other)
50 {
51 _UseConjugateGradient = other._UseConjugateGradient;
52 _ConjugateTotalGradient = other._ConjugateTotalGradient;
53
54 Deallocate(_g);
55 if (other._g && _Function) {
56 Allocate(_g, _Function->NumberOfDOFs());
57 memcpy(_g, other._g, _Function->NumberOfDOFs() * sizeof(double));
58 }
59
60 Deallocate(_h);
61 if (other._h && _Function) {
62 Allocate(_h, _Function->NumberOfDOFs());
63 memcpy(_h, other._h, _Function->NumberOfDOFs() * sizeof(double));
64 }
65 }
66
67 // -----------------------------------------------------------------------------
ConjugateGradientDescent(const ConjugateGradientDescent & other)68 ConjugateGradientDescent::ConjugateGradientDescent(const ConjugateGradientDescent &other)
69 :
70 GradientDescent(other),
71 _g(nullptr), _h(nullptr)
72 {
73 CopyAttributes(other);
74 }
75
76 // -----------------------------------------------------------------------------
operator =(const ConjugateGradientDescent & other)77 ConjugateGradientDescent &ConjugateGradientDescent::operator =(const ConjugateGradientDescent &other)
78 {
79 if (this != &other) {
80 GradientDescent::operator =(other);
81 CopyAttributes(other);
82 }
83 return *this;
84 }
85
86 // -----------------------------------------------------------------------------
~ConjugateGradientDescent()87 ConjugateGradientDescent::~ConjugateGradientDescent()
88 {
89 Deallocate(_g);
90 Deallocate(_h);
91 }
92
93 // =============================================================================
94 // Parameters
95 // =============================================================================
96
97 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)98 bool ConjugateGradientDescent::Set(const char *name, const char *value)
99 {
100 if (strcmp(name, "Conjugate total energy gradient") == 0 ||
101 strcmp(name, "Conjugate total gradient") == 0) {
102 return FromString(value, _ConjugateTotalGradient);
103 }
104 return GradientDescent::Set(name, value);
105 }
106
107 // -----------------------------------------------------------------------------
Parameter() const108 ParameterList ConjugateGradientDescent::Parameter() const
109 {
110 ParameterList params = GradientDescent::Parameter();
111 Insert(params, "Conjugate total energy gradient", _ConjugateTotalGradient);
112 return params;
113 }
114
115 // =============================================================================
116 // Optimization
117 // =============================================================================
118
119 // -----------------------------------------------------------------------------
Initialize()120 void ConjugateGradientDescent::Initialize()
121 {
122 GradientDescent::Initialize();
123 Deallocate(_g), Allocate(_g, _Function->NumberOfDOFs());
124 Deallocate(_h), Allocate(_h, _Function->NumberOfDOFs());
125 ResetConjugateGradient();
126 }
127
128 // -----------------------------------------------------------------------------
Finalize()129 void ConjugateGradientDescent::Finalize()
130 {
131 GradientDescent::Finalize();
132 Deallocate(_g);
133 Deallocate(_h);
134 }
135
136 // -----------------------------------------------------------------------------
Gradient(double * gradient,double step,bool * sgn_chg)137 void ConjugateGradientDescent::Gradient(double *gradient, double step, bool *sgn_chg)
138 {
139 // Compute gradient of objective function
140 Function()->DataFidelityGradient(gradient, step, sgn_chg);
141 if (!_UseConjugateGradient || _ConjugateTotalGradient) {
142 Function()->AddConstraintGradient(gradient, step, sgn_chg);
143 }
144
145 // Conjugate gradient
146 if (_UseConjugateGradient) {
147 ConjugateGradient(gradient);
148 } else {
149 ResetConjugateGradient();
150 }
151
152 // Add non-conjugated constraint gradient
153 if (_UseConjugateGradient && !_ConjugateTotalGradient) {
154 Function()->AddConstraintGradient(gradient, step, sgn_chg);
155 }
156 }
157
158 // -----------------------------------------------------------------------------
ConjugateGradient(double * gradient)159 void ConjugateGradientDescent::ConjugateGradient(double *gradient)
160 {
161 const int ndofs = _Function->NumberOfDOFs();
162 if (IsNaN(_g[0])) {
163 for (int i = 0; i < ndofs; ++i) _g[i] = -gradient[i];
164 memcpy(_h, _g, ndofs * sizeof(double));
165 } else {
166 double gg = .0;
167 double dgg = .0;
168 for (int i = 0; i < ndofs; ++i) {
169 gg += _g[i] * _g[i];
170 dgg += (gradient[i] + _g[i]) * gradient[i];
171 }
172 double gamma = max(dgg / gg, .0);
173 for (int i = 0; i < ndofs; ++i) {
174 _g[i] = -gradient[i];
175 _h[i] = _g[i] + gamma * _h[i];
176 gradient[i] = -_h[i];
177 }
178 }
179 }
180
181
182 } // namespace mirtk
183