1# fit.py 2# fitting plotter 3 4# Copyright (C) 2005 Jeremy S. Sanders 5# Email: Jeremy Sanders <jeremy@jeremysanders.net> 6# 7# This program is free software; you can redistribute it and/or modify 8# it under the terms of the GNU General Public License as published by 9# the Free Software Foundation; either version 2 of the License, or 10# (at your option) any later version. 11# 12# This program is distributed in the hope that it will be useful, 13# but WITHOUT ANY WARRANTY; without even the implied warranty of 14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15# GNU General Public License for more details. 16# 17# You should have received a copy of the GNU General Public License along 18# with this program; if not, write to the Free Software Foundation, Inc., 19# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 20############################################################################### 21 22from __future__ import division, absolute_import, print_function 23import re 24import sys 25 26import numpy as N 27 28from ..compat import czip, cstr 29from .. import document 30from .. import setting 31from .. import utils 32from .. import qtall as qt 33 34from .function import FunctionPlotter 35from . import widget 36 37# try importing iminuit first, then minuit, then None 38try: 39 import iminuit as minuit 40except ImportError: 41 try: 42 import minuit 43 except ImportError: 44 minuit = None 45 46def _(text, disambiguation=None, context='Fit'): 47 """Translate text.""" 48 return qt.QCoreApplication.translate(context, text, disambiguation) 49 50def minuitFit(evalfunc, params, names, values, xvals, yvals, yserr): 51 """Do fitting with minuit (if installed).""" 52 53 def chi2(params): 54 """generate a lambda function to impedance-match between PyMinuit's 55 use of multiple parameters versus our use of a single numpy vector.""" 56 c = ((evalfunc(params, xvals) - yvals)**2 / yserr**2).sum() 57 if chi2.runningFit: 58 chi2.iters += 1 59 p = [chi2.iters, c] + params.tolist() 60 str = ("%5i " + "%8g " * (len(params)+1)) % tuple(p) 61 print(str) 62 63 return c 64 65 namestr = ', '.join(names) 66 fnstr = 'lambda %s: chi2(N.array([%s]))' % (namestr, namestr) 67 68 # this is safe because the only user-controlled variable is len(names) 69 fn = eval(fnstr, {'chi2' : chi2, 'N' : N}) 70 71 print(_('Fitting via Minuit:')) 72 m = minuit.Minuit(fn, **values) 73 74 # run the fit 75 chi2.runningFit = True 76 chi2.iters = 0 77 m.migrad() 78 79 # do some error analysis 80 have_symerr, have_err = False, False 81 try: 82 chi2.runningFit = False 83 m.hesse() 84 have_symerr = True 85 m.minos() 86 have_err = True 87 except Exception as e: 88 print(e) 89 if str(e).startswith('Discovered a new minimum'): 90 # the initial fit really failed 91 raise 92 93 # print the results 94 retchi2 = m.fval 95 dof = len(yvals) - len(params) 96 redchi2 = retchi2 / dof 97 98 if have_err: 99 print(_('Fit results:\n') + "\n".join([ 100 u" %s = %g \u00b1 %g (+%g / %g)" 101 % (n, m.values[n], m.errors[n], m.merrors[(n, 1.0)], 102 m.merrors[(n, -1.0)]) for n in names])) 103 elif have_symerr: 104 print(_('Fit results:\n') + "\n".join([ 105 u" %s = %g \u00b1 %g" % (n, m.values[n], m.errors[n]) 106 for n in names])) 107 print(_('MINOS error estimate not available.')) 108 else: 109 print(_('Fit results:\n') + "\n".join([ 110 ' %s = %g' % (n, m.values[n]) for n in names])) 111 print(_('No error analysis available: fit quality uncertain')) 112 113 print("chi^2 = %g, dof = %i, reduced-chi^2 = %g" % (retchi2, dof, redchi2)) 114 115 vals = dict(m.values) 116 return vals, retchi2, dof 117 118class Fit(FunctionPlotter): 119 """A plotter to fit a function to data.""" 120 121 typename='fit' 122 allowusercreation=True 123 description=_('Fit a function to data') 124 125 def __init__(self, parent, name=None): 126 FunctionPlotter.__init__(self, parent, name=name) 127 128 self.addAction( widget.Action('fit', self.actionFit, 129 descr = _('Fit function'), 130 usertext = _('Fit function')) ) 131 132 @classmethod 133 def addSettings(klass, s): 134 """Construct list of settings.""" 135 FunctionPlotter.addSettings(s) 136 137 s.add( setting.FloatDict( 138 'values', 139 {'a': 0.0, 'b': 1.0}, 140 descr = _('Variables and fit values'), 141 usertext=_('Parameters')), 1 ) 142 s.add( setting.DatasetExtended( 143 'xData', 'x', 144 descr = _('X data to fit (dataset name, list of values ' 145 'or expression)'), 146 usertext=_('X data')), 2 ) 147 s.add( setting.DatasetExtended( 148 'yData', 'y', 149 descr = _('Y data to fit (dataset name, list of values ' 150 'or expression)'), 151 usertext=_('Y data')), 3 ) 152 s.add( setting.Bool( 153 'fitRange', False, 154 descr = _('Fit only the data between the ' 155 'minimum and maximum of the axis for ' 156 'the function variable'), 157 usertext=_('Fit only range')), 158 4 ) 159 s.add( setting.WidgetChoice( 160 'outLabel', '', 161 descr=_('Write best fit parameters to this text label ' 162 'after fitting'), 163 widgettypes=('label',), 164 usertext=_('Output label')), 165 5 ) 166 s.add( setting.Str('outExpr', '', 167 descr = _('Output best fitting expression'), 168 usertext=_('Output expression')), 169 6, readonly=True ) 170 s.add( setting.Float('chi2', -1, 171 descr = 'Output chi^2 from fitting', 172 usertext=_('Fit χ<sup>2</sup>')), 173 7, readonly=True ) 174 s.add( setting.Int('dof', -1, 175 descr = _('Output degrees of freedom from fitting'), 176 usertext=_('Fit d.o.f.')), 177 8, readonly=True ) 178 s.add( setting.Float('redchi2', -1, 179 descr = _('Output reduced-chi-squared from fitting'), 180 usertext=_('Fit reduced χ<sup>2</sup>')), 181 9, readonly=True ) 182 183 f = s.get('function') 184 f.newDefault('a + b*x') 185 f.descr = _('Function to fit') 186 187 # modify description 188 s.get('min').usertext=_('Min. fit range') 189 s.get('max').usertext=_('Max. fit range') 190 191 def affectsAxisRange(self): 192 """This widget provides range information about these axes.""" 193 s = self.settings 194 return ( (s.xAxis, 'sx'), (s.yAxis, 'sy') ) 195 196 def getRange(self, axis, depname, axrange): 197 """Update range with range of data.""" 198 dataname = {'sx': 'xData', 'sy': 'yData'}[depname] 199 data = self.settings.get(dataname).getData(self.document) 200 if data: 201 drange = data.getRange() 202 if drange: 203 axrange[0] = min(axrange[0], drange[0]) 204 axrange[1] = max(axrange[1], drange[1]) 205 206 def initEnviron(self): 207 """Copy data into environment.""" 208 env = self.document.evaluate.context.copy() 209 env.update( self.settings.values ) 210 return env 211 212 def updateOutputLabel(self, ops, vals, chi2, dof): 213 """Use best fit parameters to update text label.""" 214 s = self.settings 215 labelwidget = s.get('outLabel').findWidget() 216 217 if labelwidget is not None: 218 # build up a set of X=Y values 219 loc = self.document.locale 220 txt = [] 221 for l, v in sorted(vals.items()): 222 val = utils.formatNumber(v, '%.4Vg', locale=loc) 223 txt.append( '%s = %s' % (l, val) ) 224 # add chi2 output 225 txt.append( r'\chi^{2}_{\nu} = %s/%i = %s' % ( 226 utils.formatNumber(chi2, '%.4Vg', locale=loc), 227 dof, 228 utils.formatNumber(chi2/dof, '%.4Vg', locale=loc) )) 229 230 # update label with text 231 text = r'\\'.join(txt) 232 ops.append( document.OperationSettingSet( 233 labelwidget.settings.get('label') , text ) ) 234 235 def actionFit(self): 236 """Fit the data.""" 237 238 s = self.settings 239 240 # check and get compiled for of function 241 compiled = self.document.evaluate.compileCheckedExpression(s.function) 242 if compiled is None: 243 return 244 245 # populate the input parameters 246 paramnames = sorted(s.values) 247 params = N.array( [s.values[p] for p in paramnames] ) 248 249 # FIXME: loads of error handling!! 250 d = self.document 251 252 # choose dataset depending on fit variable 253 if s.variable == 'x': 254 xvals = s.get('xData').getData(d).data 255 ydata = s.get('yData').getData(d) 256 else: 257 xvals = s.get('yData').getData(d).data 258 ydata = s.get('xData').getData(d) 259 yvals = ydata.data 260 yserr = ydata.serr 261 262 # if there are no errors on data 263 if yserr is None: 264 if ydata.perr is not None and ydata.nerr is not None: 265 print("Warning: Symmeterising positive and negative errors") 266 yserr = N.sqrt( 0.5*(ydata.perr**2 + ydata.nerr**2) ) 267 else: 268 print("Warning: No errors on y values. Assuming 5% errors.") 269 yserr = yvals*0.05 270 yserr[yserr < 1e-8] = 1e-8 271 272 # if the fitRange parameter is on, we chop out data outside the 273 # range of the axis 274 if s.fitRange: 275 # get ranges for axes 276 if s.variable == 'x': 277 drange = self.parent.getAxes((s.xAxis,))[0].getPlottedRange() 278 mask = N.logical_and(xvals >= drange[0], xvals <= drange[1]) 279 else: 280 drange = self.parent.getAxes((s.yAxis,))[0].getPlottedRange() 281 mask = N.logical_and(yvals >= drange[0], yvals <= drange[1]) 282 xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask] 283 print("Fitting %s from %g to %g" % (s.variable, 284 drange[0], drange[1])) 285 286 evalenv = self.initEnviron() 287 def evalfunc(params, xvals): 288 # update environment with variable and parameters 289 evalenv[self.settings.variable] = xvals 290 evalenv.update( czip(paramnames, params) ) 291 292 try: 293 return eval(compiled, evalenv) + xvals*0. 294 except Exception as e: 295 self.document.log(cstr(e)) 296 return N.nan 297 298 # minimum set for fitting 299 if s.min != 'Auto': 300 if s.variable == 'x': 301 mask = xvals >= s.min 302 else: 303 mask = yvals >= s.min 304 xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask] 305 306 # maximum set for fitting 307 if s.max != 'Auto': 308 if s.variable == 'x': 309 mask = xvals <= s.max 310 else: 311 mask = yvals <= s.max 312 xvals, yvals, yserr = xvals[mask], yvals[mask], yserr[mask] 313 314 if s.min != 'Auto' or s.max != 'Auto': 315 print("Fitting %s between %s and %s" % (s.variable, s.min, s.max)) 316 317 # various error checks 318 if len(xvals) != len(yvals) or len(xvals) != len(yserr): 319 sys.stderr.write(_('Fit data not equal in length. Not fitting.\n')) 320 return 321 if len(params) > len(xvals): 322 sys.stderr.write(_('No degrees of freedom for fit. Not fitting\n')) 323 return 324 325 # actually do the fit, either via Minuit or our own LM fitter 326 chi2 = 1 327 dof = 1 328 329 # only consider finite values 330 finite = N.isfinite(xvals) & N.isfinite(yvals) & N.isfinite(yserr) 331 xvals = xvals[finite] 332 yvals = yvals[finite] 333 yserr = yserr[finite] 334 335 # check length after excluding non-finite values 336 if len(xvals) == 0: 337 sys.stderr.write(_('No data values. Not fitting.\n')) 338 return 339 340 if minuit is not None: 341 vals, chi2, dof = minuitFit( 342 evalfunc, params, paramnames, s.values, 343 xvals, yvals, yserr) 344 else: 345 print(_('Minuit not available, falling back to simple L-M fitting:')) 346 retn, chi2, dof = utils.fitLM( 347 evalfunc, params, xvals, yvals, yserr) 348 vals = {} 349 for i, v in czip(paramnames, retn): 350 vals[i] = float(v) 351 352 # list of operations do we can undo the changes 353 operations = [] 354 355 # populate the return parameters 356 operations.append( document.OperationSettingSet(s.get('values'), vals) ) 357 358 # populate the read-only fit quality params 359 operations.append( document.OperationSettingSet(s.get('chi2'), float(chi2)) ) 360 operations.append( document.OperationSettingSet(s.get('dof'), int(dof)) ) 361 if dof <= 0: 362 print(_('No degrees of freedom in fit.\n')) 363 redchi2 = -1. 364 else: 365 redchi2 = float(chi2/dof) 366 operations.append( document.OperationSettingSet(s.get('redchi2'), redchi2) ) 367 368 # expression for fit 369 expr = self.generateOutputExpr(vals) 370 operations.append( document.OperationSettingSet(s.get('outExpr'), expr) ) 371 372 self.updateOutputLabel(operations, vals, chi2, dof) 373 374 # actually change all the settings 375 d.applyOperation( 376 document.OperationMultiple(operations, descr=_('fit')) ) 377 378 def generateOutputExpr(self, vals): 379 """Try to generate text form of output expression. 380 381 vals is a dict of variable: value pairs 382 returns the expression 383 """ 384 385 paramvals = dict(vals) 386 s = self.settings 387 388 # also substitute in data name for variable 389 if s.variable == 'x': 390 paramvals['x'] = s.xData 391 else: 392 paramvals['y'] = s.yData 393 394 # split expression up into parts of text and nums, separated 395 # by non-text/nums 396 parts = re.split('([^A-Za-z0-9.])', s.function) 397 398 # replace part by things in paramvals, if they exist 399 for i, p in enumerate(parts): 400 if p in paramvals: 401 parts[i] = str(paramvals[p]) 402 403 return ''.join(parts) 404 405# allow the factory to instantiate an x,y plotter 406document.thefactory.register( Fit ) 407