1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Copyright 2013-2015 Imperial College London
5 * Copyright 2013-2015 Andreas Schuh
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "mirtk/InexactLineSearch.h"
21
22 #include "mirtk/Memory.h"
23 #include "mirtk/String.h"
24
25
26 namespace mirtk {
27
28
29 // =============================================================================
30 // Construction/Destruction
31 // =============================================================================
32
33 // -----------------------------------------------------------------------------
InexactLineSearch(ObjectiveFunction * f)34 InexactLineSearch::InexactLineSearch(ObjectiveFunction *f)
35 :
36 LineSearch(f),
37 _MaxRejectedStreak (-1),
38 _ReusePreviousStepLength(true),
39 _StrictStepLengthRange (1),
40 _AllowSignChange (NULL),
41 _CurrentDoFValues (NULL),
42 _ScaledDirection (NULL)
43 {
44 if (f) {
45 Allocate(_CurrentDoFValues, f->NumberOfDOFs());
46 Allocate(_ScaledDirection, f->NumberOfDOFs());
47 }
48 }
49
50 // -----------------------------------------------------------------------------
CopyAttributes(const InexactLineSearch & other)51 void InexactLineSearch::CopyAttributes(const InexactLineSearch &other)
52 {
53 Deallocate(_CurrentDoFValues);
54 Deallocate(_ScaledDirection);
55
56 _MaxRejectedStreak = other._MaxRejectedStreak;
57 _ReusePreviousStepLength = other._ReusePreviousStepLength;
58 _StrictStepLengthRange = other._StrictStepLengthRange;
59 _AllowSignChange = other._AllowSignChange;
60
61 if (Function()) {
62 Allocate(_CurrentDoFValues, Function()->NumberOfDOFs());
63 Allocate(_ScaledDirection, Function()->NumberOfDOFs());
64 }
65 }
66
67 // -----------------------------------------------------------------------------
InexactLineSearch(const InexactLineSearch & other)68 InexactLineSearch::InexactLineSearch(const InexactLineSearch &other)
69 :
70 LineSearch(other),
71 _CurrentDoFValues(NULL),
72 _ScaledDirection (NULL)
73 {
74 CopyAttributes(other);
75 }
76
77 // -----------------------------------------------------------------------------
operator =(const InexactLineSearch & other)78 InexactLineSearch &InexactLineSearch::operator =(const InexactLineSearch &other)
79 {
80 if (this != &other) {
81 LineSearch::operator =(other);
82 CopyAttributes(other);
83 }
84 return *this;
85 }
86
87 // -----------------------------------------------------------------------------
~InexactLineSearch()88 InexactLineSearch::~InexactLineSearch()
89 {
90 Deallocate(_CurrentDoFValues);
91 Deallocate(_ScaledDirection);
92 }
93
94 // -----------------------------------------------------------------------------
Function(ObjectiveFunction * f)95 void InexactLineSearch::Function(ObjectiveFunction *f)
96 {
97 if (Function() != f) {
98 Deallocate(_CurrentDoFValues);
99 Deallocate(_ScaledDirection);
100 LineSearch::Function(f);
101 if (f) {
102 Allocate(_CurrentDoFValues, f->NumberOfDOFs());
103 Allocate(_ScaledDirection, f->NumberOfDOFs());
104 }
105 }
106 }
107
108 // =============================================================================
109 // Parameters
110 // =============================================================================
111
112 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)113 bool InexactLineSearch::Set(const char *name, const char *value)
114 {
115 // Maximum number of consectutive rejections
116 if (strcmp(name, "Maximum streak of rejected steps") == 0) {
117 return FromString(value, _MaxRejectedStreak);
118 // Whether to start new search using step length of previous search
119 }
120 if (strcmp(name, "Reuse previous step length") == 0) {
121 return FromString(value, _ReusePreviousStepLength);
122 // Whether [min, max] step length range is strict
123 }
124 if (strcmp(name, "Strict step length range") == 0 ||
125 strcmp(name, "Strict incremental step length range") == 0) {
126 bool limit_increments;
127 if (!FromString(value, limit_increments)) return false;
128 if (limit_increments) {
129 _StrictStepLengthRange |= 1;
130 } else {
131 _StrictStepLengthRange &= ~1;
132 }
133 return true;
134 }
135 if (strcmp(name, "Strict total step length range") == 0 ||
136 strcmp(name, "Strict accumulated step length range") == 0) {
137 bool limit_step;
138 if (!FromString(value, limit_step)) return false;
139 if (limit_step) {
140 _StrictStepLengthRange |= 2;
141 } else {
142 _StrictStepLengthRange &= ~2;
143 }
144 return true;
145 }
146 return LineSearch::Set(name, value);
147 }
148
149 // -----------------------------------------------------------------------------
Parameter() const150 ParameterList InexactLineSearch::Parameter() const
151 {
152 ParameterList params = LineSearch::Parameter();
153 Insert(params, "Maximum streak of rejected steps", _MaxRejectedStreak);
154 Insert(params, "Reuse previous step length", _ReusePreviousStepLength);
155 Insert(params, "Strict incremental step length range", (_StrictStepLengthRange & 1) != 0);
156 Insert(params, "Strict total step length range", (_StrictStepLengthRange & 2) != 0);
157 return params;
158 }
159
160 // =============================================================================
161 // Optimization
162 // =============================================================================
163
164 // -----------------------------------------------------------------------------
Advance(double alpha)165 double InexactLineSearch::Advance(double alpha)
166 {
167 if (_StepLengthUnit == .0 || alpha == .0) return .0;
168 // Backup current function parameter values
169 Function()->Get(_CurrentDoFValues);
170 // Compute gradient for given step length
171 alpha /= _StepLengthUnit;
172 if (_Revert) alpha *= -1.0;
173 const int ndofs = Function()->NumberOfDOFs();
174 for (int dof = 0; dof < ndofs; ++dof) {
175 _ScaledDirection[dof] = alpha * _Direction[dof];
176 }
177 // Set scaled gradient to the negative of the current parameter value if sign
178 // changes are not allowed for this parameter s.t. updated value is zero
179 //
180 // Note: This is used for the optimization of the registration cost
181 // function with L1-norm sparsity constraint on the multi-level
182 // free-form deformation parameters (Wenzhe et al.'s Sparse FFD).
183 if (_AllowSignChange) {
184 double next_value;
185 for (int dof = 0; dof < ndofs; ++dof) {
186 if (_AllowSignChange[dof]) continue;
187 next_value = _CurrentDoFValues[dof] + _ScaledDirection[dof];
188 if ((_CurrentDoFValues[dof] * next_value) <= .0) {
189 _ScaledDirection[dof] = - _CurrentDoFValues[dof];
190 }
191 }
192 }
193 // Update all parameters at once to only trigger a single modified event
194 return Function()->Step(_ScaledDirection);
195 }
196
197 // -----------------------------------------------------------------------------
Retreat(double alpha)198 void InexactLineSearch::Retreat(double alpha)
199 {
200 if (_StepLengthUnit == .0 || alpha == .0) return;
201 Function()->Put(_CurrentDoFValues);
202 }
203
204 // -----------------------------------------------------------------------------
Value(double alpha,double * delta)205 double InexactLineSearch::Value(double alpha, double *delta)
206 {
207 const double max_delta = Advance(alpha);
208 if (delta) *delta = max_delta;
209 Function()->Update(false);
210 const double value = Function()->Value();
211 Retreat(alpha);
212 return value;
213 }
214
215
216 } // namespace mirtk
217