Source code for cait.versatile.iterators.iteratorbase

from abc import ABC, abstractmethod
from typing import Union, List, Callable
import itertools

import numpy as np

from .batchresolver import BatchResolver

#### HELPER FUNCTIONS ####
def _ensure_array(x):
    if isinstance(x, str): x = [x]
    elif isinstance(x, int): x = [x]
    elif isinstance(x, np.integer): x = [int(x)]
    return np.array(x)
    
def _ensure_not_array(x):
    if isinstance(x, np.ndarray): x = x.tolist()
    if isinstance(x, np.integer): x = int(x)
    if isinstance(x, str): x = str(x)
    return x

[docs]class IteratorBaseClass(ABC): def __init__(self, inds: List[int], batch_size: int = None): self.fncs = list() self.__n_events = len(inds) # self._inds will be a list of batches. If we just take the inds list, we have batches of size 1, if we take [inds] # all inds are in one batch, otherwise it is a list of lists where each list is a batch if batch_size is None or batch_size == 1: self.__inds = inds self._uses_batches = False elif batch_size == -1: self.__inds = [inds] self._uses_batches = True else: self.__inds = [inds[i:i+batch_size] for i in range(0, len(inds), batch_size)] self._uses_batches = True self._n_batches = len(self._inds) def __len__(self): return self.__n_events def __enter__(self): return self def __exit__(self, typ, val, tb): ... @abstractmethod def __iter__(self): ... def __next__(self): return self._apply_processing(self._next_raw()) @abstractmethod def _next_raw(self): # This is to be used as the __next__ method by child classes. # The purpose of having it separately is that the IteratorBaseClass # then automatically applies the preprocessing ... def _apply_processing(self, out): for fnc in self.fncs: out = fnc(out) if not self.uses_batches else BatchResolver(fnc, self.n_channels)(out) return out def __repr__(self): out = f"{self.__class__.__name__}(n_events={len(self)}, n_channels={self.n_channels}, uses_batches={self.uses_batches}, record_length={self.record_length}, dt_us={self.dt_us}" if len(self.fncs) > 0: out += f", preprocessing: {self.fncs})" else: out += ")" return out def __getitem__(self, val): # Slice Iterator as if it was layed out as a numpy.ndarray. # The first argument slices the channel/key/... and the second slices the remaining list of events in the iterator. if not isinstance(val, tuple): val = (val,) if len(val) > 2: raise IndexError(f"Too many indices for iterator: 2 values can be indexed (channel list and event list) but {len(val)} were indexed.") # First index will numpy-slice channel (or equivalent) slice_channel = val[0] # Second index will numpy-slice indices (or equivalent) slice_inds = val[1] if len(val)>1 else slice(None) # Get parameters to reconstruct iterator and make a copy params, keys = self._slice_info new_params = dict(**params) # Slice relevant arguments new_params[keys[0]] = _ensure_not_array(_ensure_array(params[keys[0]])[slice_channel]) new_params[keys[1]] = _ensure_not_array(_ensure_array(params[keys[1]])[slice_inds]) # Create new instance of iterator and add processing new_iterator = self.__class__(**new_params) new_iterator.add_processing(self.fncs.copy()) # Return new iterator return new_iterator def __add__(self, other): if isinstance(self, IteratorCollection): l = [i.with_processing(self.fncs) for i in self.iterators] else: l = [self] if isinstance(other, IteratorBaseClass): if isinstance(other, IteratorCollection): l += [i.with_processing(other.fncs) for i in other.iterators] else: l += [other] else: raise TypeError(f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'") return IteratorCollection(l) def __radd__(self, other): # Used to return iterator for situation '0 + it = it' # This way, we can use the built-in sum() to sum a list of iterators if other == 0: return self else: return other.__add__(self)
[docs] def add_processing(self, f: Union[Callable, List[Callable]]): """ Add functions to be applied to each event before returning it. Batches are supported, i.e. if the iterator returns events in batches, the specified functions are applied to all events in a batch separately. However, the user is responsible for handling multiple channels correctly: Events are passed to the functions directly, even if it includes multiple channels. :param f: Function(s) to be applied. Function signature: f(event: np.ndarray) -> np.ndarray :type f: Union[Callable, List[Callable]] **Example:** :: import cait.versatile as vai def f1(event): return event + 1 def f2(event): return event*2 it = vai.MockData().get_event_iterator() it.add_processing([f1, f2]) """ if not isinstance(f, list): f = [f] self.fncs += f # return instance such that it is chainable and can be used in one-liners return self
[docs] def with_processing(self, f: Union[Callable, List[Callable]]): """ Same as ``add_processing`` but it returns a new iterator instead of modifying the original one. :param f: Function(s) to be applied. Function signature: f(event: np.ndarray) -> np.ndarray :type f: Union[Callable, List[Callable]] **Example:** :: import cait.versatile as vai def f1(event): return event + 1 def f2(event): return event*2 it = vai.MockData().get_event_iterator() new_it = it.with_processing([f1, f2]) """ return self[:,:].add_processing(f)
[docs] def grab(self, which: Union[int, list]): """ Grab specified event(s) and return it/them as numpy array. :param which: Events of interest. :type which: Union[int, list] **Example:** :: import cait.versatile as vai it = vai.MockData().get_event_iterator() # Get events from mock data selected_event = it.grab(-1) # Get the last event in the iterator selected_events = it.grab([1,7,9]) # Get events with indices 1, 7, 9 """ return np.squeeze(np.array(list(self[:, which])))[()]
@property def t(self): """ Return the time axis (record window) of the events in the iterator. It is a millisecond array with 0 being at 1/4th of the window. """ return (np.arange(self.record_length) - self.record_length/4)*self.dt_us/1000 @property def _inds(self): return self.__inds @property def uses_batches(self): """ Returns True if the iterator returns batches. """ return self._uses_batches @property def n_batches(self): """ Returns the number of batches in the iterator. """ return self._n_batches @property def hours(self): """ Returns the times (in hours) of the events in this iterators since the start of the underlying datasource. """ return (self.timestamps - self.ds_start_us)/1e6/3600 @property @abstractmethod def record_length(self): """ Returns the record length (in samples) of the events in the iterator. """ ... @property @abstractmethod def dt_us(self): """ Returns the time base (in microseconds) of the events in the iterator. """ ... @property @abstractmethod def ds_start_us(self): """ The microsecond timestamp of the start of the recording for the datasource underlying this iterator object. """ ... @property @abstractmethod def timestamps(self): """ Returns microsecond timestamps corresponding to the trigger times of the events in the iterator. """ ... @property @abstractmethod def n_channels(self): """ Returns the number of channels in the iterator. """ ... @property @abstractmethod def _slice_info(self): # Returns a tuple containing a dictionary of input arguments used to construct # the iterator and another tuple which specifies the dictionary keys corresponding # to the channel- and index-equivalents. Note that those cannot have None-values # in the dictionary! E.g. # return ( {'path_h5': 'path/to/file', # 'dataset': 'events', # 'channels': 0, # 'inds': [1,2,3,4], # 'batch_size': 1 # }, # ('channel', 'inds') # ) ...
class IteratorCollection(IteratorBaseClass): """ Iterator object that chains multiple iterators. :param iterators: Iterator or List of Iterators to chain. :type iterators: Union[IteratorBaseClass, List[IteratorBaseClass]] :return: Iterable object :rtype: IteratorCollection >>> it = H5Iterator(dh, "events", "event") >>> it_collection = IteratorCollection([it, it]) >>> # Or simply (output of iterator addition is IteratorCollection) >>> it_collection = it + it """ def __init__(self, iterators: Union[IteratorBaseClass, List[IteratorBaseClass]]): # We do not construct the superclass because batching is handled differently self.fncs = list() # Check if all elements are IteratorBaseClass instances if isinstance(iterators, list): for it in iterators: if not isinstance(it, IteratorBaseClass): raise TypeError(f"All iterators must be child classes of 'IteratorBaseClass'. Not '{type(it)}'.") else: if isinstance(iterators, IteratorBaseClass): iterators = [iterators] else: raise TypeError(f"Unsupported type '{type(iterators)}' for input argument 'iterators'.") # Check if batch usage, number of channels, record_length and dt_us are consistent batch_usage = [it.uses_batches for it in iterators] channel_usage = [it.n_channels for it in iterators] rec_usage = [it.record_length for it in iterators] dt_usage = [it.dt_us for it in iterators] if len(set(batch_usage)) != 1: raise ValueError(f"Either all iterators must use batches or none of them. Got {batch_usage}") if len(set(channel_usage)) != 1: raise ValueError(f"All iterators must contain the same number of channels. Got {channel_usage}") if len(set(rec_usage)) != 1: raise ValueError(f"All iterators must have the same record length. Got {rec_usage}") if len(set(dt_usage)) != 1: raise ValueError(f"All iterators must have the same time base. Got {dt_usage}") self._iterators = iterators self._uses_batches = batch_usage[0] # made sure that batch usage is consistent above self._n_channels = channel_usage[0] # made sure that number of channels is consistent above self._dt_us = dt_usage[0] # made sure that time base is consistent above self._record_length = rec_usage[0] # made sure that record length is consistent above # Overrides superclass def __len__(self): return sum([len(it) for it in self._iterators]) def __repr__(self): out = f"{self.__class__.__name__}(n_events={len(self)})[" for it in self._iterators: out += f"\n\t- {it}" out += "\n\t]" if len(self.fncs) > 0: out += f"(preprocessing: {self.fncs})" return out def __enter__(self): for it in self._iterators: it.__enter__() return self def __exit__(self, typ, val, tb): for it in self._iterators: it.__exit__(typ, val, tb) def __iter__(self): self._chain = itertools.chain.from_iterable(self._iterators) return self def _next_raw(self): return next(self._chain) def __getitem__(self, val): # overriding IteratorBaseClass behavior if not isinstance(val, tuple): val = (val,) if len(val) > 2: raise IndexError(f"Too many indices for iterator: 2 values can be indexed (channel list and event list) but {len(val)} were indexed.") slice_channel = val[0] slice_inds = val[1] if len(val)>1 else slice(None) channels = np.arange(self._n_channels) # Make array of all indices channels_sliced = channels[slice_channel] # Indices that survive the slice channels_bool = np.zeros(self._n_channels, dtype=bool) # Turn surviving indices into channels_bool[channels_sliced] = True # boolean array inds = np.arange(len(self)) # Make array of all channels inds_sliced = inds[slice_inds] # Channels that survive the slice inds_bool = np.zeros(len(self), dtype=bool) # Turn surviving channels into inds_bool[inds_sliced] = True # boolean array # Boolean array for channels is identical for all iterators in collection # But we have to split the boolean array of surviving events and forward the pieces # to iterators in collection lens = np.array([len(it) for it in self._iterators]) sub_arrays = np.split(inds_bool, np.cumsum(lens)[:-1]) # Slice iterators in collection new_iterators = [it[channels_bool, a] for it, a in zip(self._iterators, sub_arrays)] # Create new collection and add processing new_collection = IteratorCollection(new_iterators) new_collection.add_processing(self.fncs.copy()) return new_collection @property def record_length(self): return self._record_length @property def dt_us(self): return self._dt_us @property def ds_start_us(self): return np.min([it.ds_start_us for it in self._iterators]) @property def timestamps(self): return np.concatenate([it.timestamps for it in self._iterators]) # Overrides superclass @property def uses_batches(self): return self._uses_batches # Overrides superclass @property def n_batches(self): return sum([it.n_batches for it in self._iterators]) @property def n_channels(self): return self._n_channels @property def _slice_info(self): ... # Not needed here because __getitem__ was overridden @property def iterators(self): return self._iterators