1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/InexactLineSearch.h"
21 
22 #include "mirtk/Memory.h"
23 #include "mirtk/String.h"
24 
25 
26 namespace mirtk {
27 
28 
29 // =============================================================================
30 // Construction/Destruction
31 // =============================================================================
32 
33 // -----------------------------------------------------------------------------
InexactLineSearch(ObjectiveFunction * f)34 InexactLineSearch::InexactLineSearch(ObjectiveFunction *f)
35 :
36   LineSearch(f),
37   _MaxRejectedStreak      (-1),
38   _ReusePreviousStepLength(true),
39   _StrictStepLengthRange  (1),
40   _AllowSignChange        (NULL),
41   _CurrentDoFValues       (NULL),
42   _ScaledDirection        (NULL)
43 {
44   if (f) {
45     Allocate(_CurrentDoFValues, f->NumberOfDOFs());
46     Allocate(_ScaledDirection,  f->NumberOfDOFs());
47   }
48 }
49 
50 // -----------------------------------------------------------------------------
CopyAttributes(const InexactLineSearch & other)51 void InexactLineSearch::CopyAttributes(const InexactLineSearch &other)
52 {
53   Deallocate(_CurrentDoFValues);
54   Deallocate(_ScaledDirection);
55 
56   _MaxRejectedStreak       = other._MaxRejectedStreak;
57   _ReusePreviousStepLength = other._ReusePreviousStepLength;
58   _StrictStepLengthRange   = other._StrictStepLengthRange;
59   _AllowSignChange         = other._AllowSignChange;
60 
61   if (Function()) {
62     Allocate(_CurrentDoFValues, Function()->NumberOfDOFs());
63     Allocate(_ScaledDirection,  Function()->NumberOfDOFs());
64   }
65 }
66 
67 // -----------------------------------------------------------------------------
InexactLineSearch(const InexactLineSearch & other)68 InexactLineSearch::InexactLineSearch(const InexactLineSearch &other)
69 :
70   LineSearch(other),
71   _CurrentDoFValues(NULL),
72   _ScaledDirection (NULL)
73 {
74   CopyAttributes(other);
75 }
76 
77 // -----------------------------------------------------------------------------
operator =(const InexactLineSearch & other)78 InexactLineSearch &InexactLineSearch::operator =(const InexactLineSearch &other)
79 {
80   if (this != &other) {
81     LineSearch::operator =(other);
82     CopyAttributes(other);
83   }
84   return *this;
85 }
86 
87 // -----------------------------------------------------------------------------
~InexactLineSearch()88 InexactLineSearch::~InexactLineSearch()
89 {
90   Deallocate(_CurrentDoFValues);
91   Deallocate(_ScaledDirection);
92 }
93 
94 // -----------------------------------------------------------------------------
Function(ObjectiveFunction * f)95 void InexactLineSearch::Function(ObjectiveFunction *f)
96 {
97   if (Function() != f) {
98     Deallocate(_CurrentDoFValues);
99     Deallocate(_ScaledDirection);
100     LineSearch::Function(f);
101     if (f) {
102       Allocate(_CurrentDoFValues, f->NumberOfDOFs());
103       Allocate(_ScaledDirection,  f->NumberOfDOFs());
104     }
105   }
106 }
107 
108 // =============================================================================
109 // Parameters
110 // =============================================================================
111 
112 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)113 bool InexactLineSearch::Set(const char *name, const char *value)
114 {
115   // Maximum number of consectutive rejections
116   if (strcmp(name, "Maximum streak of rejected steps") == 0) {
117     return FromString(value, _MaxRejectedStreak);
118   // Whether to start new search using step length of previous search
119   }
120   if (strcmp(name, "Reuse previous step length") == 0) {
121     return FromString(value, _ReusePreviousStepLength);
122   // Whether [min, max] step length range is strict
123   }
124   if (strcmp(name, "Strict step length range")             == 0 ||
125       strcmp(name, "Strict incremental step length range") == 0) {
126     bool limit_increments;
127     if (!FromString(value, limit_increments)) return false;
128     if (limit_increments) {
129       _StrictStepLengthRange |= 1;
130     } else {
131       _StrictStepLengthRange &= ~1;
132     }
133     return true;
134   }
135   if (strcmp(name, "Strict total step length range")       == 0 ||
136       strcmp(name, "Strict accumulated step length range") == 0) {
137     bool limit_step;
138     if (!FromString(value, limit_step)) return false;
139     if (limit_step) {
140       _StrictStepLengthRange |= 2;
141     } else {
142       _StrictStepLengthRange &= ~2;
143     }
144     return true;
145   }
146   return LineSearch::Set(name, value);
147 }
148 
149 // -----------------------------------------------------------------------------
Parameter() const150 ParameterList InexactLineSearch::Parameter() const
151 {
152   ParameterList params = LineSearch::Parameter();
153   Insert(params, "Maximum streak of rejected steps", _MaxRejectedStreak);
154   Insert(params, "Reuse previous step length",       _ReusePreviousStepLength);
155   Insert(params, "Strict incremental step length range", (_StrictStepLengthRange & 1) != 0);
156   Insert(params, "Strict total step length range", (_StrictStepLengthRange & 2) != 0);
157   return params;
158 }
159 
160 // =============================================================================
161 // Optimization
162 // =============================================================================
163 
164 // -----------------------------------------------------------------------------
Advance(double alpha)165 double InexactLineSearch::Advance(double alpha)
166 {
167   if (_StepLengthUnit == .0 || alpha == .0) return .0;
168   // Backup current function parameter values
169   Function()->Get(_CurrentDoFValues);
170   // Compute gradient for given step length
171   alpha /= _StepLengthUnit;
172   if (_Revert) alpha *= -1.0;
173   const int ndofs = Function()->NumberOfDOFs();
174   for (int dof = 0; dof < ndofs; ++dof) {
175     _ScaledDirection[dof] = alpha * _Direction[dof];
176   }
177   // Set scaled gradient to the negative of the current parameter value if sign
178   // changes are not allowed for this parameter s.t. updated value is zero
179   //
180   // Note: This is used for the optimization of the registration cost
181   //       function with L1-norm sparsity constraint on the multi-level
182   //       free-form deformation parameters (Wenzhe et al.'s Sparse FFD).
183   if (_AllowSignChange) {
184     double next_value;
185     for (int dof = 0; dof < ndofs; ++dof) {
186       if (_AllowSignChange[dof]) continue;
187       next_value = _CurrentDoFValues[dof] + _ScaledDirection[dof];
188       if ((_CurrentDoFValues[dof] * next_value) <= .0) {
189         _ScaledDirection[dof] = - _CurrentDoFValues[dof];
190       }
191     }
192   }
193   // Update all parameters at once to only trigger a single modified event
194   return Function()->Step(_ScaledDirection);
195 }
196 
197 // -----------------------------------------------------------------------------
Retreat(double alpha)198 void InexactLineSearch::Retreat(double alpha)
199 {
200   if (_StepLengthUnit == .0 || alpha == .0) return;
201   Function()->Put(_CurrentDoFValues);
202 }
203 
204 // -----------------------------------------------------------------------------
Value(double alpha,double * delta)205 double InexactLineSearch::Value(double alpha, double *delta)
206 {
207   const double max_delta = Advance(alpha);
208   if (delta) *delta = max_delta;
209   Function()->Update(false);
210   const double value = Function()->Value();
211   Retreat(alpha);
212   return value;
213 }
214 
215 
216 } // namespace mirtk
217