Source code for cait.versatile.functions.apply

from typing import Callable, List
from inspect import signature, _empty
# Note that multiprocessing.Pool cannot handle lambdas
# which is why we use multiprocess here. The Pool interface
# os otherwise identical
from multiprocess import Pool
import itertools

import numpy as np
from tqdm.auto import tqdm

import cait as ai
from ..iterators.iteratorbase import IteratorBaseClass
from ..iterators.batchresolver import BatchResolver

class Compose:
    def __init__(self, fncs: List[Callable]):
        self._fncs = fncs
    def __call__(self, data):
        out = self._fncs[0](data)
        for f in self._fncs[1:]: out = f(out)
        return out
    
[docs] def apply(f: Callable, ev_iter: IteratorBaseClass, n_processes: int = None, pb_prefix: str = ""): """ Apply a function to events provided by an EventIterator. Multiprocessing and resolving batches as returned by the iterator is done automatically. The function returns a numpy array where the first dimension corresponds to the events returned by the iterator. Higher dimensions are as returned by the function that is applied. Batches are resolved, i.e. calls with an ``EventIterator(..., batch_size=1)`` and ``EventIterator(..., batch_size=100)`` yield identical results. :param f: Function to be applied to events. :type f: Callable :param ev_iter: Events for which the function should be applied. :type ev_iter: :class:`~cait.versatile.iterators.iteratorbase.IteratorBaseClass` :param n_processes: Number of processes to use for multiprocessing. If None, ``cait._available_workers`` is used. Defaults to None. :type n_processes: int, optional :param pb_prefix: An optional prefix for the progress bar. :type pb_prefix: str :return: Results of ``f`` for all events in ``ev_iter``. Has same structure as output of ``f`` (just with an additional event dimension). :rtype: Any **Example:** .. code-block:: python import cait.versatile as vai import numpy as np def func1(event): return np.max(event) def func2(event): return np.min(event), np.max(event) # Example when func has one output it = vai.MockData().get_event_iterator(batch_size=42) out = vai.apply(func1, it) # Example when func has two outputs it = vai.MockData().get_event_iterator(batch_size=42) out1, out2 = vai.apply(func2, it) # Example using a function defined inline it = vai.MockData().get_event_iterator()[0] out = vai.apply(lambda x: np.max(x), it) """ # Check if 'ev_iter' is a cait.versatile iterator object if not isinstance(ev_iter, IteratorBaseClass): raise TypeError(f"Input argument 'ev_iter' must be an instance of {IteratorBaseClass} not '{type(ev_iter)}'.") # Check if 'ev_iter' is not empty if len(ev_iter)==0: raise IndexError(f"Input argument 'ev_iter' must contain at least 1 event (iterator is empty).") # Check if 'f' is indeed a function if not callable(f): raise TypeError(f"Input argument 'f' must be callable.") # Use all available workers if no number of processes is provided if n_processes is None: n_processes = ai._available_workers # Check if 'f' takes exactly one required argument (the event) n_req_args = np.sum([x.default is _empty for x in signature(f).parameters.values()]) if n_req_args != 1: raise TypeError(f"Input function {f} has too many required arguments ({n_req_args}). Only functions which take one (non-default) argument (the event) are supported.") # pop processing from iterator (if exists) and construct a list of # functions to be applied to each event if ev_iter.has_processing: # make copy of event iterator before removing its processing ev_iter = ev_iter[:, :] processing = ev_iter.pop_processing() fncs = processing + [f] else: fncs = [f] # If iterator returns batches, we dress all functions with a BatchResolver if ev_iter.uses_batches: fncs = [BatchResolver(f, ev_iter.n_channels) for f in fncs] # Finally, we compose the list of functions (they will be applied consecutively # by the workers in the process pool). If we only applied the function (but not # the iterator processing) in the process pool, the parent process would have # to calculate all processing outputs (defeating the purpose of using multiple # workers) F = Compose(fncs) tqdm_config = dict(total=ev_iter.n_batches, unit="batches" if ev_iter.uses_batches else "events", delay=2, desc=pb_prefix) with ev_iter as ev_it: if n_processes > 1: with Pool(n_processes) as pool: out = list(tqdm(pool.imap(F, ev_it), **tqdm_config)) else: out = [F(ev) for ev in tqdm(ev_it, **tqdm_config)] # Chain batches such that the list is indistinguishable from a list using no batches # (If uses_batches, 'out' is a list of lists) if ev_iter.uses_batches: out = list(itertools.chain.from_iterable(out)) # If elements in 'out' are tuples, this means that the function had multiple outputs. # In this case, we transpose the list so that we have a tuple of outputs where each element in the tuple is also converted to a numpy.array of len(ev_iter) if isinstance(out[0], tuple): out = tuple(np.array(x) for x in zip(*out)) else: out = np.array(out) return out