1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package org.apache.commons.math3.fitting;
18 
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.Collections;
22 import java.util.Comparator;
23 import java.util.List;
24 
25 import org.apache.commons.math3.analysis.function.Gaussian;
26 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
27 import org.apache.commons.math3.exception.NullArgumentException;
28 import org.apache.commons.math3.exception.NumberIsTooSmallException;
29 import org.apache.commons.math3.exception.OutOfRangeException;
30 import org.apache.commons.math3.exception.ZeroException;
31 import org.apache.commons.math3.exception.util.LocalizedFormats;
32 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
33 import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
34 import org.apache.commons.math3.linear.DiagonalMatrix;
35 import org.apache.commons.math3.util.FastMath;
36 
37 /**
38  * Fits points to a {@link
39  * org.apache.commons.math3.analysis.function.Gaussian.Parametric Gaussian}
40  * function.
41  * <br/>
42  * The {@link #withStartPoint(double[]) initial guess values} must be passed
43  * in the following order:
44  * <ul>
45  *  <li>Normalization</li>
46  *  <li>Mean</li>
47  *  <li>Sigma</li>
48  * </ul>
49  * The optimal values will be returned in the same order.
50  *
51  * <p>
52  * Usage example:
53  * <pre>
54  *   WeightedObservedPoints obs = new WeightedObservedPoints();
55  *   obs.add(4.0254623,  531026.0);
56  *   obs.add(4.03128248, 984167.0);
57  *   obs.add(4.03839603, 1887233.0);
58  *   obs.add(4.04421621, 2687152.0);
59  *   obs.add(4.05132976, 3461228.0);
60  *   obs.add(4.05326982, 3580526.0);
61  *   obs.add(4.05779662, 3439750.0);
62  *   obs.add(4.0636168,  2877648.0);
63  *   obs.add(4.06943698, 2175960.0);
64  *   obs.add(4.07525716, 1447024.0);
65  *   obs.add(4.08237071, 717104.0);
66  *   obs.add(4.08366408, 620014.0);
67  *   double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
68  * </pre>
69  *
70  * @since 3.3
71  */
72 public class GaussianCurveFitter extends AbstractCurveFitter {
73     /** Parametric function to be fitted. */
74     private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
75             /** {@inheritDoc} */
76             @Override
77             public double value(double x, double ... p) {
78                 double v = Double.POSITIVE_INFINITY;
79                 try {
80                     v = super.value(x, p);
81                 } catch (NotStrictlyPositiveException e) { // NOPMD
82                     // Do nothing.
83                 }
84                 return v;
85             }
86 
87             /** {@inheritDoc} */
88             @Override
89             public double[] gradient(double x, double ... p) {
90                 double[] v = { Double.POSITIVE_INFINITY,
91                                Double.POSITIVE_INFINITY,
92                                Double.POSITIVE_INFINITY };
93                 try {
94                     v = super.gradient(x, p);
95                 } catch (NotStrictlyPositiveException e) { // NOPMD
96                     // Do nothing.
97                 }
98                 return v;
99             }
100         };
101     /** Initial guess. */
102     private final double[] initialGuess;
103     /** Maximum number of iterations of the optimization algorithm. */
104     private final int maxIter;
105 
106     /**
107      * Contructor used by the factory methods.
108      *
109      * @param initialGuess Initial guess. If set to {@code null}, the initial guess
110      * will be estimated using the {@link ParameterGuesser}.
111      * @param maxIter Maximum number of iterations of the optimization algorithm.
112      */
GaussianCurveFitter(double[] initialGuess, int maxIter)113     private GaussianCurveFitter(double[] initialGuess,
114                                 int maxIter) {
115         this.initialGuess = initialGuess;
116         this.maxIter = maxIter;
117     }
118 
119     /**
120      * Creates a default curve fitter.
121      * The initial guess for the parameters will be {@link ParameterGuesser}
122      * computed automatically, and the maximum number of iterations of the
123      * optimization algorithm is set to {@link Integer#MAX_VALUE}.
124      *
125      * @return a curve fitter.
126      *
127      * @see #withStartPoint(double[])
128      * @see #withMaxIterations(int)
129      */
create()130     public static GaussianCurveFitter create() {
131         return new GaussianCurveFitter(null, Integer.MAX_VALUE);
132     }
133 
134     /**
135      * Configure the start point (initial guess).
136      * @param newStart new start point (initial guess)
137      * @return a new instance.
138      */
withStartPoint(double[] newStart)139     public GaussianCurveFitter withStartPoint(double[] newStart) {
140         return new GaussianCurveFitter(newStart.clone(),
141                                        maxIter);
142     }
143 
144     /**
145      * Configure the maximum number of iterations.
146      * @param newMaxIter maximum number of iterations
147      * @return a new instance.
148      */
withMaxIterations(int newMaxIter)149     public GaussianCurveFitter withMaxIterations(int newMaxIter) {
150         return new GaussianCurveFitter(initialGuess,
151                                        newMaxIter);
152     }
153 
154     /** {@inheritDoc} */
155     @Override
getProblem(Collection<WeightedObservedPoint> observations)156     protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
157 
158         // Prepare least-squares problem.
159         final int len = observations.size();
160         final double[] target  = new double[len];
161         final double[] weights = new double[len];
162 
163         int i = 0;
164         for (WeightedObservedPoint obs : observations) {
165             target[i]  = obs.getY();
166             weights[i] = obs.getWeight();
167             ++i;
168         }
169 
170         final AbstractCurveFitter.TheoreticalValuesFunction model =
171                 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
172 
173         final double[] startPoint = initialGuess != null ?
174             initialGuess :
175             // Compute estimation.
176             new ParameterGuesser(observations).guess();
177 
178         // Return a new least squares problem set up to fit a Gaussian curve to the
179         // observed points.
180         return new LeastSquaresBuilder().
181                 maxEvaluations(Integer.MAX_VALUE).
182                 maxIterations(maxIter).
183                 start(startPoint).
184                 target(target).
185                 weight(new DiagonalMatrix(weights)).
186                 model(model.getModelFunction(), model.getModelFunctionJacobian()).
187                 build();
188 
189     }
190 
191     /**
192      * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
193      * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric}
194      * based on the specified observed points.
195      */
196     public static class ParameterGuesser {
197         /** Normalization factor. */
198         private final double norm;
199         /** Mean. */
200         private final double mean;
201         /** Standard deviation. */
202         private final double sigma;
203 
204         /**
205          * Constructs instance with the specified observed points.
206          *
207          * @param observations Observed points from which to guess the
208          * parameters of the Gaussian.
209          * @throws NullArgumentException if {@code observations} is
210          * {@code null}.
211          * @throws NumberIsTooSmallException if there are less than 3
212          * observations.
213          */
ParameterGuesser(Collection<WeightedObservedPoint> observations)214         public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
215             if (observations == null) {
216                 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
217             }
218             if (observations.size() < 3) {
219                 throw new NumberIsTooSmallException(observations.size(), 3, true);
220             }
221 
222             final List<WeightedObservedPoint> sorted = sortObservations(observations);
223             final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
224 
225             norm = params[0];
226             mean = params[1];
227             sigma = params[2];
228         }
229 
230         /**
231          * Gets an estimation of the parameters.
232          *
233          * @return the guessed parameters, in the following order:
234          * <ul>
235          *  <li>Normalization factor</li>
236          *  <li>Mean</li>
237          *  <li>Standard deviation</li>
238          * </ul>
239          */
guess()240         public double[] guess() {
241             return new double[] { norm, mean, sigma };
242         }
243 
244         /**
245          * Sort the observations.
246          *
247          * @param unsorted Input observations.
248          * @return the input observations, sorted.
249          */
sortObservations(Collection<WeightedObservedPoint> unsorted)250         private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
251             final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);
252 
253             final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
254                 /** {@inheritDoc} */
255                 public int compare(WeightedObservedPoint p1,
256                                    WeightedObservedPoint p2) {
257                     if (p1 == null && p2 == null) {
258                         return 0;
259                     }
260                     if (p1 == null) {
261                         return -1;
262                     }
263                     if (p2 == null) {
264                         return 1;
265                     }
266                     final int cmpX = Double.compare(p1.getX(), p2.getX());
267                     if (cmpX < 0) {
268                         return -1;
269                     }
270                     if (cmpX > 0) {
271                         return 1;
272                     }
273                     final int cmpY = Double.compare(p1.getY(), p2.getY());
274                     if (cmpY < 0) {
275                         return -1;
276                     }
277                     if (cmpY > 0) {
278                         return 1;
279                     }
280                     final int cmpW = Double.compare(p1.getWeight(), p2.getWeight());
281                     if (cmpW < 0) {
282                         return -1;
283                     }
284                     if (cmpW > 0) {
285                         return 1;
286                     }
287                     return 0;
288                 }
289             };
290 
291             Collections.sort(observations, cmp);
292             return observations;
293         }
294 
295         /**
296          * Guesses the parameters based on the specified observed points.
297          *
298          * @param points Observed points, sorted.
299          * @return the guessed parameters (normalization factor, mean and
300          * sigma).
301          */
basicGuess(WeightedObservedPoint[] points)302         private double[] basicGuess(WeightedObservedPoint[] points) {
303             final int maxYIdx = findMaxY(points);
304             final double n = points[maxYIdx].getY();
305             final double m = points[maxYIdx].getX();
306 
307             double fwhmApprox;
308             try {
309                 final double halfY = n + ((m - n) / 2);
310                 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
311                 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
312                 fwhmApprox = fwhmX2 - fwhmX1;
313             } catch (OutOfRangeException e) {
314                 // TODO: Exceptions should not be used for flow control.
315                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
316             }
317             final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
318 
319             return new double[] { n, m, s };
320         }
321 
322         /**
323          * Finds index of point in specified points with the largest Y.
324          *
325          * @param points Points to search.
326          * @return the index in specified points array.
327          */
findMaxY(WeightedObservedPoint[] points)328         private int findMaxY(WeightedObservedPoint[] points) {
329             int maxYIdx = 0;
330             for (int i = 1; i < points.length; i++) {
331                 if (points[i].getY() > points[maxYIdx].getY()) {
332                     maxYIdx = i;
333                 }
334             }
335             return maxYIdx;
336         }
337 
338         /**
339          * Interpolates using the specified points to determine X at the
340          * specified Y.
341          *
342          * @param points Points to use for interpolation.
343          * @param startIdx Index within points from which to start the search for
344          * interpolation bounds points.
345          * @param idxStep Index step for searching interpolation bounds points.
346          * @param y Y value for which X should be determined.
347          * @return the value of X for the specified Y.
348          * @throws ZeroException if {@code idxStep} is 0.
349          * @throws OutOfRangeException if specified {@code y} is not within the
350          * range of the specified {@code points}.
351          */
interpolateXAtY(WeightedObservedPoint[] points, int startIdx, int idxStep, double y)352         private double interpolateXAtY(WeightedObservedPoint[] points,
353                                        int startIdx,
354                                        int idxStep,
355                                        double y)
356             throws OutOfRangeException {
357             if (idxStep == 0) {
358                 throw new ZeroException();
359             }
360             final WeightedObservedPoint[] twoPoints
361                 = getInterpolationPointsForY(points, startIdx, idxStep, y);
362             final WeightedObservedPoint p1 = twoPoints[0];
363             final WeightedObservedPoint p2 = twoPoints[1];
364             if (p1.getY() == y) {
365                 return p1.getX();
366             }
367             if (p2.getY() == y) {
368                 return p2.getX();
369             }
370             return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
371                                 (p2.getY() - p1.getY()));
372         }
373 
374         /**
375          * Gets the two bounding interpolation points from the specified points
376          * suitable for determining X at the specified Y.
377          *
378          * @param points Points to use for interpolation.
379          * @param startIdx Index within points from which to start search for
380          * interpolation bounds points.
381          * @param idxStep Index step for search for interpolation bounds points.
382          * @param y Y value for which X should be determined.
383          * @return the array containing two points suitable for determining X at
384          * the specified Y.
385          * @throws ZeroException if {@code idxStep} is 0.
386          * @throws OutOfRangeException if specified {@code y} is not within the
387          * range of the specified {@code points}.
388          */
getInterpolationPointsForY(WeightedObservedPoint[] points, int startIdx, int idxStep, double y)389         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
390                                                                    int startIdx,
391                                                                    int idxStep,
392                                                                    double y)
393             throws OutOfRangeException {
394             if (idxStep == 0) {
395                 throw new ZeroException();
396             }
397             for (int i = startIdx;
398                  idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
399                  i += idxStep) {
400                 final WeightedObservedPoint p1 = points[i];
401                 final WeightedObservedPoint p2 = points[i + idxStep];
402                 if (isBetween(y, p1.getY(), p2.getY())) {
403                     if (idxStep < 0) {
404                         return new WeightedObservedPoint[] { p2, p1 };
405                     } else {
406                         return new WeightedObservedPoint[] { p1, p2 };
407                     }
408                 }
409             }
410 
411             // Boundaries are replaced by dummy values because the raised
412             // exception is caught and the message never displayed.
413             // TODO: Exceptions should not be used for flow control.
414             throw new OutOfRangeException(y,
415                                           Double.NEGATIVE_INFINITY,
416                                           Double.POSITIVE_INFINITY);
417         }
418 
419         /**
420          * Determines whether a value is between two other values.
421          *
422          * @param value Value to test whether it is between {@code boundary1}
423          * and {@code boundary2}.
424          * @param boundary1 One end of the range.
425          * @param boundary2 Other end of the range.
426          * @return {@code true} if {@code value} is between {@code boundary1} and
427          * {@code boundary2} (inclusive), {@code false} otherwise.
428          */
isBetween(double value, double boundary1, double boundary2)429         private boolean isBetween(double value,
430                                   double boundary1,
431                                   double boundary2) {
432             return (value >= boundary1 && value <= boundary2) ||
433                 (value >= boundary2 && value <= boundary1);
434         }
435     }
436 }
437