1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/GradientDescent.h"
21 
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/ObjectFactory.h"
25 
26 #include <algorithm>
27 
28 
29 namespace mirtk {
30 
31 
32 // Register energy term with object factory during static initialization
33 mirtkAutoRegisterOptimizerMacro(GradientDescent);
34 
35 
36 // =============================================================================
37 // Construction/Destruction
38 // =============================================================================
39 
40 // -----------------------------------------------------------------------------
GradientDescent(ObjectiveFunction * f)41 GradientDescent::GradientDescent(ObjectiveFunction *f)
42 :
43   LocalOptimizer(f),
44   _NumberOfRestarts      (-1),
45   _NumberOfFailedRestarts(-1),
46   _LineSearchStrategy    (LS_Adaptive),
47   _LineSearch            (NULL),
48   _LineSearchOwner       (false),
49   _Gradient              (NULL),
50   _AllowSignChange       (NULL)
51 {
52   _EventDelegate.Bind(MakeDelegate(this, &Observable::Broadcast));
53   this->Epsilon(- this->Epsilon()); // relative to current best value
54 }
55 
56 // -----------------------------------------------------------------------------
CopyAttributes(const GradientDescent & other)57 void GradientDescent::CopyAttributes(const GradientDescent &other)
58 {
59   _NumberOfRestarts       = other._NumberOfRestarts;
60   _NumberOfFailedRestarts = other._NumberOfFailedRestarts;
61   _LineSearchStrategy     = other._LineSearchStrategy;
62   _LineSearchParameter    = other._LineSearchParameter;
63   Deallocate(_Gradient);
64   Deallocate(_AllowSignChange);
65   if (_LineSearchOwner) Delete(_LineSearch);
66   if (!other._LineSearchOwner) {
67     _LineSearch      = other._LineSearch;
68     _LineSearchOwner = false;
69   }
70 }
71 
72 // -----------------------------------------------------------------------------
GradientDescent(const GradientDescent & other)73 GradientDescent::GradientDescent(const GradientDescent &other)
74 :
75   LocalOptimizer(other),
76   _Gradient       (NULL),
77   _AllowSignChange(NULL)
78 {
79   _EventDelegate.Bind(MakeDelegate(this, &Observable::Broadcast));
80   CopyAttributes(other);
81 }
82 
83 // -----------------------------------------------------------------------------
operator =(const GradientDescent & other)84 GradientDescent &GradientDescent::operator =(const GradientDescent &other)
85 {
86   if (this != &other) {
87     LocalOptimizer::operator =(other);
88     CopyAttributes(other);
89   }
90   return *this;
91 }
92 
93 // -----------------------------------------------------------------------------
~GradientDescent()94 GradientDescent::~GradientDescent()
95 {
96   Deallocate(_Gradient);
97   Deallocate(_AllowSignChange);
98   if (_LineSearchOwner) Delete(_LineSearch);
99 }
100 
101 // -----------------------------------------------------------------------------
Function(ObjectiveFunction * f)102 void GradientDescent::Function(ObjectiveFunction *f)
103 {
104   // Reallocation of possibly previously allocated gradient vector required
105   // if new function has differing number of DoFs
106   if (!f || !this->Function() || this->Function()->NumberOfDOFs() != f->NumberOfDOFs()) {
107     Deallocate(_Gradient);
108     Deallocate(_AllowSignChange);
109   }
110   LocalOptimizer::Function(f);
111 }
112 
113 // -----------------------------------------------------------------------------
LineSearch(class LineSearch * s,bool transfer_ownership)114 void GradientDescent::LineSearch(class LineSearch *s, bool transfer_ownership)
115 {
116   if (_LineSearchOwner) {
117     Delete(_LineSearch);
118     _LineSearchOwner = false;
119   }
120   if (s) {
121     _LineSearchStrategy = s->Strategy();
122     _LineSearch         = s;
123     _LineSearchOwner    = transfer_ownership;
124   }
125 }
126 
127 // =============================================================================
128 // Parameters
129 // =============================================================================
130 
131 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)132 bool GradientDescent::Set(const char *name, const char *value)
133 {
134   if (strcmp(name, "Maximum no. of restarts")    == 0 ||
135       strcmp(name, "Maximum number of restarts") == 0 ||
136       strcmp(name, "No. of restarts")            == 0 ||
137       strcmp(name, "Number of restarts")         == 0) {
138     return FromString(value, _NumberOfRestarts);
139   }
140   if (strcmp(name, "Maximum no. of failed restarts")          == 0 ||
141       strcmp(name, "Maximum number of failed restarts")       == 0 ||
142       strcmp(name, "No. of failed restarts")                  == 0 ||
143       strcmp(name, "Number of failed restarts")               == 0 ||
144       strcmp(name, "Maximum no. of unsuccessful restarts")    == 0 ||
145       strcmp(name, "Maximum number of unsuccessful restarts") == 0 ||
146       strcmp(name, "No. of unsuccessful restarts")            == 0 ||
147       strcmp(name, "Number of unsuccessful restarts")         == 0 ||
148       strcmp(name, "Maximum streak of failed restarts")       == 0 ||
149       strcmp(name, "Maximum streak of unsuccessful restarts") == 0 ||
150       strcmp(name, "Maximum streak of rejected restarts")     == 0) {
151     return FromString(value, _NumberOfFailedRestarts);
152   }
153   if (strcmp(name, "Line search strategy") == 0) {
154     return FromString(value, _LineSearchStrategy);
155   }
156   if (strstr(name, "line search") != NULL ||
157       strstr(name, "line iterations") != NULL ||
158       strcmp(name, "Maximum streak of rejected steps") == 0    ||
159       strcmp(name, "Length of steps")  == 0 ||
160       strcmp(name, "Minimum length of steps") == 0 ||
161       strcmp(name, "Maximum length of steps") == 0 ||
162       strcmp(name, "Strict step length range") == 0 ||
163       strcmp(name, "Strict total step length range") == 0 ||
164       strcmp(name, "Strict incremental step length range") == 0 ||
165       strcmp(name, "Strict accumulated step length range") == 0 ||
166       strcmp(name, "Step length rise") == 0 ||
167       strcmp(name, "Step length drop") == 0 ||
168       strcmp(name, "Reuse previous step length") == 0) {
169     Insert(_LineSearchParameter, name, value);
170     return true;
171   }
172   return LocalOptimizer::Set(name, value);
173 }
174 
175 // -----------------------------------------------------------------------------
Parameter() const176 ParameterList GradientDescent::Parameter() const
177 {
178   ParameterList params = LocalOptimizer::Parameter();
179   if (_LineSearch && _LineSearch->Strategy() == _LineSearchStrategy) {
180     Insert(params, _LineSearch->Parameter());
181   }
182   Insert(params, _LineSearchParameter);
183   Insert(params, "Maximum no. of restarts",        _NumberOfRestarts);
184   Insert(params, "Maximum no. of failed restarts", _NumberOfFailedRestarts);
185   Insert(params, "Line search strategy",           _LineSearchStrategy);
186   return params;
187 }
188 
189 // =============================================================================
190 // Optimization
191 // =============================================================================
192 
193 // -----------------------------------------------------------------------------
Initialize()194 void GradientDescent::Initialize()
195 {
196   // Initialize base class -- checks that _Function is valid
197   LocalOptimizer::Initialize();
198   // Default values
199   if (_NumberOfRestarts       < 0) _NumberOfRestarts       = 100;
200   if (_NumberOfFailedRestarts < 0) _NumberOfFailedRestarts =   5;
201   // Use existing line search object which was created when user called
202   // Initialize before Run in order to be able to set the line search
203   // parameters directly through the setters instead of the generic Set.
204   if (!_LineSearch || _LineSearch->Strategy() != _LineSearchStrategy) {
205     if (_LineSearchOwner) Delete(_LineSearch);
206     _LineSearch      = LineSearch::New(_LineSearchStrategy);
207     _LineSearchOwner = true;
208   }
209   // Allocate memory for gradient vector if not done before
210   if (!_Gradient)        Allocate(_Gradient,        Function()->NumberOfDOFs());
211   if (!_AllowSignChange) Allocate(_AllowSignChange, Function()->NumberOfDOFs());
212   // Initialize line search object
213   _LineSearch->Function   (Function());
214   _LineSearch->Parameter  (_LineSearchParameter);
215   _LineSearch->Delta      (_Delta);
216   _LineSearch->Epsilon    (max(.0, _Epsilon));
217   _LineSearch->Direction  (_Gradient);
218   _LineSearch->Revert     (true);
219   _LineSearch->AddObserver(_EventDelegate);
220   _LineSearch->Initialize();
221   // Check line search parameters
222   if (_LineSearch->MaxStepLength() == .0) {
223     cerr << this->NameOfClass() << "::Initialize: Line search interval length is zero!" << endl;
224     cerr << "  Check the \"Minimum/Maximum length of steps\" line search parameter." << endl;
225     exit(1);
226   }
227 }
228 
229 // -----------------------------------------------------------------------------
Run()230 double GradientDescent::Run()
231 {
232   double value     = numeric_limits<double>::max();
233   int    nrestarts = 0; // Number of restarts after convergence
234   int    nfailed   = 0; // Number of consecutive restarts without improvement
235 
236   // Initialize
237   this->Initialize();
238 
239   // Initial line search parameters
240   const double initial_delta    = _LineSearch->Delta();
241   const double initial_epsilon  = _LineSearch->Epsilon();
242   const double initial_step     = _LineSearch->StepLength();
243   const double initial_min_step = _LineSearch->MinStepLength();
244   const double initial_max_step = _LineSearch->MaxStepLength();
245 
246   // Perform initial update of energy function before StartEvent because
247   // it may trigger some further delayed initialization with LogEvent's
248   Function()->Update(true);
249 
250   // Notify observers about start of optimization
251   Broadcast(StartEvent);
252 
253   // Total number of performed gradient steps
254   Iteration step(0, _NumberOfSteps);
255 
256   // Repeat gradient descent optimization for each restart with modified
257   // energy function. If energy remains fixed, only one iteration is done.
258   while (true) {
259 
260     // Get initial value of (modified) energy function
261     value = _Function->Value();
262     _LastValues.clear();
263     _LastValues.push_back(value);
264 
265     // Current number of iterations
266     const int iter = step.Iter();
267 
268     // Descent along computed gradient
269     _Converged = false;
270     while (!_Converged && step.Next()) {
271 
272       // Notify observers about start of gradient descent iteration
273       Broadcast(IterationStartEvent, &step);
274 
275       // Update current best value
276       _LineSearch->CurrentValue(value);
277 
278       // Compute gradient of objective function
279       if (step.Iter() > 1) Function()->Update(true);
280       this->Gradient(_Gradient, _LineSearch->StepLength(), _AllowSignChange);
281 
282       // Adjust step length range if necessary
283       //
284       // This is required for the optimization of the registration cost function
285       // with L1-norm sparsity constraint on the multi-level free-form deformation
286       // parameters (Wenzhe et al.'s Sparse FFD). Furthermore, if the gradient of
287       // an energy term is approximated using finite differences, this energy term
288       // will set min_step = max_step such that the step length corresponds to
289       // the one with which the gradient was approximated.
290       double max_norm = Function()->GradientNorm(_Gradient);
291       double min_step = _LineSearch->MinStepLength();
292       double max_step = _LineSearch->MaxStepLength();
293       Function()->GradientStep(_Gradient, min_step, max_step);
294 
295       // Set line search range
296       _LineSearch->MinStepLength (min_step);
297       _LineSearch->MaxStepLength (max_step);
298       _LineSearch->StepLengthUnit(max_norm);
299 
300       // Perform line search along computed gradient direction
301       value = _LineSearch->Run();
302 
303       // Adjust epsilon if relative to current best value, i.e.,
304       // epsilon parameter is set to a negative value
305       if (_Epsilon < .0) _LineSearch->Epsilon(abs(_Epsilon * value));
306 
307       // Check convergence
308       if (_LineSearch->StepLength() > 0.) {
309         _Converged = this->Converged(step.Iter(), value, _Gradient);
310       } else {
311         _Converged = true;
312       }
313 
314       // Notify observers about end of gradient descent iteration
315       Broadcast(IterationEndEvent, &step);
316     }
317 
318     // Stop if previous restart did not bring any improvement
319     if (step.Iter() == (iter + 1) && value == _LineSearch->CurrentValue()) {
320       ++nfailed;
321       if (nfailed >= _NumberOfFailedRestarts) break;
322     } else {
323       nfailed = 0;
324     }
325 
326     // Update current best value
327     _LineSearch->CurrentValue(value);
328 
329     // Stop if maximum number of iterations exceeded
330     if (step.End()) break;
331 
332     // Stop if maximum number of allowed restarts exceeded
333     if (nrestarts >= _NumberOfRestarts) break;
334     ++nrestarts;
335 
336     // If there was no improvement compared to the previous converged
337     // iterative gradient descent, or the objective function remains
338     // unmodified, stop here. Otherwise, start another optimization of
339     // the amended objective (cf. FiducialRegistrationError).
340     // Such restart realizes an alternating optimization, where the
341     // ObjectiveFunction::Upgrade is performing the optimization w.r.t.
342     // some function parameters different from the "public" DoFs.
343     if (!_Function->Upgrade()) break;
344 
345     // Update energy function for initial value evaluation as well as for
346     // the first gradient computation
347     Function()->Update(true);
348 
349     // Reset line search parameters
350     _LineSearch->Delta        (initial_delta);
351     _LineSearch->Epsilon      (initial_epsilon);
352     _LineSearch->StepLength   (initial_step);
353     _LineSearch->MinStepLength(initial_min_step);
354     _LineSearch->MaxStepLength(initial_max_step);
355 
356     // Notify observers about restart
357     Broadcast(RestartEvent);
358   }
359 
360   // Notify observers about end of optimization
361   Broadcast(EndEvent, &value);
362 
363   // Finalize
364   this->Finalize();
365 
366   return value;
367 }
368 
369 // -----------------------------------------------------------------------------
Gradient(double * gradient,double step,bool * sgn_chg)370 void GradientDescent::Gradient(double *gradient, double step, bool *sgn_chg)
371 {
372   Function()->Gradient(gradient, step, sgn_chg);
373 }
374 
375 // -----------------------------------------------------------------------------
Finalize()376 void GradientDescent::Finalize()
377 {
378   Deallocate(_Gradient);
379   Deallocate(_AllowSignChange);
380   _LineSearch->DeleteObserver(_EventDelegate);
381   if (_LineSearchOwner) Delete(_LineSearch);
382 }
383 
384 
385 } // namespace mirtk
386