Source code for cait.versatile.iterators.impl_pulsesim

from typing import List, Union

import numpy as np

import cait.versatile as vai

from ...fit import pulse_template
from .iteratorbase import IteratorBaseClass


[docs] class PulseSimIterator(IteratorBaseClass): """ Iterator object that returns voltage traces superimposed with a SEV. The SEV can EITHER be specified by a template array OR by fit parameters [t0, An, At, tau_n, tau_in, tau_t] where the time constants are given in ms. The n-component pulse shape model with parameters [t0, A1, A2, ..., Ak, tau_in, tau_2, ..., tau_k, tau_n] is also supported. :param iterator: An iterator (of baselines, stream_chunks, etc.) that you want to superimpose the SEV on. :type iterator: IteratorBaseClass :param pulse_heights: The pulse heights to scale the SEV. One for each event in 'iterator' and each channel, i.e. with shape ``(iterator.n_channels, len(iterator))``. :type pulse_heights: List[List[float]] :param shift_samples: If specified for all elements in ``pulse_heights`` (i.e. also for all channels). The respective x-shift (in samples) is applied to the superimposed pulse. If you use a ``sev`` (see below), the edges of the shifted array are padded with the average of the first/last 10 samples. If you use ``sev_fitpars`` (see below), no padding is required because the fit can just be extrapolated. :type shift_samples: List[List[int]] :param shift_subsamples: If specified for all elements in ``pulse_heights`` (i.e. also for all channels). The respective x-shift (in fractional samples) is applied to the superimposed pulse. If you use a ``sev`` (see below), the array is linearly interpolated between samples. If you use ``sev_fitpars`` (see below), the model is evaluated at the intermediate points. Note that all values have to be in the interval [0, 1), corresponding to shifts between zero and one sample. :type shift_subsamples: List[List[float]] :param sev: The SEV to superimpose. Has to match the number of channels of ``iterator`` and its record length, i.e. requires shape ``(iterator.n_channels, iterator.record_length)``. Cannot be specified together with ``sev_fitpars``. :type sev: np.ndarray :param sev_fitpars: The fit parameters for the SEV to superimpose. Has to match the number of channels of ``iterator``. Cannot be specified together with ``sev``. :type sev_fitpars: List[List[float]] :param channels: The channels that we are interested in. Has to be a subset of iterator's channels. If None, all channels are considered. Defaults to None. :type channels: Union[int, List[int]] :param inds: The indices of 'iterator' that we want to iterate over. If None, all indices are considered. Defaults to None :type inds: Union[int, List[int]] :param batch_size: The number of events to be returned at once (these are all read together). There will be a trade-off: large batch_sizes cause faster read speed but increase the memory usage. :type batch_size: int .. code-block:: python import numpy as np import scipy as sp import cait.versatile as vai from cait.versatile.iterators import PulseSimIterator # Use mock data. You will have a more meaningful iterator. md = vai.MockData() sev = md.sev[0] # This is just a cheeky way to get an iterator of random noise. # You will have actual noise. noise_it = md.get_event_iterator()[0].with_processing(lambda x: sp.stats.norm.rvs(loc=0, scale=0.1, size=md.record_length)) # Define random pulse heights sim_phs = sp.stats.uniform.rvs(size=len(noise_it)) # Set up the iterator containing simulated pulses on top of # the noise traces. # Check out the docstring to learn about advanced ways to # simulate events, e.g. by adding template shifts. pulse_sim_it = PulseSimIterator( iterator=noise_it, pulse_heights=sim_phs, sev=sev, ) # Preview your pulses. vai.Preview(pulse_sim_it) # In a next step, you could for example filter the simulated # events, determine their pulse heights, and estimate the baseline # resolution (the code below is a minimal example and definitely not # perfect!). In this case, it makes more sense to simulate a fixed # pulse height: pulse_sim_it_fixed_ph = PulseSimIterator( iterator=noise_it, pulse_heights=0.5*np.ones_like(sim_phs), sev=sev, ) reconstructed_phs = vai.apply( np.max, pulse_sim_it_fixed_ph.with_processing(vai.OptimumFiltering(md.of[0])) ) # Plot a histogram of reconstructed pulse heights. vai.Histogram(reconstructed_phs) """ def __init__(self, iterator: IteratorBaseClass, pulse_heights: List[List[float]], shift_samples: List[List[int]] = None, shift_subsamples: List[List[float]] = None, sev: np.ndarray = None, sev_fitpars: List[List[float]] = None, channels: Union[int, List[int]] = None, inds: List[int] = None, batch_size: int = None): if np.sum([x is None for x in [sev, sev_fitpars]]) != 1: raise ValueError(f"You have to specify EITHER a sev or its fit parameters. Not both.") # We will use a flag 'using_fit' to distinguish between the two cases # and just call the variable collectively 'sev_or_pars'. self._using_fit = sev_fitpars is not None sev_or_pars = sev_fitpars if self._using_fit else sev # check if dimensions for sev, iterator and pulse_heights match sev_or_pars, phs = np.atleast_2d(sev_or_pars), np.atleast_2d(pulse_heights) N, nch = len(iterator), iterator.n_channels if nch>1 and not sev_or_pars.ndim>1: raise ValueError(f"For multi-channel iterators, also 'sev'/'sev_fitpars' must be multi-channel.") if nch>1 and not phs.ndim>1: raise ValueError(f"For multi-channel iterators, also 'pulse_heights' must be multi-channel.") if (nch!=sev_or_pars.shape[0]) or (nch!=phs.shape[0]): raise ValueError(f"Number of channels in 'iterator', 'sev'/'sev_fitpars', and 'pulse_heights' must be equal. Got {[nch, sev_or_pars.shape[0], phs.shape[0]]}.") if N!=phs.shape[-1]: raise ValueError(f"Number of events in 'iterator', and 'pulse_heights' must be equal. Got {[N, phs.shape[-1]]}.") if (not self._using_fit) and (sev_or_pars.shape[-1] != iterator.record_length): raise ValueError(f"Length of 'sev' has to match the record length of 'iterator'. Got {[sev_or_pars.shape[-1], iterator.record_length]}.") if self._using_fit and ((sev_or_pars.shape[-1]-2)%2 != 0): raise ValueError(f"Number of parameters in 'sev_fitpars' has to be 2k+2 for k>1. Got {sev_or_pars.shape[-1]}.") # Sanitize shifts if shift_samples is not None: shift_samples = np.atleast_2d(shift_samples).astype(np.int32) if shift_samples.shape != phs.shape: raise ValueError(f"The shapes of 'shift_samples' and 'pulse_heights' have to be identical. Got {shift_samples.shape} and {phs.shape}.") if shift_subsamples is not None: shift_subsamples = np.atleast_2d(shift_subsamples).astype(np.float32) if shift_subsamples.shape != phs.shape: raise ValueError(f"The shapes of 'shift_subsamples' and 'pulse_heights' have to be identical. Got {shift_subsamples.shape} and {phs.shape}.") if not np.all((shift_subsamples>=0)*(shift_subsamples<1)): raise ValueError("All values in 'shift_subsamples' have to be in the interval [0, 1).") if channels is None: channels = list(range(iterator.n_channels)) if isinstance(channels, int): self._channels = channels self._n_channels = 1 elif isinstance(channels, list): self._channels = channels if len(channels)>1 else channels[0] self._n_channels = len(channels) else: raise TypeError(f"Unsupported type {type(channels)} for input argument 'channels'") if inds is None: inds = np.arange(len(iterator)) inds = [inds] if isinstance(inds, int) else [int(i) for i in inds] # Does batch handling and creates properties self._inds, self.uses_batches, and self.n_batches super().__init__( inds=inds, batch_size=batch_size, iterator=iterator, sev=np.array(sev).tolist() if sev is not None else None, sev_fitpars=np.array(sev_fitpars).tolist() if sev_fitpars is not None else None, pulse_heights=np.array(phs).tolist(), shift_samples=np.array(shift_samples).tolist() if shift_samples is not None else None, shift_subsamples=np.array(shift_subsamples).tolist() if shift_subsamples is not None else None, channels=channels, ) # Save values to reconstruct iterator: self._params = {'iterator': iterator, 'sev': sev, 'sev_fitpars': sev_fitpars, 'pulse_heights': phs, 'shift_samples': shift_samples, 'shift_subsamples': shift_subsamples, 'channels': self._channels, 'inds': inds, 'batch_size': batch_size} # We use the tools implemented by iterator to already # select the correct channels, indices, and batch_size. # If we do so, we can just iterate it to get the correct # (batched) events. if batch_size is None: self._it = iterator[self._channels, inds].flatten() else: self._it = iterator[self._channels, inds].with_batchsize(batch_size) # For sev and phs, we have to manually select the channels. # Furthermore, we make use of the self._inds created by # super().__init__() to slice the correct (batched) pulse_heights # in _next_raw(). self._sev_or_pars = sev_or_pars[self._channels] self._phs = phs[self._channels] # We will always have shift samples (easier handling) and # just set them to zero if not specified if shift_samples is not None: self._shift_samples = shift_samples[self._channels] else: self._shift_samples = np.zeros(self._phs.shape, dtype=np.int32) if shift_subsamples is not None: self._shift_subsamples = shift_subsamples[self._channels] else: self._shift_subsamples = np.zeros(self._phs.shape, dtype=np.float32) # If fit pars are used, we have to evaluate the pulse model at some # point for which the time array of the iterator is used: self._fit_t = iterator.t def __enter__(self): # enter underlying iterator self._it.__enter__() return self def __exit__(self, typ, val, tb): self._it.__exit__(typ, val, tb) def __iter__(self): self._current_batch_ind = 0 # start iteration of underlying iterator # (this way we can retrieve elements faster) self._itit = self._it.__iter__() return self # Forwards .with_record_length, .with_alignment, and .with_extended_window # to StreamIterator/IteratorCollection. def __getattr__(self, name): if name in ["with_record_length", "with_alignment", "with_extended_window"]: if not isinstance(self._it, (vai.iterators.StreamIterator, vai.iterators.IteratorCollection)): raise NotImplementedError(f"Method '{name}' is only available if the PulseSimIterator is based on a StreamIterator or IteratorCollection thereof. Instead based on {self._it.__class__.__name__}.") if self.has_processing: raise NotImplementedError(f"Cannot use method '{name}' on iterators with processing because processing might depend on window size and/or alignment and cause obscure issues. Manually remove processing first using 'old_processing = it.pop_processing()', call '{name}' on the iterator without processing, then add the 'old_processing' again if it does not depend on window size and/or alignment, or add it again after adjusting its parameters to work with the new size/alignment.") params, _ = self._slice_info new_params = params.copy() old_it = new_params.pop("iterator") old_sev = new_params.pop("sev", None) # If fit parameters are given, we don't have to do anything # (because they will just be evaluated on a new time array, # which is 0-aligned with the previous one). If the SEV is # given as an array, we pad it as required. if old_sev is None: return lambda *args, **kwargs: self.__class__( **{ **new_params, **dict(iterator=getattr(old_it, name)(*args, **kwargs)), } ) # Now do careful handling: # Helper variables that let us correctly # modify the SEV ind0 = np.argmin(np.abs(self.t)) rl = self.record_length al = ind0/rl k = np.ndim(params["sev"]) - 1 if name == "with_record_length": def new_sev(record_length, *args, **kwargs): new_ind0 = int(al * record_length) diff0 = new_ind0 - ind0 if record_length > rl: # diff0 > 0 sev = np.zeros((*old_sev.shape[:-1], record_length)) sev[..., diff0:diff0+rl] = old_sev return sev elif record_length < rl: # diff0 < 0 return old_sev[..., -diff0:ind0+int((1-al) * record_length)] else: return old_sev if name == "with_alignment": def new_sev(alignment, *args, **kwargs): new_ind0 = int(alignment * rl) diff0 = new_ind0 - ind0 sev = np.roll(old_sev, diff0, axis=-1) if alignment > al: # diff0 > 0 sev[..., :diff0] = 0 return sev elif alignment < al: # diff0 < 0 sev[..., diff0:] = 0 return sev else: return old_sev if name == "with_extended_window": def new_sev(*args, **kwargs): # Pad with zeros left and right return np.pad(old_sev, (*([(0, 0)]*k), (rl, rl))) return lambda *args, **kwargs: self.__class__( **{ **new_params, **dict( iterator=getattr(old_it, name)(*args, **kwargs), sev=new_sev(*args, **kwargs), ), } ) else: raise AttributeError(f"{self.__class__.__name__} has no attribute '{name}'.") def _shift_fit_pars_and_eval(self, pars: np.ndarray, ks: np.ndarray, zs: np.ndarray): # for all fitpar-tuple in pars, the corresponding shift in ks (sample) # and zs (subsample) is applied and the parameters are evaluated on self.t temp_array = pars.reshape(-1, pars.shape[-1]) temp_k = ks.flatten() temp_z = zs.flatten() new_array = np.zeros((temp_array.shape[0], self.record_length)) for j in range(temp_array.shape[0]): p = temp_array[j] k = temp_k[j] z = temp_z[j] new_array[j, :] = pulse_template( self.t, *(p[0] + (k + z)*self.dt_us/1000, *p[1:]), ) return np.reshape(new_array, tuple(list(pars.shape)[:-1]+[self.record_length])) def _shift_array(self, pulses: np.ndarray, ks: np.ndarray, zs: np.ndarray): # return the input array with all pulses shifted by k (samples) # and z (subsample) according to the values in ks and zs. # If required, edges are padded. temp_array = pulses.reshape(-1, pulses.shape[-1]) temp_k = ks.flatten() temp_z = zs.flatten() new_array = np.zeros(temp_array.shape) for j in range(temp_array.shape[0]): k = temp_k[j] z = temp_z[j] if k == 0: new_array[j, :] = temp_array[j, :] elif k > 0: new_array[j, k:] = temp_array[j, :-k] new_array[j, :k] = np.mean(temp_array[j, :10]) else: # Note that k is negative here new_array[j, :k] = temp_array[j, -k:] new_array[j, k:] = np.mean(temp_array[j, -10:]) # Interpolate template if necessary (subsample shift) if z > 0: new_array[j] = np.hstack([ new_array[j][0], new_array[j][:-1] + (1 - z)*np.diff(new_array[j]) ]) return np.reshape(new_array, pulses.shape) def _next_raw(self): if self._current_batch_ind < self.n_batches: event_inds_in_batch = self._inds[self._current_batch_ind] self._current_batch_ind += 1 sim_phs = self._phs[..., event_inds_in_batch].T[..., None] shifts = self._shift_samples[..., event_inds_in_batch].T[..., None] subshifts = self._shift_subsamples[..., event_inds_in_batch].T[..., None] events = next(self._itit) if self._using_fit: pulse = self._shift_fit_pars_and_eval( # Extend the fitpars array so that we can easier treat different cases # (batches, no batches, single channel, multi channel, ...) np.broadcast_to( self._sev_or_pars, tuple(list(events.shape)[:-1] + [self._sev_or_pars.shape[-1]]) ), shifts, subshifts, ) else: pulse = self._shift_array( # Extend the sev array so that we can easier treat different cases # (batches, no batches, single channel, multi channel, ...) np.broadcast_to(self._sev_or_pars, events.shape), shifts, subshifts, ) return sim_phs*pulse + events else: raise StopIteration @property def t(self): """Return the time axis (record window) of the events in the iterator.""" return self._it.t @property def record_length(self): return self._it.record_length @property def dt_us(self): return self._it.dt_us @property def ds_start_us(self): return self._it.ds_start_us @property def timestamps(self): return self._it.timestamps @property def n_channels(self): return self._n_channels @property def _slice_info(self): return (self._params, ('channels', 'inds'))