Source code for cait.versatile.iterators.impl_stream

from typing import List, Union

import numpy as np

from .iteratorbase import IteratorBaseClass


def _validate_new_window(stream, name, new_alignment, new_record_length):
    ts = stream.timestamps
    dt = stream.dt_us
    # Timestamps of the first and last samples of 
    # the records currently in the iterator.
    record_start_ts = ts - dt * int(new_alignment * new_record_length)
    record_time = dt * new_record_length
    record_end_ts = record_start_ts + ( record_time - dt )
    
    # Check if any of the windows would extend outside the stream samples.
    flag_inside = (
        ( record_start_ts >= stream.ds_start_us )
        * ( record_end_ts <= stream._stream.time[-1] )
    )
    n_outside = np.sum(~flag_inside)

    if n_outside > 0:
        raise IndexError(f"If {name} was changed, {n_outside} of the events would have record windows extending outside of the valid sample range of the stream.")

[docs] class StreamIterator(IteratorBaseClass): """ Iterator object that returns voltage traces for given trigger indices of a stream file. :param stream: The stream object to read the voltage traces from. :type stream: StreamBaseClass :param keys: The keys (channel names) of the stream object to be iterated over. :type keys: Union[str, List[str]] :param inds: The stream indices for which we want to read the voltage traces. This index is aligned according to 'alignment' (default: at 1/4th of the record window). :type inds: Union[int, List[int]] :param record_length: The number of samples to be returned for each index. Usually, those are powers of 2, e.g. 16384 :type record_length: int :param alignment: A number in the interval [0,1] which determines the alignment of the record window (of length `record_length`) relative to the specified index. E.g. if `alignment=1/2`, the record window is centered around the index. Defaults to 1/4. :type alignment: float :param batch_size: The number of events to be returned at once (these are all read together). There will be a trade-off: large batch_sizes cause faster read speed but increase the memory usage. :type batch_size: int :return: Iterable object :rtype: StreamIterator """ def __init__(self, stream, keys: Union[str, List[str]], inds: Union[int, List[int]], record_length: int, alignment: float = 1/4, batch_size: int = None): if 0 > alignment or 1 < alignment: raise ValueError("'alignment' has to be in the interval [0,1]") self._keys = [keys] if isinstance(keys, str) else keys inds = [inds] if isinstance(inds, int) else [int(i) for i in inds] # Does batch handling and creates properties self._inds, self.uses_batches, and self.n_batches. # Also sets up serializing. super().__init__(inds=inds, batch_size=batch_size, stream=stream, keys=keys, record_length=record_length, alignment=alignment) self._stream = stream self._record_length = record_length # Save values to reconstruct iterator: self._params = {'stream': stream, 'keys': self._keys, 'inds': inds, 'record_length': record_length, 'alignment': alignment, 'batch_size': batch_size} self._alignment = alignment self._interval = (int(alignment*record_length), record_length - int(alignment*record_length)) def __iter__(self): self._current_batch_ind = 0 return self def __enter__(self): self._stream.__enter__() return self def __exit__(self, typ, val, tb): self._stream.__exit__(typ, val, tb) def _next_raw(self): if self._current_batch_ind < self.n_batches: event_inds_in_batch = self._inds[self._current_batch_ind] self._current_batch_ind += 1 if isinstance(event_inds_in_batch, int): s = slice(event_inds_in_batch - self._interval[0], event_inds_in_batch + self._interval[1]) if len(self._keys) == 1: out = self._stream[self._keys[0], s, 'as_voltage'] else: out = [self._stream[k, s, 'as_voltage'] for k in self._keys] return np.array(out) else: all_slices = np.r_.__getitem__(tuple( [slice(i - self._interval[0], i + self._interval[1]) for i in event_inds_in_batch] ) ).reshape(len(event_inds_in_batch), self._record_length) if len(self._keys) == 1: out = self._stream[self._keys[0], all_slices, 'as_voltage'] else: out = [self._stream[k, all_slices, 'as_voltage'] for k in self._keys] out = np.transpose(np.array(out), axes=[1,0,2]) return np.array(out) else: raise StopIteration
[docs] def with_alignment(self, alignment: float): """ Return an iterator for identical timestamps but with different alignment. :param alignment: A number in the interval [0,1] which determines the alignment of the record window (of length `record_length`) relative to the specified index. E.g. if `alignment=1/2`, the record window is centered around the index. :type alignment: float .. warning:: Requires that all event traces are still within the stream boundaries after changing alignment. """ if 0 > alignment or 1 < alignment: raise ValueError("'alignment' has to be in the interval [0,1]") if self.has_processing: raise NotImplementedError("Cannot change alignment for iterators with processing because processing might depend on alignment and cause obscure issues. Manually remove processing first using 'old_processing = it.pop_processing()', call 'with_alignment' on the iterator without processing, then add the 'old_processing' again if it does not depend on alignment, or add it again after adjusting its parameters to work with the new alignment.") _validate_new_window(self, "alignment", new_alignment=alignment, new_record_length=self.record_length) params, _ = self._slice_info new_params = params.copy() new_params["alignment"] = alignment return self.__class__(**new_params)
[docs] def with_record_length(self, record_length: int): """ Return an iterator for identical timestamps but with different record length. :param record_length: The number of samples to be returned for each event. Usually, those are powers of 2, e.g. 16384 :type record_length: int .. warning:: Requires that all event traces are still within the stream boundaries after changing record length. """ if self.has_processing: raise NotImplementedError("Cannot change record_length for iterators with processing because processing might depend on window size and cause obscure issues. Manually remove processing first using 'old_processing = it.pop_processing()', call 'with_record_length' on the iterator without processing, then add the 'old_processing' again if it does not depend on window size, or add it again after adjusting its parameters to work with the new window size.") _validate_new_window(self, "record_length", new_alignment=self.alignment, new_record_length=record_length) params, _ = self._slice_info new_params = params.copy() new_params["record_length"] = record_length return self.__class__(**new_params)
[docs] def with_extended_window(self): """Return an iterator for identical timestamps but with the window size increased to include one additional record length before and after the previous window.""" return self.with_alignment(1 / 3 + self.alignment / 3).with_record_length(3 * self.record_length)
@property def alignment(self): """ The time axis alignment of the iterator. For most event iterators, this is 1/4, i.e. the timestamp of an event corresponds to the sample at 1/4th of the record window. However, when constructing a StreamIterator, you may choose the alignment. Therefore, for StreamIterators, this value may be anything in the interval [0, 1]. """ return self._alignment # Overridden here because StreamIterator may define different alignment @property def t(self): """ Return the time axis (record window) of the events in the iterator. It is a millisecond array with 0 aligned according to the 'alignment' argument used when constructing the StreamIterator. """ return (np.arange(self.record_length) - self.alignment*self.record_length)*self.dt_us/1000 @property def record_length(self): return self._record_length @property def dt_us(self): return self._stream.dt_us @property def ds_start_us(self): return self._stream.start_us @property def timestamps(self): return self._stream.time[self._params["inds"]] @property def n_channels(self): return len(self._keys) @property def _slice_info(self): return (self._params, ('keys', 'inds'))