Source code for cait.versatile.iterators.iteratorbase

import itertools
from abc import ABC, abstractmethod
from typing import Callable, List, Union

import numpy as np

# Have to import like this to avoid circular import
import cait.versatile as vai

from ...serialize import SerializingMixin
from .batchresolver import BatchResolver


#### HELPER FUNCTIONS ####
def _ensure_array(x):
    if isinstance(x, str): x = [x]
    elif isinstance(x, int): x = [x]
    elif isinstance(x, np.integer): x = [int(x)]
    return np.array(x)
    
def _ensure_not_array(x):
    if isinstance(x, np.ndarray): x = x.tolist()
    if isinstance(x, np.integer): x = int(x)
    if isinstance(x, str): x = str(x)
    return x

[docs] class IteratorBaseClass(SerializingMixin, ABC): """ Baseclass for all iterators. Defines behavior shared among all event iterators. .. automethod:: __len__ .. automethod:: __add__ .. automethod:: __getitem__ """ def __init__(self, inds: List[int], batch_size: int = None, **kwargs): super().__init__(inds=inds, batch_size=batch_size, **kwargs) self.fncs = list() self.__n_events = len(inds) # self._inds will be a list of batches. If we just take the inds list, we have batches of size 1, if we take [inds] # all inds are in one batch, otherwise it is a list of lists where each list is a batch if batch_size is None or batch_size == 1: self.__inds = inds self._uses_batches = False elif batch_size == -1: self.__inds = [inds] self._uses_batches = True else: self.__inds = [inds[i:i+batch_size] for i in range(0, len(inds), batch_size)] self._uses_batches = True self._n_batches = len(self._inds)
[docs] def __len__(self): """Return the number of events in the iterator.""" return self.__n_events
def __enter__(self): return self def __exit__(self, typ, val, tb): ... @abstractmethod def __iter__(self): ... def __next__(self): return self._apply_processing(self._next_raw()) @abstractmethod def _next_raw(self): # This is to be used as the __next__ method by child classes. # The purpose of having it separately is that the IteratorBaseClass # then automatically applies the preprocessing ... def _apply_processing(self, out): for fnc in self.fncs: out = fnc(out) if not self.uses_batches else BatchResolver(fnc, self.n_channels)(out) return out def __repr__(self): out = f"{self.__class__.__name__}(n_events={len(self)}, n_channels={self.n_channels}, uses_batches={self.uses_batches}, record_length={self.record_length}, dt_us={self.dt_us}" if len(self.fncs) > 0: out += f", preprocessing: {self.fncs})" else: out += ")" return out
[docs] def __getitem__(self, val): """ Slice iterator as if it was laid out as a numpy.ndarray and return a new iterator. The first argument slices the channel, the second slices the list of events in the iterator. **Example:** .. code-block:: python # Starting from an iterator 'it' of multiple channels, you can # - access only the first channel it[0] # - access the last 1000 events of the first channel it[0, -1000:] # - access every second event from all channels it[:, ::2] # ... etc. """ # Slice Iterator as if it was layed out as a numpy.ndarray. # The first argument slices the channel/key/... and the second slices the remaining list of events in the iterator. if not isinstance(val, tuple): val = (val,) if len(val) > 2: raise IndexError(f"Too many indices for iterator: 2 values can be indexed (channel list and event list) but {len(val)} were indexed.") # First index will numpy-slice channel (or equivalent) slice_channel = val[0] # Second index will numpy-slice indices (or equivalent) slice_inds = val[1] if len(val)>1 else slice(None) # Get parameters to reconstruct iterator and make a copy params, keys = self._slice_info new_params = dict(**params) # Slice relevant arguments new_params[keys[0]] = _ensure_not_array(_ensure_array(params[keys[0]])[slice_channel]) new_params[keys[1]] = _ensure_not_array(_ensure_array(params[keys[1]])[slice_inds]) # Create new instance of iterator and add processing new_iterator = self.__class__(**new_params) new_iterator.add_processing(self.fncs.copy()) # Return new iterator return new_iterator
[docs] def __add__(self, other): """ Add two iterators sequentially. E.g. given two iterators ``it1`` and ``it2``, the sum ``it1 + it2`` returns an iterator that first iterates through ``it1``, and then through ``it2``, once ``it1`` is consumed. **Example:** .. code-block:: python # Given two iterators 'it1' and 'it2', they can be sequentially combined into # a single iterator by combined_it = it1 + it2 """ if isinstance(self, IteratorCollection): l = [i.with_processing(self.fncs) for i in self.iterators] else: l = [self] if isinstance(other, IteratorBaseClass): if isinstance(other, IteratorCollection): l += [i.with_processing(other.fncs) for i in other.iterators] else: l += [other] else: raise TypeError(f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'") return IteratorCollection(l)
def __radd__(self, other): # Used to return iterator for situation '0 + it = it' # This way, we can use the built-in sum() to sum a list of iterators if other == 0: return self else: return other.__add__(self)
[docs] def add_processing(self, f: Union[Callable, List[Callable]]): """ Add functions to be applied to each event before returning it. Batches are supported, i.e. if the iterator returns events in batches, the specified functions are applied to all events in a batch separately. However, the user is responsible for handling multiple channels correctly: Events are passed to the functions directly, even if it includes multiple channels. :param f: Function(s) to be applied. Function signature: f(event: np.ndarray) -> np.ndarray :type f: Union[Callable, List[Callable]] **Example:** .. code-block:: python import cait.versatile as vai def f1(event): return event + 1 def f2(event): return event*2 it = vai.MockData().get_event_iterator() it.add_processing([f1, f2]) """ if not isinstance(f, list): f = [f] self.fncs += f # return instance such that it is chainable and can be used in one-liners return self
[docs] def with_processing(self, f: Union[Callable, List[Callable]]): """ Same as ``add_processing`` but it returns a new iterator instead of modifying the original one. :param f: Function(s) to be applied. Function signature: f(event: np.ndarray) -> np.ndarray :type f: Union[Callable, List[Callable]] **Example:** .. code-block:: python import cait.versatile as vai def f1(event): return event + 1 def f2(event): return event*2 it = vai.MockData().get_event_iterator() new_it = it.with_processing([f1, f2]) """ return self[:,:].add_processing(f)
[docs] def pop_processing(self): """ Removes all processing functions from the iterator and returns them as a list. """ fncs = self.fncs.copy() self.fncs = list() return fncs
[docs] def with_batchsize(self, batch_size: int): """ Returns an identical iterator but with a different batch size. :param batch_size: The new batch size. :type batch_size: int """ params, _ = self._slice_info if "batch_size" not in params.keys(): raise Exception(f"{type(self)} does not support changing batch size.") new_params = params.copy() new_params["batch_size"] = batch_size new_iterator = self.__class__(**new_params) new_iterator.add_processing(self.fncs.copy()) return new_iterator
[docs] def flatten(self): """ Returns an identical iterator but without batches. Has no effect if iterator didn't use batches before. """ return self.with_batchsize(1)
[docs] def grab(self, which: Union[int, list]): """ Grab specified event(s) and return it/them as numpy array. :param which: Events of interest. :type which: Union[int, list] **Example:** .. code-block:: python import cait.versatile as vai it = vai.MockData().get_event_iterator() # Get events from mock data selected_event = it.grab(-1) # Get the last event in the iterator selected_events = it.grab([1,7,9]) # Get events with indices 1, 7, 9 """ with self: # so that all events are read without re-opening the file return np.squeeze(np.array(list(self[:, which])))[()]
@property def t(self): """ Return the time axis (record window) of the events in the iterator. It is a millisecond array with 0 being at 1/4th of the window. """ return (np.arange(self.record_length) - self.record_length/4)*self.dt_us/1000 @property def _inds(self): return self.__inds @property def uses_batches(self): """ Returns True if the iterator returns batches. """ return self._uses_batches @property def n_batches(self): """ Returns the number of batches in the iterator. """ return self._n_batches @property def has_processing(self): """ Returns True if one or more processing functions have been added to the iterator. """ return len(self.fncs) > 0 @property def hours(self): """ Returns the times (in hours) of the events in this iterators since the start of the underlying datasource. """ return (self.timestamps - self.ds_start_us)/1e6/3600 @property @abstractmethod def record_length(self): """ Returns the record length (in samples) of the events in the iterator. """ ... @property @abstractmethod def dt_us(self): """ Returns the time base (in microseconds) of the events in the iterator. """ ... @property def sample_frequency(self): """ Returns the sampling frequency (in Hz) of the events in the iterator. :return: Sampling frequency (Hz) :rtype: int """ return int(1e6//self.dt_us) @property @abstractmethod def ds_start_us(self): """ The microsecond timestamp of the start of the recording for the datasource underlying this iterator object. """ ... @property @abstractmethod def timestamps(self): """ Returns microsecond timestamps corresponding to the trigger times of the events in the iterator. """ ... @property @abstractmethod def n_channels(self): """ Returns the number of channels in the iterator. """ ... @property @abstractmethod def _slice_info(self): # Returns a tuple containing a dictionary of input arguments used to construct # the iterator and another tuple which specifies the dictionary keys corresponding # to the channel- and index-equivalents. Note that those cannot have None-values # in the dictionary! E.g. # return ( {'path_h5': 'path/to/file', # 'dataset': 'events', # 'channels': 0, # 'inds': [1,2,3,4], # 'batch_size': 1 # }, # ('channel', 'inds') # ) ...
class IteratorCollection(IteratorBaseClass): """ Iterator object that chains multiple iterators. :param iterators: Iterator or List of Iterators to chain. :type iterators: Union[IteratorBaseClass, List[IteratorBaseClass]] :return: Iterable object :rtype: IteratorCollection .. code-block:: python it = H5Iterator(dh, "events", "event") it_collection = IteratorCollection([it, it]) # Or simply (output of iterator addition is IteratorCollection) it_collection = it + it """ def __init__(self, iterators: Union[IteratorBaseClass, List[IteratorBaseClass]]): super(IteratorBaseClass, self).__init__(iterators=iterators) # We do not construct the superclass because batching is handled differently self.fncs = list() # Check if all elements are IteratorBaseClass instances if isinstance(iterators, list): for it in iterators: if not isinstance(it, IteratorBaseClass): raise TypeError(f"All iterators must be child classes of 'IteratorBaseClass'. Not '{type(it)}'.") else: if isinstance(iterators, IteratorBaseClass): iterators = [iterators] else: raise TypeError(f"Unsupported type '{type(iterators)}' for input argument 'iterators'.") # Check if batch usage, number of channels, record_length, dt_us and the time axis are consistent batch_usage = [it.uses_batches for it in iterators] channel_usage = [it.n_channels for it in iterators] rec_usage = [it.record_length for it in iterators] dt_usage = [it.dt_us for it in iterators] t_usage = [it.t for it in iterators] if len(set(batch_usage)) != 1: raise ValueError(f"Either all iterators must use batches or none of them. Got {batch_usage}") if len(set(channel_usage)) != 1: raise ValueError(f"All iterators must contain the same number of channels. Got {channel_usage}") if len(set(rec_usage)) != 1: raise ValueError(f"All iterators must have the same record length. Got {rec_usage}") if len(set(dt_usage)) != 1: raise ValueError(f"All iterators must have the same time base. Got {dt_usage}") if not np.all(np.isclose(t_usage, t_usage[0])): raise ValueError(f"All iterators must have the same time axis (it.t).") self._iterators = iterators self._uses_batches = batch_usage[0] # made sure that batch usage is consistent above self._n_channels = channel_usage[0] # made sure that number of channels is consistent above self._dt_us = dt_usage[0] # made sure that time base is consistent above self._record_length = rec_usage[0] # made sure that record length is consistent above self._t = t_usage[0] # made sure that all time arrays are consistent above # Overrides superclass def __len__(self): return sum([len(it) for it in self._iterators]) def __repr__(self): out = f"{self.__class__.__name__}(n_events={len(self)})[" for it in self._iterators: out += f"\n\t- {it}" out += "\n\t]" if len(self.fncs) > 0: out += f"(preprocessing: {self.fncs})" return out def __enter__(self): for it in self._iterators: it.__enter__() return self def __exit__(self, typ, val, tb): for it in self._iterators: it.__exit__(typ, val, tb) def __iter__(self): self._chain = itertools.chain.from_iterable(self._iterators) return self def _next_raw(self): return next(self._chain) def __getitem__(self, val): # overriding IteratorBaseClass behavior if not isinstance(val, tuple): val = (val,) if len(val) > 2: raise IndexError(f"Too many indices for iterator: 2 values can be indexed (channel list and event list) but {len(val)} were indexed.") slice_channel = val[0] slice_inds = val[1] if len(val)>1 else slice(None) channels = np.arange(self._n_channels) # Make array of all indices channels_sliced = channels[slice_channel] # Indices that survive the slice channels_bool = np.zeros(self._n_channels, dtype=bool) # Turn surviving indices into channels_bool[channels_sliced] = True # boolean array inds = np.arange(len(self)) # Make array of all channels inds_sliced = inds[slice_inds] # Channels that survive the slice inds_bool = np.zeros(len(self), dtype=bool) # Turn surviving channels into inds_bool[inds_sliced] = True # boolean array # Boolean array for channels is identical for all iterators in collection # But we have to split the boolean array of surviving events and forward the pieces # to iterators in collection lens = np.array([len(it) for it in self._iterators]) sub_arrays = np.split(inds_bool, np.cumsum(lens)[:-1]) # Slice iterators in collection new_iterators = [it[channels_bool, a] for it, a in zip(self._iterators, sub_arrays)] # Create new collection and add processing new_collection = IteratorCollection(new_iterators) new_collection.add_processing(self.fncs.copy()) return new_collection # Forwards .with_record_length, .with_alignment, and .with_extended_window # to StreamIterator. def __getattr__(self, name): if name in ["with_record_length", "with_alignment", "with_extended_window"]: if not all( isinstance(x, (vai.iterators.StreamIterator, vai.iterators.PulseSimIterator)) for x in self.iterators ): raise NotImplementedError(f"Method '{name}' is only available if all iterators in the IteratorCollection are StreamIterators or PulseSimIterators (based on StreamIterators). At least one of the iterators in this IteratorCollection is neither. Got iterators {[it.__class__.__name__ for it in self.iterators]}.") if self.has_processing: raise NotImplementedError(f"Cannot use method '{name}' on iterators with processing because processing might depend on window size and/or alignment and cause obscure issues. Manually remove processing first using 'old_processing = it.pop_processing()', call '{name}' on the iterator without processing, then add the 'old_processing' again if it does not depend on window size and/or alignment, or add it again after adjusting its parameters to work with the new size/alignment.") return lambda *args, **kwargs: self.__class__([getattr(it, name)(*args, **kwargs) for it in self.iterators]) else: raise AttributeError(f"{self.__class__.__name__} has no attribute '{name}'.") # overrides default behavior def with_batchsize(self, batch_size: int): """ Returns an identical iterator but with a different batch size. :param batch_size: The new batch size. :type batch_size: int """ new_iterator = self.__class__([it.with_batchsize(batch_size) for it in self._iterators]) new_iterator.add_processing(self.fncs.copy()) return new_iterator @property def t(self): """Return the time axis (record window) of the events in the iterator.""" return self._t @property def record_length(self): return self._record_length @property def dt_us(self): return self._dt_us @property def ds_start_us(self): return np.min([it.ds_start_us for it in self._iterators]) @property def timestamps(self): return np.concatenate([it.timestamps for it in self._iterators]) # Overrides superclass @property def uses_batches(self): return self._uses_batches # Overrides superclass @property def n_batches(self): return sum([it.n_batches for it in self._iterators]) @property def n_channels(self): return self._n_channels @property def _slice_info(self): ... # Not needed here because __getitem__ was overridden @property def iterators(self): return self._iterators