from typing import List
import numpy as np
from scipy.linalg import LinAlgError
from scipy.optimize import minimize
from ..functionbase import ScalarFncBaseclass
from ..processing.removebaseline import RemoveBaseline
from .templatefit import _TemplateCachePoly, _TemplateCacheSimple, shift_arrays
############################
###### HELPER CLASSES ######
############################
class _TemplateCacheCorrelated:
"""
Helper class that performs the correlated template fit for multiple channels with a (non-)trivial baseline. It caches the matrices used for solving the minimization problem for the regular (not truncated) fit to increase computational efficiency.
See https://edoc.ub.uni-muenchen.de/23762/ and https://mediatum.ub.tum.de/?id=1294132 for details.
:param sev: The reference event (at least two channels).
:type sev: np.ndarray
:param xdata: The x-data to use for the baseline model evaluation (exactly 1-d).
:type xdata: np.ndarray
:param order: A list of the orders of the baseline polynomials to be fitted for each channel separately. If an entry is None, this channel's baseline will not be fitted.
:type order: List[int]
:param fit_onset: List of as many entries as there are channels in 'sev'. For any True, the onset value for the respective channel is fitted. If False, the respective channel just moves together with the fitted channels (but does not contribute to the minimization)
:type fit_onset: List[bool]
:param max_shift: The maximum shift value (in samples) to search for a minimum. The onset fit will search the minimum for shifts in ``(-max_shift, +max_shift)``. Defaults to 50 samples
:type max_shift: int, optional
"""
def __init__(self,
sev: np.ndarray,
xdata: np.ndarray,
order: List[int],
fit_onset: List[bool],
max_shift: int = 50):
self._sev = sev
self._n_channels = sev.shape[0]
self._xdata = xdata
self._order = order
self._fit_onset = fit_onset
self._max_shift = max_shift
# Construct template fits for each channel separately
self._template_fits = []
for i in range(self._n_channels):
if self._order[i] is None:
self._template_fits.append(
_TemplateCacheSimple(
sev=self._sev[i],
fit_onset=self._fit_onset[i],
max_shift=self._max_shift
)
)
else:
self._template_fits.append(
_TemplateCachePoly(
sev=self._sev[i],
xdata=self._xdata,
order=self._order[i],
fit_onset=self._fit_onset[i],
max_shift=self._max_shift
)
)
# Determine output shape (pad with zeros if channels have different baseline models)
# [[ph_ch0, poly0_ch0, poly1_ch0, poly2_ch0, poly3_ch0]
# [ph_ch1, poly0_ch1, poly1_ch1, poly2_ch1, poly3_ch1]]
non_none_orders = [x for x in self._order if x is not None]
max_order = np.max(non_none_orders) if len(non_none_orders)>0 else -1
self._output_shape = (self._n_channels, 2+max_order)
# Define error outputs (if fit fails, these will be the
# fit results that tell you that the fit failed)
self._error_output = (np.zeros(self._output_shape), 0, -404*np.ones(self._n_channels))
def __call__(self, ev: np.ndarray, flag: np.ndarray = None):
"""
Performs the correlated template fit including a (non-)trivial baseline.
See https://edoc.ub.uni-muenchen.de/23762/ and https://mediatum.ub.tum.de/?id=1294132 for details.
:param ev: The event to be fitted (at least 2 channels)
:type ev: np.ndarray
:param flag: The flag (for each channel) to apply to the data (used for truncated fit). Defaults to None, i.e. no slicing on either channel
:type flag: np.ndarray, optional
:return: Tuple of fit result, optimal shift, and RMS value ``([[amplitude_ch0, constant_bl_coeff_ch0, linear_bl_coeff_ch0, ...], [amplitude_ch1, constant_bl_coeff_ch1, linear_bl_coeff_ch1, ...], ...], shift, [rms_ch0, rms_ch1, ...])``. If the fit was unsuccessful, the RMS values are set to -404 and all fit parameters are 0. If different orders of baseline polynomials are used for different channels, the output is extended to the largest used polynomial order (i.e. unused coefficients are 0).
:rtype: Tuple[np.ndarray, int, np.ndarray]
"""
flag = [None]*self._n_channels if flag is None else flag
try:
if any(self._fit_onset):
res = minimize(self._chij2,
x0=0,
args=(ev, flag),
method="Powell",
bounds=[(-self._max_shift, self._max_shift)])
opt_shift = int(res.x[0])
else:
opt_shift = 0
results = self._solve(j=opt_shift, ev=ev, flag=flag)
opt_param = np.zeros(self._output_shape)
rms = np.zeros(self._n_channels)
for i, result in enumerate(results):
op, r = result
rms[i] = r
# make sure it's a 1d array
op = np.array(op)[None].flatten()
opt_param[i, :len(op)] = op[:len(op)]
except LinAlgError:
opt_param, opt_shift, rms = self._error_output
except ValueError as err:
if err.args[0] == "array must not contain infs or NaNs":
opt_param, opt_shift, rms = self._error_output
elif err.args[0] == "zero-size array to reduction operation maximum which has no identity":
opt_param, opt_shift, rms = self._error_output
else:
raise err
return opt_param, opt_shift, rms
def _solve(self, j: int, ev: np.ndarray, flag: list):
"""
Solve the minimization problem for a given shift value j.
:param j: The shift.
:type j: int
:param ev: The event to be fitted.
:type ev: np.ndarray
:param flag: The flag to apply to the data (used for truncated fit). Defaults to None, i.e. no slicing
:type flag: np.ndarray, optional
:return: Tuple of (fit parameters, rms).
:rtype: Tuple[float, float]
"""
# call the solvers of the underlying classes and return results
# as a list where each entry is a tuple that corresponds to
# (fitresult, rms) of the individual channels
return [
tf._solve(j=j, ev=ev[i], flag=flag[i])
for i, tf in enumerate(self._template_fits)
]
def _chij2(self, j: int, ev: np.ndarray, flag: list):
"""
Returns the chi-squared value for a given shift after fitting ``sev`` to ``ev``. The result includes the sum of the chi-squares of all channels whose onset is to be fitted (i.e. chi-squared is minimized together if multiple channels use fit_onset=True).
See https://edoc.ub.uni-muenchen.de/23762/ and https://mediatum.ub.tum.de/?id=1294132 for details.
:param j: The shift.
:type j: int
:param ev: The event to be fitted.
:type ev: np.ndarray
:param flag: The flag to apply to the data (used for truncated fit). Defaults to None, i.e. no slicing
:type flag: np.ndarray, optional
:return: The chi-squared value.
:rtype: float
"""
# call the chij2 functions of the underlying classes and sum
# the ones for which the onset should be fitted. Returns the
# sum of the RMSs
return sum([
tf._chij2(j=j, ev=ev[i], flag=flag[i])
for i, tf in enumerate(self._template_fits)
if self._fit_onset[i]
])
########################
### CLASS DEFINITION ###
########################