1# fit.py
2# fitting plotter
3
4#    Copyright (C) 2005 Jeremy S. Sanders
5#    Email: Jeremy Sanders <jeremy@jeremysanders.net>
6#
7#    This program is free software; you can redistribute it and/or modify
8#    it under the terms of the GNU General Public License as published by
9#    the Free Software Foundation; either version 2 of the License, or
10#    (at your option) any later version.
11#
12#    This program is distributed in the hope that it will be useful,
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15#    GNU General Public License for more details.
16#
17#    You should have received a copy of the GNU General Public License along
18#    with this program; if not, write to the Free Software Foundation, Inc.,
19#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
20###############################################################################
21
22from __future__ import division, absolute_import, print_function
23import re
24import sys
25
26import numpy as N
27
28from ..compat import czip, cstr
29from .. import document
30from .. import setting
31from .. import utils
32from .. import qtall as qt
33
34from .function import FunctionPlotter
35from . import widget
36
37# try importing iminuit first, then minuit, then None
38try:
39    import iminuit as minuit
40except ImportError:
41    try:
42        import minuit
43    except ImportError:
44        minuit = None
45
46def _(text, disambiguation=None, context='Fit'):
47    """Translate text."""
48    return qt.QCoreApplication.translate(context, text, disambiguation)
49
50def minuitFit(evalfunc, params, names, values, xvals, yvals, yserr):
51    """Do fitting with minuit (if installed)."""
52
53    def chi2(params):
54        """generate a lambda function to impedance-match between PyMinuit's
55        use of multiple parameters versus our use of a single numpy vector."""
56        c = ((evalfunc(params, xvals) - yvals)**2 / yserr**2).sum()
57        if chi2.runningFit:
58            chi2.iters += 1
59            p = [chi2.iters, c] + params.tolist()
60            str = ("%5i " + "%8g " * (len(params)+1)) % tuple(p)
61            print(str)
62
63        return c
64
65    namestr = ', '.join(names)
66    fnstr = 'lambda %s: chi2(N.array([%s]))' % (namestr, namestr)
67
68    # this is safe because the only user-controlled variable is len(names)
69    fn = eval(fnstr, {'chi2' : chi2, 'N' : N})
70
71    print(_('Fitting via Minuit:'))
72    m = minuit.Minuit(fn, **values)
73
74    # run the fit
75    chi2.runningFit = True
76    chi2.iters = 0
77    m.migrad()
78
79    # do some error analysis
80    have_symerr, have_err = False, False
81    try:
82        chi2.runningFit = False
83        m.hesse()
84        have_symerr = True
85        m.minos()
86        have_err = True
87    except Exception as e:
88        print(e)
89        if str(e).startswith('Discovered a new minimum'):
90            # the initial fit really failed
91            raise
92
93    # print the results
94    retchi2 = m.fval
95    dof = len(yvals) - len(params)
96    redchi2 = retchi2 / dof
97
98    if have_err:
99        print(_('Fit results:\n') + "\n".join([
100                    u"    %s = %g \u00b1 %g (+%g / %g)"
101                    % (n, m.values[n], m.errors[n], m.merrors[(n, 1.0)],
102                       m.merrors[(n, -1.0)]) for n in names]))
103    elif have_symerr:
104        print(_('Fit results:\n') + "\n".join([
105                    u"    %s = %g \u00b1 %g" % (n, m.values[n], m.errors[n])
106                    for n in names]))
107        print(_('MINOS error estimate not available.'))
108    else:
109        print(_('Fit results:\n') + "\n".join([
110                    '    %s = %g' % (n, m.values[n]) for n in names]))
111        print(_('No error analysis available: fit quality uncertain'))
112
113    print("chi^2 = %g, dof = %i, reduced-chi^2 = %g" % (retchi2, dof, redchi2))
114
115    vals = dict(m.values)
116    return vals, retchi2, dof
117
118class Fit(FunctionPlotter):
119    """A plotter to fit a function to data."""
120
121    typename='fit'
122    allowusercreation=True
123    description=_('Fit a function to data')
124
125    def __init__(self, parent, name=None):
126        FunctionPlotter.__init__(self, parent, name=name)
127
128        self.addAction( widget.Action('fit', self.actionFit,
129                                      descr = _('Fit function'),
130                                      usertext = _('Fit function')) )
131
132    @classmethod
133    def addSettings(klass, s):
134        """Construct list of settings."""
135        FunctionPlotter.addSettings(s)
136
137        s.add( setting.FloatDict(
138                'values',
139                {'a': 0.0, 'b': 1.0},
140                descr = _('Variables and fit values'),
141                usertext=_('Parameters')), 1 )
142        s.add( setting.DatasetExtended(
143                'xData', 'x',
144                descr = _('X data to fit (dataset name, list of values '
145                          'or expression)'),
146                usertext=_('X data')), 2 )
147        s.add( setting.DatasetExtended(
148                'yData', 'y',
149                descr = _('Y data to fit (dataset name, list of values '
150                          'or expression)'),
151                usertext=_('Y data')), 3 )
152        s.add( setting.Bool(
153                'fitRange', False,
154                descr = _('Fit only the data between the '
155                          'minimum and maximum of the axis for '
156                          'the function variable'),
157                usertext=_('Fit only range')),
158               4 )
159        s.add( setting.WidgetChoice(
160                'outLabel', '',
161                descr=_('Write best fit parameters to this text label '
162                        'after fitting'),
163                widgettypes=('label',),
164                usertext=_('Output label')),
165               5 )
166        s.add( setting.Str('outExpr', '',
167                           descr = _('Output best fitting expression'),
168                           usertext=_('Output expression')),
169               6, readonly=True )
170        s.add( setting.Float('chi2', -1,
171                             descr = 'Output chi^2 from fitting',
172                             usertext=_('Fit &chi;<sup>2</sup>')),
173               7, readonly=True )
174        s.add( setting.Int('dof', -1,
175                           descr = _('Output degrees of freedom from fitting'),
176                           usertext=_('Fit d.o.f.')),
177               8, readonly=True )
178        s.add( setting.Float('redchi2', -1,
179                             descr = _('Output reduced-chi-squared from fitting'),
180                             usertext=_('Fit reduced &chi;<sup>2</sup>')),
181               9, readonly=True )
182
183        f = s.get('function')
184        f.newDefault('a + b*x')
185        f.descr = _('Function to fit')
186
187        # modify description
188        s.get('min').usertext=_('Min. fit range')
189        s.get('max').usertext=_('Max. fit range')
190
191    def affectsAxisRange(self):
192        """This widget provides range information about these axes."""
193        s = self.settings
194        return ( (s.xAxis, 'sx'), (s.yAxis, 'sy') )
195
196    def getRange(self, axis, depname, axrange):
197        """Update range with range of data."""
198        dataname = {'sx': 'xData', 'sy': 'yData'}[depname]
199        data = self.settings.get(dataname).getData(self.document)
200        if data:
201            drange = data.getRange()
202            if drange:
203                axrange[0] = min(axrange[0], drange[0])
204                axrange[1] = max(axrange[1], drange[1])
205
206    def initEnviron(self):
207        """Copy data into environment."""
208        env = self.document.evaluate.context.copy()
209        env.update( self.settings.values )
210        return env
211
212    def updateOutputLabel(self, ops, vals, chi2, dof):
213        """Use best fit parameters to update text label."""
214        s = self.settings
215        labelwidget = s.get('outLabel').findWidget()
216
217        if labelwidget is not None:
218            # build up a set of X=Y values
219            loc = self.document.locale
220            txt = []
221            for l, v in sorted(vals.items()):
222                val = utils.formatNumber(v, '%.4Vg', locale=loc)
223                txt.append( '%s = %s' % (l, val) )
224            # add chi2 output
225            txt.append( r'\chi^{2}_{\nu} = %s/%i = %s' % (
226                    utils.formatNumber(chi2, '%.4Vg', locale=loc),
227                    dof,
228                    utils.formatNumber(chi2/dof, '%.4Vg', locale=loc) ))
229
230            # update label with text
231            text = r'\\'.join(txt)
232            ops.append( document.OperationSettingSet(
233                    labelwidget.settings.get('label') , text ) )
234
235    def actionFit(self):
236        """Fit the data."""
237
238        s = self.settings
239
240        # check and get compiled for of function
241        compiled = self.document.evaluate.compileCheckedExpression(s.function)
242        if compiled is None:
243            return
244
245        # populate the input parameters
246        paramnames = sorted(s.values)
247        params = N.array( [s.values[p] for p in paramnames] )
248
249        # FIXME: loads of error handling!!
250        d = self.document
251
252        # choose dataset depending on fit variable
253        if s.variable == 'x':
254            xvals = s.get('xData').getData(d).data
255            ydata = s.get('yData').getData(d)
256        else:
257            xvals = s.get('yData').getData(d).data
258            ydata = s.get('xData').getData(d)
259        yvals = ydata.data
260        yserr = ydata.serr
261
262        # if there are no errors on data
263        if yserr is None:
264            if ydata.perr is not None and ydata.nerr is not None:
265                print("Warning: Symmeterising positive and negative errors")
266                yserr = N.sqrt( 0.5*(ydata.perr**2 + ydata.nerr**2) )
267            else:
268                print("Warning: No errors on y values. Assuming 5% errors.")
269                yserr = yvals*0.05
270                yserr[yserr < 1e-8] = 1e-8
271
272        # if the fitRange parameter is on, we chop out data outside the
273        # range of the axis
274        if s.fitRange:
275            # get ranges for axes
276            if s.variable == 'x':
277                drange = self.parent.getAxes((s.xAxis,))[0].getPlottedRange()
278                mask = N.logical_and(xvals >= drange[0], xvals <= drange[1])
279            else:
280                drange = self.parent.getAxes((s.yAxis,))[0].getPlottedRange()
281                mask = N.logical_and(yvals >= drange[0], yvals <= drange[1])
282            xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask]
283            print("Fitting %s from %g to %g" % (s.variable,
284                                                drange[0], drange[1]))
285
286        evalenv = self.initEnviron()
287        def evalfunc(params, xvals):
288            # update environment with variable and parameters
289            evalenv[self.settings.variable] = xvals
290            evalenv.update( czip(paramnames, params) )
291
292            try:
293                return eval(compiled, evalenv) + xvals*0.
294            except Exception as e:
295                self.document.log(cstr(e))
296                return N.nan
297
298        # minimum set for fitting
299        if s.min != 'Auto':
300            if s.variable == 'x':
301                mask = xvals >= s.min
302            else:
303                mask = yvals >= s.min
304            xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask]
305
306        # maximum set for fitting
307        if s.max != 'Auto':
308            if s.variable == 'x':
309                mask = xvals <= s.max
310            else:
311                mask = yvals <= s.max
312            xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask]
313
314        if s.min != 'Auto' or s.max != 'Auto':
315            print("Fitting %s between %s and %s" % (s.variable, s.min, s.max))
316
317        # various error checks
318        if len(xvals) != len(yvals) or len(xvals) != len(yserr):
319            sys.stderr.write(_('Fit data not equal in length. Not fitting.\n'))
320            return
321        if len(params) > len(xvals):
322            sys.stderr.write(_('No degrees of freedom for fit. Not fitting\n'))
323            return
324
325        # actually do the fit, either via Minuit or our own LM fitter
326        chi2 = 1
327        dof = 1
328
329        # only consider finite values
330        finite = N.isfinite(xvals) & N.isfinite(yvals) & N.isfinite(yserr)
331        xvals = xvals[finite]
332        yvals = yvals[finite]
333        yserr = yserr[finite]
334
335        # check length after excluding non-finite values
336        if len(xvals) == 0:
337            sys.stderr.write(_('No data values. Not fitting.\n'))
338            return
339
340        if minuit is not None:
341            vals, chi2, dof = minuitFit(
342                evalfunc, params, paramnames, s.values,
343                xvals, yvals, yserr)
344        else:
345            print(_('Minuit not available, falling back to simple L-M fitting:'))
346            retn, chi2, dof = utils.fitLM(
347                evalfunc, params, xvals, yvals, yserr)
348            vals = {}
349            for i, v in czip(paramnames, retn):
350                vals[i] = float(v)
351
352        # list of operations do we can undo the changes
353        operations = []
354
355        # populate the return parameters
356        operations.append( document.OperationSettingSet(s.get('values'), vals) )
357
358        # populate the read-only fit quality params
359        operations.append( document.OperationSettingSet(s.get('chi2'), float(chi2)) )
360        operations.append( document.OperationSettingSet(s.get('dof'), int(dof)) )
361        if dof <= 0:
362            print(_('No degrees of freedom in fit.\n'))
363            redchi2 = -1.
364        else:
365            redchi2 = float(chi2/dof)
366        operations.append( document.OperationSettingSet(s.get('redchi2'), redchi2) )
367
368        # expression for fit
369        expr = self.generateOutputExpr(vals)
370        operations.append( document.OperationSettingSet(s.get('outExpr'), expr) )
371
372        self.updateOutputLabel(operations, vals, chi2, dof)
373
374        # actually change all the settings
375        d.applyOperation(
376            document.OperationMultiple(operations, descr=_('fit')) )
377
378    def generateOutputExpr(self, vals):
379        """Try to generate text form of output expression.
380
381        vals is a dict of variable: value pairs
382        returns the expression
383        """
384
385        paramvals = dict(vals)
386        s = self.settings
387
388        # also substitute in data name for variable
389        if s.variable == 'x':
390            paramvals['x'] = s.xData
391        else:
392            paramvals['y'] = s.yData
393
394        # split expression up into parts of text and nums, separated
395        # by non-text/nums
396        parts = re.split('([^A-Za-z0-9.])', s.function)
397
398        # replace part by things in paramvals, if they exist
399        for i, p in enumerate(parts):
400            if p in paramvals:
401                parts[i] = str(paramvals[p])
402
403        return ''.join(parts)
404
405# allow the factory to instantiate an x,y plotter
406document.thefactory.register( Fit )
407