1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Copyright 2013-2015 Imperial College London
5 * Copyright 2013-2015 Andreas Schuh
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "mirtk/GradientDescent.h"
21
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/ObjectFactory.h"
25
26 #include <algorithm>
27
28
29 namespace mirtk {
30
31
32 // Register energy term with object factory during static initialization
33 mirtkAutoRegisterOptimizerMacro(GradientDescent);
34
35
36 // =============================================================================
37 // Construction/Destruction
38 // =============================================================================
39
40 // -----------------------------------------------------------------------------
GradientDescent(ObjectiveFunction * f)41 GradientDescent::GradientDescent(ObjectiveFunction *f)
42 :
43 LocalOptimizer(f),
44 _NumberOfRestarts (-1),
45 _NumberOfFailedRestarts(-1),
46 _LineSearchStrategy (LS_Adaptive),
47 _LineSearch (NULL),
48 _LineSearchOwner (false),
49 _Gradient (NULL),
50 _AllowSignChange (NULL)
51 {
52 _EventDelegate.Bind(MakeDelegate(this, &Observable::Broadcast));
53 this->Epsilon(- this->Epsilon()); // relative to current best value
54 }
55
56 // -----------------------------------------------------------------------------
CopyAttributes(const GradientDescent & other)57 void GradientDescent::CopyAttributes(const GradientDescent &other)
58 {
59 _NumberOfRestarts = other._NumberOfRestarts;
60 _NumberOfFailedRestarts = other._NumberOfFailedRestarts;
61 _LineSearchStrategy = other._LineSearchStrategy;
62 _LineSearchParameter = other._LineSearchParameter;
63 Deallocate(_Gradient);
64 Deallocate(_AllowSignChange);
65 if (_LineSearchOwner) Delete(_LineSearch);
66 if (!other._LineSearchOwner) {
67 _LineSearch = other._LineSearch;
68 _LineSearchOwner = false;
69 }
70 }
71
72 // -----------------------------------------------------------------------------
GradientDescent(const GradientDescent & other)73 GradientDescent::GradientDescent(const GradientDescent &other)
74 :
75 LocalOptimizer(other),
76 _Gradient (NULL),
77 _AllowSignChange(NULL)
78 {
79 _EventDelegate.Bind(MakeDelegate(this, &Observable::Broadcast));
80 CopyAttributes(other);
81 }
82
83 // -----------------------------------------------------------------------------
operator =(const GradientDescent & other)84 GradientDescent &GradientDescent::operator =(const GradientDescent &other)
85 {
86 if (this != &other) {
87 LocalOptimizer::operator =(other);
88 CopyAttributes(other);
89 }
90 return *this;
91 }
92
93 // -----------------------------------------------------------------------------
~GradientDescent()94 GradientDescent::~GradientDescent()
95 {
96 Deallocate(_Gradient);
97 Deallocate(_AllowSignChange);
98 if (_LineSearchOwner) Delete(_LineSearch);
99 }
100
101 // -----------------------------------------------------------------------------
Function(ObjectiveFunction * f)102 void GradientDescent::Function(ObjectiveFunction *f)
103 {
104 // Reallocation of possibly previously allocated gradient vector required
105 // if new function has differing number of DoFs
106 if (!f || !this->Function() || this->Function()->NumberOfDOFs() != f->NumberOfDOFs()) {
107 Deallocate(_Gradient);
108 Deallocate(_AllowSignChange);
109 }
110 LocalOptimizer::Function(f);
111 }
112
113 // -----------------------------------------------------------------------------
LineSearch(class LineSearch * s,bool transfer_ownership)114 void GradientDescent::LineSearch(class LineSearch *s, bool transfer_ownership)
115 {
116 if (_LineSearchOwner) {
117 Delete(_LineSearch);
118 _LineSearchOwner = false;
119 }
120 if (s) {
121 _LineSearchStrategy = s->Strategy();
122 _LineSearch = s;
123 _LineSearchOwner = transfer_ownership;
124 }
125 }
126
127 // =============================================================================
128 // Parameters
129 // =============================================================================
130
131 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)132 bool GradientDescent::Set(const char *name, const char *value)
133 {
134 if (strcmp(name, "Maximum no. of restarts") == 0 ||
135 strcmp(name, "Maximum number of restarts") == 0 ||
136 strcmp(name, "No. of restarts") == 0 ||
137 strcmp(name, "Number of restarts") == 0) {
138 return FromString(value, _NumberOfRestarts);
139 }
140 if (strcmp(name, "Maximum no. of failed restarts") == 0 ||
141 strcmp(name, "Maximum number of failed restarts") == 0 ||
142 strcmp(name, "No. of failed restarts") == 0 ||
143 strcmp(name, "Number of failed restarts") == 0 ||
144 strcmp(name, "Maximum no. of unsuccessful restarts") == 0 ||
145 strcmp(name, "Maximum number of unsuccessful restarts") == 0 ||
146 strcmp(name, "No. of unsuccessful restarts") == 0 ||
147 strcmp(name, "Number of unsuccessful restarts") == 0 ||
148 strcmp(name, "Maximum streak of failed restarts") == 0 ||
149 strcmp(name, "Maximum streak of unsuccessful restarts") == 0 ||
150 strcmp(name, "Maximum streak of rejected restarts") == 0) {
151 return FromString(value, _NumberOfFailedRestarts);
152 }
153 if (strcmp(name, "Line search strategy") == 0) {
154 return FromString(value, _LineSearchStrategy);
155 }
156 if (strstr(name, "line search") != NULL ||
157 strstr(name, "line iterations") != NULL ||
158 strcmp(name, "Maximum streak of rejected steps") == 0 ||
159 strcmp(name, "Length of steps") == 0 ||
160 strcmp(name, "Minimum length of steps") == 0 ||
161 strcmp(name, "Maximum length of steps") == 0 ||
162 strcmp(name, "Strict step length range") == 0 ||
163 strcmp(name, "Strict total step length range") == 0 ||
164 strcmp(name, "Strict incremental step length range") == 0 ||
165 strcmp(name, "Strict accumulated step length range") == 0 ||
166 strcmp(name, "Step length rise") == 0 ||
167 strcmp(name, "Step length drop") == 0 ||
168 strcmp(name, "Reuse previous step length") == 0) {
169 Insert(_LineSearchParameter, name, value);
170 return true;
171 }
172 return LocalOptimizer::Set(name, value);
173 }
174
175 // -----------------------------------------------------------------------------
Parameter() const176 ParameterList GradientDescent::Parameter() const
177 {
178 ParameterList params = LocalOptimizer::Parameter();
179 if (_LineSearch && _LineSearch->Strategy() == _LineSearchStrategy) {
180 Insert(params, _LineSearch->Parameter());
181 }
182 Insert(params, _LineSearchParameter);
183 Insert(params, "Maximum no. of restarts", _NumberOfRestarts);
184 Insert(params, "Maximum no. of failed restarts", _NumberOfFailedRestarts);
185 Insert(params, "Line search strategy", _LineSearchStrategy);
186 return params;
187 }
188
189 // =============================================================================
190 // Optimization
191 // =============================================================================
192
193 // -----------------------------------------------------------------------------
Initialize()194 void GradientDescent::Initialize()
195 {
196 // Initialize base class -- checks that _Function is valid
197 LocalOptimizer::Initialize();
198 // Default values
199 if (_NumberOfRestarts < 0) _NumberOfRestarts = 100;
200 if (_NumberOfFailedRestarts < 0) _NumberOfFailedRestarts = 5;
201 // Use existing line search object which was created when user called
202 // Initialize before Run in order to be able to set the line search
203 // parameters directly through the setters instead of the generic Set.
204 if (!_LineSearch || _LineSearch->Strategy() != _LineSearchStrategy) {
205 if (_LineSearchOwner) Delete(_LineSearch);
206 _LineSearch = LineSearch::New(_LineSearchStrategy);
207 _LineSearchOwner = true;
208 }
209 // Allocate memory for gradient vector if not done before
210 if (!_Gradient) Allocate(_Gradient, Function()->NumberOfDOFs());
211 if (!_AllowSignChange) Allocate(_AllowSignChange, Function()->NumberOfDOFs());
212 // Initialize line search object
213 _LineSearch->Function (Function());
214 _LineSearch->Parameter (_LineSearchParameter);
215 _LineSearch->Delta (_Delta);
216 _LineSearch->Epsilon (max(.0, _Epsilon));
217 _LineSearch->Direction (_Gradient);
218 _LineSearch->Revert (true);
219 _LineSearch->AddObserver(_EventDelegate);
220 _LineSearch->Initialize();
221 // Check line search parameters
222 if (_LineSearch->MaxStepLength() == .0) {
223 cerr << this->NameOfClass() << "::Initialize: Line search interval length is zero!" << endl;
224 cerr << " Check the \"Minimum/Maximum length of steps\" line search parameter." << endl;
225 exit(1);
226 }
227 }
228
229 // -----------------------------------------------------------------------------
Run()230 double GradientDescent::Run()
231 {
232 double value = numeric_limits<double>::max();
233 int nrestarts = 0; // Number of restarts after convergence
234 int nfailed = 0; // Number of consecutive restarts without improvement
235
236 // Initialize
237 this->Initialize();
238
239 // Initial line search parameters
240 const double initial_delta = _LineSearch->Delta();
241 const double initial_epsilon = _LineSearch->Epsilon();
242 const double initial_step = _LineSearch->StepLength();
243 const double initial_min_step = _LineSearch->MinStepLength();
244 const double initial_max_step = _LineSearch->MaxStepLength();
245
246 // Perform initial update of energy function before StartEvent because
247 // it may trigger some further delayed initialization with LogEvent's
248 Function()->Update(true);
249
250 // Notify observers about start of optimization
251 Broadcast(StartEvent);
252
253 // Total number of performed gradient steps
254 Iteration step(0, _NumberOfSteps);
255
256 // Repeat gradient descent optimization for each restart with modified
257 // energy function. If energy remains fixed, only one iteration is done.
258 while (true) {
259
260 // Get initial value of (modified) energy function
261 value = _Function->Value();
262 _LastValues.clear();
263 _LastValues.push_back(value);
264
265 // Current number of iterations
266 const int iter = step.Iter();
267
268 // Descent along computed gradient
269 _Converged = false;
270 while (!_Converged && step.Next()) {
271
272 // Notify observers about start of gradient descent iteration
273 Broadcast(IterationStartEvent, &step);
274
275 // Update current best value
276 _LineSearch->CurrentValue(value);
277
278 // Compute gradient of objective function
279 if (step.Iter() > 1) Function()->Update(true);
280 this->Gradient(_Gradient, _LineSearch->StepLength(), _AllowSignChange);
281
282 // Adjust step length range if necessary
283 //
284 // This is required for the optimization of the registration cost function
285 // with L1-norm sparsity constraint on the multi-level free-form deformation
286 // parameters (Wenzhe et al.'s Sparse FFD). Furthermore, if the gradient of
287 // an energy term is approximated using finite differences, this energy term
288 // will set min_step = max_step such that the step length corresponds to
289 // the one with which the gradient was approximated.
290 double max_norm = Function()->GradientNorm(_Gradient);
291 double min_step = _LineSearch->MinStepLength();
292 double max_step = _LineSearch->MaxStepLength();
293 Function()->GradientStep(_Gradient, min_step, max_step);
294
295 // Set line search range
296 _LineSearch->MinStepLength (min_step);
297 _LineSearch->MaxStepLength (max_step);
298 _LineSearch->StepLengthUnit(max_norm);
299
300 // Perform line search along computed gradient direction
301 value = _LineSearch->Run();
302
303 // Adjust epsilon if relative to current best value, i.e.,
304 // epsilon parameter is set to a negative value
305 if (_Epsilon < .0) _LineSearch->Epsilon(abs(_Epsilon * value));
306
307 // Check convergence
308 if (_LineSearch->StepLength() > 0.) {
309 _Converged = this->Converged(step.Iter(), value, _Gradient);
310 } else {
311 _Converged = true;
312 }
313
314 // Notify observers about end of gradient descent iteration
315 Broadcast(IterationEndEvent, &step);
316 }
317
318 // Stop if previous restart did not bring any improvement
319 if (step.Iter() == (iter + 1) && value == _LineSearch->CurrentValue()) {
320 ++nfailed;
321 if (nfailed >= _NumberOfFailedRestarts) break;
322 } else {
323 nfailed = 0;
324 }
325
326 // Update current best value
327 _LineSearch->CurrentValue(value);
328
329 // Stop if maximum number of iterations exceeded
330 if (step.End()) break;
331
332 // Stop if maximum number of allowed restarts exceeded
333 if (nrestarts >= _NumberOfRestarts) break;
334 ++nrestarts;
335
336 // If there was no improvement compared to the previous converged
337 // iterative gradient descent, or the objective function remains
338 // unmodified, stop here. Otherwise, start another optimization of
339 // the amended objective (cf. FiducialRegistrationError).
340 // Such restart realizes an alternating optimization, where the
341 // ObjectiveFunction::Upgrade is performing the optimization w.r.t.
342 // some function parameters different from the "public" DoFs.
343 if (!_Function->Upgrade()) break;
344
345 // Update energy function for initial value evaluation as well as for
346 // the first gradient computation
347 Function()->Update(true);
348
349 // Reset line search parameters
350 _LineSearch->Delta (initial_delta);
351 _LineSearch->Epsilon (initial_epsilon);
352 _LineSearch->StepLength (initial_step);
353 _LineSearch->MinStepLength(initial_min_step);
354 _LineSearch->MaxStepLength(initial_max_step);
355
356 // Notify observers about restart
357 Broadcast(RestartEvent);
358 }
359
360 // Notify observers about end of optimization
361 Broadcast(EndEvent, &value);
362
363 // Finalize
364 this->Finalize();
365
366 return value;
367 }
368
369 // -----------------------------------------------------------------------------
Gradient(double * gradient,double step,bool * sgn_chg)370 void GradientDescent::Gradient(double *gradient, double step, bool *sgn_chg)
371 {
372 Function()->Gradient(gradient, step, sgn_chg);
373 }
374
375 // -----------------------------------------------------------------------------
Finalize()376 void GradientDescent::Finalize()
377 {
378 Deallocate(_Gradient);
379 Deallocate(_AllowSignChange);
380 _LineSearch->DeleteObserver(_EventDelegate);
381 if (_LineSearchOwner) Delete(_LineSearch);
382 }
383
384
385 } // namespace mirtk
386