Source code for cait.versatile.eventfunctions.scalarfunctions.fitbaseline

from typing import Union, List
import warnings

import numpy as np
from scipy.optimize import curve_fit

from ..functionbase import FitFncBaseClass

# We do not use scipy's parameter error estimation anyways so we can suppress this warning
warnings.filterwarnings("ignore", "Covariance of the parameters could not be estimated")

def exponential_decay(x, a, b, c):
    return a*np.exp(-b*x) + c

[docs]class FitBaseline(FitFncBaseClass): """ Fit voltage traces with a polynomial or decaying exponential and return the fit parameters as well as the RMS. Also works for multiple channels simultaneously. :param model: Order of the polynomial or 'exponential'/'exp', defaults to 0, i.e. a constant baseline. :type model: Union[int, str] :param where: Specifies a subset of data points to be used in the fit: Either a boolean flag of the same length of the voltage traces, a slice object (e.g. slice(0,50) for using the first 50 data points), or a float. If a float `where` is passed, the first `int(where)*record_length` samples are used (e.g. if `where=1/8`, the first 1/8th of the record window is used). Defaults to `slice(None, None, None). :type where: Union[List[bool], slice, int] :param xdata: x-data to use for the fit (has no effect for `order=0`). Specifying xdata is not necessary in general but if you want your fit parameters to have physical units (e.g. time constants) instead of just samples, you may use this option. Defaults to `None`, in which case `xdata=np.linspace(0,1,record_length)`. :type xdata: List[float] :return: Fit parameter(s) and RMS as a tuple. :rtype: Tuple[Union[float, numpy.ndarray], float] **Example:** :: import cait.versatile as vai # Construct mock data (which provides event iterator) md = vai.MockData() it = md.get_event_iterator()[0] # View effect of fitting baseline on events # We specify that for the fit, 1/8th of the record window should be used, # that we fit with a degree-0-polynomial (i.e. constant) # Specifying the xdata is not necessary, but it lets us plot in terms of time vai.Preview(it, vai.FitBaseline(where=1/8, model=0, xdata=it.t), xlabel="Time (ms)") .. image:: media/FitBaseline_preview.png """ def __init__(self, model: Union[int, str] = 0, where: Union[List[bool], slice, float] = slice(None, None, None), xdata: List[float] = None): if type(model) not in [str, int]: raise NotImplementedError(f"Unsupported type '{type(model)}' for input 'order'.") elif type(model) is str and model not in ['exponential', 'exp']: raise NotImplementedError(f"Unrecognized baseline model '{model}'.") elif type(model) is int and model < 0: raise NotImplementedError(f"Polynomial order '{model}' is not supported, only non-zero integers are.") self._model = model self._where = where self._xdata = xdata self._A = None def __call__(self, event): # ATTENTION: this is set only once if isinstance(self._where, float): self._where = slice(0, int(np.array(event).shape[-1]*self._where)) # Shortcut for constant baseline model if self._model == 0: self._fitpar = np.mean(event[..., self._where], axis=-1) self._rms = np.std(event[..., self._where], axis=-1) else: # ATTENTION: This is only set once, i.e. data has to have same length if self._xdata is None: self._xdata = np.linspace(0, 1, np.array(event).shape[-1]) # Exponential fit if self._model in ['exponential', 'exp']: if np.array(event).ndim > 1: self._fitpar = np.array( [ curve_fit(exponential_decay, self._xdata[self._where], event[k, self._where], bounds=([0, 0, -np.inf],[np.inf,np.inf,np.inf]))[0] for k in range(np.array(event).shape[0]) ] ) self._rms = np.array( [ np.sqrt(np.mean((event[k, self._where] - exponential_decay(self._xdata[self._where], *self._fitpar[k]))**2)) for k in range(np.array(event).shape[0]) ] ) else: self._fitpar, *_ = curve_fit(exponential_decay, self._xdata[self._where], event[self._where], bounds=([0, 0, -np.inf],[np.inf, np.inf, np.inf])) self._rms = np.sqrt(np.mean((event[self._where] - exponential_decay(self._xdata[self._where], *self._fitpar))**2)) # Polynomial fit else: if self._A is None: self._A = np.array([self._xdata[self._where]**k for k in range(self._model+1)]).T par, err, *_ = np.linalg.lstsq(self._A, event[..., self._where].T, rcond=None) self._fitpar = par.T self._rms = np.sqrt(err) return self._fitpar, self._rms @property def batch_support(self): return 'trivial'
[docs] def model(self, x: List, par: List): """ """ par = np.array(par) if self._model in ['exponential', 'exp']: if par.shape[-1] != 3: raise ValueError(f"3 parameters are required to fully describe this model, {len(par)} given.") if par.ndim > 1: # i.e. we have multiple channels return np.array([exponential_decay(x, *par[k]) for k in range(par.shape[0])]) else: return exponential_decay(x, *par) else: if par.shape[-1] != self._model+1: raise ValueError(f"{self._model+1} parameter(s) are required to fully describe this model.") if par.ndim > 1: # i.e. we have multiple channels return np.array( [np.sum(np.array([par[k][j]*x**j for j in range(self._model+1)]), axis=0) for k in range(par.shape[0])] ) else: return np.sum(np.array([par[k]*x**k for k in range(self._model+1)]), axis=0)
def preview(self, event): # Call function (this will set all class attributes to be accessed for plotting) self(event) # This happens for constant baseline fit (self._xdata is never needed and therefore never # computed) if self._xdata is None: self._xdata = np.arange(np.array(event).shape[-1]) # self._fitpar is not an array for self._model = 0 par = np.array([self._fitpar]).T if self._model == 0 else self._fitpar # Reconstruct fit function fit = self.model(self._xdata, par) if np.ndim(event) > 1: d = dict() for i in range(np.ndim(event)): d[f'channel {i}'] = [self._xdata, event[i]] d[f'baseline fit channel {i}'] = [self._xdata, fit[i]] else: d = {'event': [self._xdata, event], 'baseline fit': [self._xdata, fit]} return dict(line = d)