* fit_func: Added support for lmfit-based minimization.
Tested minimally in the Cr2_analysis_cbs.py workbench (routine: Test_MC_FPTZ_20140206).
This commit is contained in:
@@ -15,6 +15,12 @@ import numpy
|
||||
import scipy.optimize
|
||||
from wpylib.db.result_base import result_base
|
||||
|
||||
try:
|
||||
import lmfit
|
||||
HAS_LMFIT = True
|
||||
except ImportError:
|
||||
HAS_LMFIT = False
|
||||
|
||||
last_fit_rslt = None
|
||||
last_chi_sqr = None
|
||||
|
||||
@@ -129,7 +135,7 @@ class fit_result(result_base):
|
||||
"""
|
||||
pass
|
||||
|
||||
def fit_func(Funct, Data=None, Guess=None,
|
||||
def fit_func(Funct, Data=None, Guess=None, Params=None,
|
||||
x=None, y=None,
|
||||
w=None, dy=None,
|
||||
debug=0,
|
||||
@@ -152,6 +158,12 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
size of y data below.
|
||||
For a 2-D fitting, for example, x should be a column array.
|
||||
|
||||
An input guess for the parameters can be specified via Guess argument.
|
||||
It is an ordered list of scalar values for these parameters.
|
||||
|
||||
The Params argument is reserved for lmfit-style fitting.
|
||||
It is ignored in other cases.
|
||||
|
||||
The "y" array is a 1-D array of length M, which contain the "measured"
|
||||
value of the function at every domain point given in "x".
|
||||
|
||||
@@ -172,6 +184,14 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
* via Data argument (which is a multi-column dataset, where the first row
|
||||
is the "y" argument).
|
||||
|
||||
OUTPUT
|
||||
|
||||
By default, only the fitted parameters are returned.
|
||||
To return as complete information as possible, use outfmt=0.
|
||||
The return value will be a struct-like object (class: fit_result).
|
||||
|
||||
DEBUGGING
|
||||
|
||||
Debugging and other investigations can be done with "Funct_hook", which,
|
||||
if defined, will be called every time right after "Funct" is called.
|
||||
It is called with the following signature:
|
||||
@@ -181,6 +201,23 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
r := f(C,x) - y
|
||||
Note that the reference to the hook object is passed as the first argument
|
||||
to facilitate object oriented programming.
|
||||
|
||||
|
||||
SUPPORT FOR LMFIT MODULE
|
||||
|
||||
This routine also supports the lmfit module.
|
||||
In this case, the Funct object must supply additional attributes:
|
||||
* param_names: an ordered list of parameter names.
|
||||
This is used in the case that Params argument is not defined.
|
||||
|
||||
Input parameters can be specified ahead of time in the following ways:
|
||||
* If Params is None (default), then unconstrained minimization is done
|
||||
and the necessary Parameter objects are created on-the fly.
|
||||
* If Params is a Parameters object, then they are used.
|
||||
Note that these parameters *will* be clobbered!
|
||||
|
||||
The input Guess parameter can be set to False, in which case Params *must*
|
||||
be defined and the initial values will be used as Guess.
|
||||
"""
|
||||
global last_fit_rslt, last_chi_sqr
|
||||
from scipy.optimize import fmin, fmin_bfgs, leastsq, anneal
|
||||
@@ -194,9 +231,27 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
x = numpy.asarray(x)
|
||||
y = numpy.asarray(y)
|
||||
|
||||
if debug >= 1:
|
||||
print "fit_func: using function=%s, minimizer=%s" \
|
||||
% (repr(Funct), method)
|
||||
|
||||
if debug >= 10:
|
||||
print "fit routine opts = ", opts
|
||||
print "Dimensionality of the domain is: ", len(x)
|
||||
|
||||
if method.startswith("lmfit:"):
|
||||
if not HAS_LMFIT:
|
||||
raise ValueError, \
|
||||
"Module lmfit is not found, cannot use `%s' minimization method." \
|
||||
% (method,)
|
||||
use_lmfit = True
|
||||
from lmfit import minimize, Parameters, Parameter
|
||||
param_names = Funct.param_names
|
||||
if debug >= 10:
|
||||
print "param names: ", param_names
|
||||
else:
|
||||
use_lmfit = False
|
||||
|
||||
if Guess != None:
|
||||
pass
|
||||
elif hasattr(Funct, "Guess_xy"):
|
||||
@@ -209,9 +264,30 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
elif Guess == None: # VERY OLD, DO NOT USE ANYMORE!
|
||||
Guess = [ y.mean() ] + [0.0, 0.0] * len(x)
|
||||
|
||||
if use_lmfit:
|
||||
if Params == None:
|
||||
# Creates a default list of Parameters for use later
|
||||
assert Guess != False
|
||||
Params = Parameters()
|
||||
for (g,pn) in zip(Guess, param_names):
|
||||
Params.add(pn, value=g)
|
||||
else:
|
||||
if Guess == None or Guess == False:
|
||||
# copy the Params' values to Guess
|
||||
Guess = [ Params[pn].value for pn in param_names ]
|
||||
else:
|
||||
# copy the Guess values onto the Params' values
|
||||
for (g,pn) in zip(Guess, param_names):
|
||||
Params[pn].value = g
|
||||
|
||||
if debug >= 10:
|
||||
print "lmfit guess parameters:"
|
||||
for k1 in Params:
|
||||
print " - ", Params[k1]
|
||||
|
||||
if debug >= 5:
|
||||
print "Guess params:"
|
||||
print Guess
|
||||
print " ", Guess
|
||||
|
||||
if Funct_hook != None:
|
||||
if not hasattr(Funct_hook, "__call__"):
|
||||
@@ -221,9 +297,9 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
reference data points:
|
||||
|
||||
* CC = current function parameters
|
||||
* xx = domain points of the ("experimental") data
|
||||
* yy = target points of the ("experimental") data
|
||||
* ww = weights of the ("experimental") data
|
||||
* xx = domain points of the ("measured") data
|
||||
* yy = target points of the ("measured") data
|
||||
* ww = weights of the ("measured") data (usually, 1/error**2 of the data)
|
||||
"""
|
||||
ff = Funct(CC,xx)
|
||||
r = (ff - yy) * ww
|
||||
@@ -255,6 +331,7 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
# Full result is stored in rec
|
||||
rec = fit_result()
|
||||
extra_keys = {}
|
||||
chi_sqr = None
|
||||
if method == 'leastsq':
|
||||
# modified Levenberg-Marquardt algorithm
|
||||
rslt = leastsq(fun_err,
|
||||
@@ -305,9 +382,72 @@ def fit_func(Funct, Data=None, Guess=None,
|
||||
**opts
|
||||
)
|
||||
keys = ('xopt', 'fopt', 'T', 'funcalls', 'iter', 'accept', 'retval')
|
||||
elif use_lmfit:
|
||||
submethod = method.split(":",1)[1]
|
||||
minrec = minimize(fun_err, Params,
|
||||
args=(x,y,sqrtw),
|
||||
method=submethod,
|
||||
# backend ("real" minimizer) options:
|
||||
full_output=1,
|
||||
**opts
|
||||
)
|
||||
|
||||
xopt = [ Params[k1].value for k1 in param_names ]
|
||||
keys = ('xopt', 'minobj', 'params')
|
||||
rslt = [ xopt, minrec, Params ]
|
||||
# map the output values (in the Minimizer instance, minrec), to
|
||||
# the same keyword as other methods for backward compatiblity.
|
||||
rec['funcalls'] = minrec.nfev
|
||||
try:
|
||||
chi_sqr = minrec.chi_sqr
|
||||
except:
|
||||
pass
|
||||
# These seem to be particular to leastsq:
|
||||
try:
|
||||
rec['ier'] = minrec.ier
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
rec['mesg'] = minrec.lmdif_message
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
rec['message'] = minrec.message
|
||||
except:
|
||||
pass
|
||||
|
||||
# Added estimate of fit parameter uncertainty (matching GNUPLOT parameter
|
||||
# uncertainty.
|
||||
# The error is estimated to be the diagonal of cov_x, multiplied by the WSSR
|
||||
# (chi_square below) and divided by the number of fit degrees of freedom.
|
||||
# I used newer scipy.optimize.curve_fit() routine as my cheat sheet here.
|
||||
if outfmt == 0:
|
||||
try:
|
||||
has_errorbars = minrec.errorbars
|
||||
except:
|
||||
has_errorbars = False
|
||||
|
||||
if has_errorbars:
|
||||
try:
|
||||
rec['xerr'] = [ Params[k1].stderr for k1 in param_names ]
|
||||
except:
|
||||
# it shouldn't fail like this!
|
||||
import warnings
|
||||
warnings.warn("wpylib.math.fitting.fit_func: Fail to get standard error of the fit parameters")
|
||||
|
||||
if debug >= 10:
|
||||
if 'xerr' in rec:
|
||||
print "param errorbars are found:"
|
||||
print " ", tuple(rec['xerr'])
|
||||
else:
|
||||
print "param errorbars are NOT found!"
|
||||
|
||||
else:
|
||||
raise ValueError, "Unsupported minimization method: %s" % method
|
||||
chi_sqr = fun_err2(rslt[0], x, y, sqrtw)
|
||||
|
||||
# Final common post-processing:
|
||||
if chi_sqr == None:
|
||||
chi_sqr = fun_err2(rslt[0], x, y, sqrtw)
|
||||
last_chi_sqr = chi_sqr
|
||||
last_fit_rslt = rslt
|
||||
if (debug >= 10):
|
||||
|
||||
Reference in New Issue
Block a user