Univeral Training Data Augmentation


Correspondence to: felix.wagner@oeaw.ac.at

In this notebook we explore the data augmentation module of the Cait Python package.

We create here a universal training set, to identify particle recoils and empty noise baselines against artifacts. In the second section of the notebook, we train a classifier model to perform this task.

from cait.augment import ParameterSampler, plot_events, unfold, L2
import numpy as np
import cait as ai  # install from the develop branch
import matplotlib.pyplot as plt
import matplotlib as mpl
import h5py
from tqdm.auto import tqdm
from scipy import signal
from cait.models import CNNModule
from torchvision import transforms
from cait.datasets import SingleMinMaxNorm, DownSample, ToTensor, CryoDataModule
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

%config InlineBackend.figure_formats = ['svg']

Some definitions for the plots.

mpl.rcParams['figure.figsize'] = (7.2, 4.45)
mpl.rcParams['savefig.dpi'] = 300

Data Augmentation

Define the global parameters. For details, see the notebook Cryogenic detector data augmentation.

For the whole notebook: Only capitalized parameters need to be set by the user.

RECORD_LENGTH = 16384
SAMPLE_FREQUENCY = 25000

POLYNOMIAL_DRIFTS = True
RASTERIZE = False
SQUARE_WAVES = False
SATURATION = True
CLASS_NAMES = ['Event Pulse',
               'Noise',
               'Decaying Baseline',
               'Temperature Rise',
               'Spike',
               'Squid Jump',
               'Reset',
               'Cosinus Tail',
               'Decaying Baseline with Event Pulse',
               'Pile Up',
               'Early or late Trigger',
               ]


# ----------------------------------------- 
# no need to change below parameters
# ----------------------------------------- 

label_names = {
    'unlabeled': 0,
    'Event Pulse': 1,
    'Test/Control Pulse': 2,
    'Noise': 3,
    'Squid Jump': 4,
    'Spike': 5,
    'Early or late Trigger': 6,
    'Pile Up': 7,
    'Carrier Event': 8,
    'Strongly Saturated Event Pulse': 9,
    'Strongly Saturated Test/Control Pulse': 10,
    'Decaying Baseline': 11,
    'Temperature Rise': 12,
    'Stick Event': 13,
    'Square Waves': 14,
    'Human Disturbance': 15,
    'Large Sawtooth': 16,
    'Cosinus Tail': 17,
    'Light only Event': 18,
    'Ring & Light Event': 19,
    'Sharp Light Event': 20,
    'Reset': 21,
    'Decaying Baseline with Event Pulse': 22,
    'Decaying Baseline with Tail Event': 23,
    'unknown/other': 99,
}

We define the sampler class.

parsam = ParameterSampler(record_length=RECORD_LENGTH,
                          sample_frequency=SAMPLE_FREQUENCY)

Plot some augmented events. We did not define the baseline resolution, pulse shapes, saturation curve or noise power spectrum. Therefore for every batch, a new one is sampled, making the data overally more diverse.

NMBR_PLOTS = 1

# classes = np.random.choice(CLASS_NAMES, size=NMBR_PLOTS)
classes = [CLASS_NAMES[0] for i in range(NMBR_PLOTS)]

for i in range(NMBR_PLOTS):

    event, info = parsam.get_event(label=classes[i],
                                   size=9,
                                   rasterize=RASTERIZE,
                                   poly=POLYNOMIAL_DRIFTS,
                                   square=SQUARE_WAVES,
                                   saturation=SATURATION,
                                   verb=True,
                                   )

    print(classes[i])
    plot_events(event, t=parsam.t)
    if classes[i] == 'Event Pulse':
        print('Resolution (only meaningful without Saturation): ', np.std(np.abs(np.max(event, axis=1) - info['pulse_height'])))
Sample Noise...
Sample Polynomials...
Sample Pulse Nmbr  0
Sample Saturation ...
Event Pulse
../_images/13universaltrainingset_12_7.svg
Resolution (only meaningful without Saturation):  2.803244455274235

Now create the HDF5 data set.

FNAME = 'test_data/universal_v0_1.h5'
IDX = list(range(11))  # which to take from the class names
BATCHSIZE = 50
EVENTS_PER_CLASS = {  # use always multiple of batch size!
    'Event Pulse': 2000,
    'Noise': 1000,
    'Decaying Baseline': 150,
    'Temperature Rise': 150,
    'Spike': 150,
    'Squid Jump': 150,
    'Reset': 300,
    'Cosinus Tail': 500,
    'Decaying Baseline with Event Pulse': 500,
    'Pile Up': 1000,
    'Early or late Trigger': 500,
}
DIV = 1
TOTAL_EVENTS = int(np.sum(list(EVENTS_PER_CLASS.values()))/DIV)
BATCHSIZE = int(BATCHSIZE/DIV)

# -------------------------------------------
# no changes required below this line!
# -------------------------------------------

with h5py.File(FNAME, 'w') as f:
    f.require_group('events')
    f['events'].create_dataset('event',
                                shape=(1, TOTAL_EVENTS, RECORD_LENGTH),
                                dtype=np.float32)
    f['events'].create_dataset('labels',
                                shape=(1, TOTAL_EVENTS),
                                dtype=int)
    f['events'].create_dataset('true_ph',
                                shape=(1, TOTAL_EVENTS),
                                dtype=float)
    f['events'].create_dataset('true_onset',
                                shape=(TOTAL_EVENTS, ),
                                dtype=float)

    bar = tqdm(total=TOTAL_EVENTS)
    bcount = 0

    for i in IDX:

        bar.write('Simulating {} ...'.format(CLASS_NAMES[i]))

        nmbr_events = int(EVENTS_PER_CLASS[CLASS_NAMES[i]] / DIV)
        batches = int(nmbr_events / BATCHSIZE)

        for b in range(batches):
            event, info = parsam.get_event(label=CLASS_NAMES[i],
                                           size=BATCHSIZE,
                                           rasterize=RASTERIZE,
                                           poly=POLYNOMIAL_DRIFTS,
                                           square=SQUARE_WAVES,
                                           saturation=SATURATION,
                                           verb=False,
                                           )

            f['events']['event'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), :] = event
            f['events']['labels'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = label_names[CLASS_NAMES[i]]
            f['events']['true_ph'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_height']
            f['events']['true_onset'][int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['t0']*1000

            # attributes

            f['events']['labels'].attrs.create(name='unlabeled', data=0)
            f['events']['labels'].attrs.create(name='Event_Pulse', data=1)
            f['events']['labels'].attrs.create(name='Test/Control_Pulse', data=2)
            f['events']['labels'].attrs.create(name='Noise', data=3)
            f['events']['labels'].attrs.create(name='Squid_Jump', data=4)
            f['events']['labels'].attrs.create(name='Spike', data=5)
            f['events']['labels'].attrs.create(name='Early_or_late_Trigger', data=6)
            f['events']['labels'].attrs.create(name='Pile_Up', data=7)
            f['events']['labels'].attrs.create(name='Carrier_Event', data=8)
            f['events']['labels'].attrs.create(name='Strongly_Saturated_Event_Pulse', data=9)
            f['events']['labels'].attrs.create(name='Strongly_Saturated_Test/Control_Pulse', data=10)
            f['events']['labels'].attrs.create(name='Decaying_Baseline', data=11)
            f['events']['labels'].attrs.create(name='Temperature_Rise', data=12)
            f['events']['labels'].attrs.create(name='Stick_Event', data=13)
            f['events']['labels'].attrs.create(name='Square_Waves', data=14)
            f['events']['labels'].attrs.create(name='Human_Disturbance', data=15)
            f['events']['labels'].attrs.create(name='Large_Sawtooth', data=16)
            f['events']['labels'].attrs.create(name='Cosinus_Tail', data=17)
            f['events']['labels'].attrs.create(name='Light_only_Event', data=18)
            f['events']['labels'].attrs.create(name='Ring_Light_Event', data=19)
            f['events']['labels'].attrs.create(name='Sharp_Light_Event', data=20)
            f['events']['labels'].attrs.create(name='Reset', data=21)
            f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Event_Pulse', data=22)
            f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Tail_Pulse', data=23)
            f['events']['labels'].attrs.create(name='unknown/other', data=99)

            bar.update(BATCHSIZE)
            bcount += 1
Simulating Event Pulse ...
Simulating Noise ...
Simulating Decaying Baseline ...
Simulating Temperature Rise ...
Simulating Spike ...
Simulating Squid Jump ...
Simulating Reset ...
Simulating Cosinus Tail ...
Simulating Decaying Baseline with Event Pulse ...
Simulating Pile Up ...
Simulating Early or late Trigger ...

Lets look into the simulated data with the VizTool!

dh = ai.DataHandler(nmbr_channels=1,
              sample_frequency=SAMPLE_FREQUENCY,
              record_length=RECORD_LENGTH)
dh.set_filepath('test_data', 'universal_v0_1', appendix=False)
DataHandler Instance created.

First we calculate the main parameters and additional main paramters.

dh.calc_mp()
dh.calc_additional_mp(no_of=True)
CALCULATE MAIN PARAMETERS.
CALCULATE ADDITIONAL MAIN PARAMETERS.

Looks like a total mess, right?

datasets = {
    'Pulse Height Phonon (V)': ['pulse_height', 0, None],
    'Rise Time Phonon (ms)': ['rise_time', 0, None],
    'Decay Time Phonon (ms)': ['decay_time', 0, None],
    'Onset Phonon (ms)': ['onset', 0, None],
    'Slope Phonon (V)': ['slope', 0, None],
    'Variance Phonon (V^2)': ['var', 0, None],
    'Mean Phonon (V)': ['mean', 0, None],
    'Skewness Phonon': ['skewness', 0, None],
}

viz = ai.VizTool(path_h5='test_data/', 
              fname='universal_v0_1',
              group='events', 
              datasets=datasets, 
              nmbr_channels=1, 
              batch_size=1000,
              sample_frequency=25000,
              record_length=16384)
viz.set_colors(dh.get('events', 'labels')[0])
viz.show()
DataHandler Instance created.

Universal Classification

Now we can train a univeral classifier model, which identifies particle recoils agains all kinds of artifacts, noise or pile up! We mostly use the same functions as in the ‘Classification with Neural Network’ Notebook, just this time we will use a CNN model.

We want to do a binary classification between events and other artifacts (noise, etc). Therefore we set all labels to 0 or 1, depending on if they show an event or not.

label_cuts = ai.cuts.LogicalCut(np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 17))
print(label_cuts.counts(), label_cuts.counts()/label_cuts.total())
2500 0.390625
dh.include_values(values=label_cuts.get_flag(),
                 naming='labels_binary', channel=0, type='events', delete_old=True)
Included values.

The cells below work very much the same way as they do for the other NN classification notebook.

# some parameters
# nmbr_gpus = ... uncommment and put in trainer to use GPUs
path_h5 = 'test_data/universal_v0_1.h5'
path_h5_test = 'test_data/test_v0_1.h5'
type = 'events'  # the group key for the data in the HDF5 set
keys = ['event', 'labels_binary']  # the datasets in the group from which we include data in the samples for the NN
channel_indices = [[0], [0]]  # the first indices of the datasets
feature_indices = [None, None]  # the third indices of the datasets
feature_keys = ['event_ch0']    # the keys in the samples of the NN dataset that are input to the NN
                                # in the data set for the NN, the keys have additionally appended the channel index 
label_keys = ['labels_binary_ch0']  # the keys in the samples of the NN dataset that are labels to the NN
norm_vals = {'event_ch0': [0, 10]}  # we do a min - max normalization of all samples, so these are roughly the lowest and highest values of the events in the data set
down_keys = ['event_ch0']  # if we input the raw time series, we apply downsampling first
down = 32                  # all samples in the NN dataset with the indices specified above are by the factor down downsampled
record_length = 16384
nmbr_classes = 2  # the number of classes in the data set - attention - the class index of the carrier event is 8, therefore we need at least 9 classes, even though one two of them are present in our data set
max_epochs = 20 # the maximal trianing epochs of the neural network
save_naming = 'cnn-clf'  # the name adition if we want to save the trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create the transforms
trans = transforms.Compose([SingleMinMaxNorm(norm_vals.keys()),
                            DownSample(keys=down_keys, down=down),
                            ToTensor()])
# create data module and init the setup
dm = CryoDataModule(hdf5_path=path_h5,
                    type=type,
                    keys=keys,
                    channel_indices=channel_indices,
                    feature_indices=feature_indices,
                    transform=trans,
                   )
dm.prepare_data(val_size=0.2,
                test_size=0.1,
                batch_size=16,
                dataset_size=None,
                nmbr_workers=0,
                only_idx=None,
                shuffle_dataset=True,
                random_seed=21,
                feature_keys=feature_keys,
                label_keys=label_keys,
                keys_one_hot=label_keys,
               )
dm.setup()
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning:

DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.
# create cnn clf
cnn = CNNModule(
    kernelsize=8,
    input_size=int(record_length/down),
    device_name=device,
    nmbr_out=nmbr_classes,  # this is the number of labels
    lr=1e-3,
    label_keys=label_keys,
    feature_keys=feature_keys,
    down=down,
    down_keys=feature_keys,
    norm_vals=norm_vals,
    norm_type='indiv_minmax',
    offset_keys=feature_keys,
)
# create callback to save the best model
checkpoint_callback = ModelCheckpoint(dirpath='callbacks',
                                      monitor='val_loss',
                                      filename=save_naming + 'down{down}-{epoch:02d}-{val_loss:.2f}')
# create instance of Trainer
trainer = Trainer(max_epochs=max_epochs,
                  callbacks=[checkpoint_callback],
                  # gpus=1,
                 )
# keyword gpus=nmbr_gpus for GPU Usage
# keyword max_epochs for number of maximal epochs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
# all training happens here
trainer.fit(model=cnn,
            datamodule=dm)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:469: LightningDeprecationWarning:

DataModule.setup has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.setup.


  | Name  | Type   | Params
---------------------------------
0 | conv1 | Conv1d | 450   
1 | conv2 | Conv1d | 4.0 K 
2 | fc1   | Linear | 14.2 K
3 | fc2   | Linear | 402   
---------------------------------
19.1 K    Trainable params
0         Non-trainable params
19.1 K    Total params
0.076     Total estimated model params size (MB)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:623: UserWarning:

Checkpoint directory /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks exists and is not empty.
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning:

The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning:

The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
# load best model
cnn.load_from_checkpoint(checkpoint_callback.best_model_path)
# cnn.load_from_checkpoint('/Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/cnn-clfdowndown=0-epoch=18-val_loss=0.16.ckpt')
CNNModule(
  (conv1): Conv1d(1, 50, kernel_size=(8,), stride=(4,))
  (conv2): Conv1d(50, 10, kernel_size=(8,), stride=(4,))
  (fc1): Linear(in_features=70, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=2, bias=True)
)
cnn()
CNNModule(
  (conv1): Conv1d(1, 50, kernel_size=(8,), stride=(4,))
  (conv2): Conv1d(50, 10, kernel_size=(8,), stride=(4,))
  (fc1): Linear(in_features=70, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=2, bias=True)
)
# run test set
result = trainer.test(model=cnn, datamodule=dm)
print(result)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-35-0076aad303dc> in <module>
      1 # run test set
----> 2 result = trainer.test(model=cnn, dataloaders=[dm.test_dataloader])
      3 print(result)

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in test(self, model, dataloaders, ckpt_path, verbose, datamodule, test_dataloaders)
    904             )
    905             dataloaders = test_dataloaders
--> 906         return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
    907 
    908     def _test_impl(

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    680         """
    681         try:
--> 682             return trainer_fn(*args, **kwargs)
    683         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    684         except KeyboardInterrupt as exception:

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
    947 
    948         # run test
--> 949         results = self._run(model, ckpt_path=self.tested_ckpt_path)
    950 
    951         assert self.state.stopped

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1193 
   1194         # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1195         self._dispatch()
   1196 
   1197         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
   1269     def _dispatch(self):
   1270         if self.evaluating:
-> 1271             self.training_type_plugin.start_evaluating(self)
   1272         elif self.predicting:
   1273             self.training_type_plugin.start_predicting(self)

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_evaluating(self, trainer)
    204     def start_evaluating(self, trainer: "pl.Trainer") -> None:
    205         # double dispatch to initiate the test loop
--> 206         self._results = trainer.run_stage()
    207 
    208     def start_predicting(self, trainer: "pl.Trainer") -> None:

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
   1280 
   1281         if self.evaluating:
-> 1282             return self._run_evaluate()
   1283         if self.predicting:
   1284             return self._run_predict()

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_evaluate(self)
   1328 
   1329         with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
-> 1330             eval_loop_results = self._evaluation_loop.run()
   1331 
   1332         # remove the tensors from the eval results

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    143             try:
    144                 self.on_advance_start(*args, **kwargs)
--> 145                 self.advance(*args, **kwargs)
    146                 self.on_advance_end()
    147                 self.restarting = False

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
    108         dl_max_batches = self._max_batches[dataloader_idx]
    109 
--> 110         dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
    111 
    112         # store batch level output per dataloader

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    138         self.reset()
    139 
--> 140         self.on_run_start(*args, **kwargs)
    141 
    142         while not self.done:

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in on_run_start(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
     84 
     85         self._reload_dataloader_state_dict(data_fetcher)
---> 86         self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
     87 
     88     def advance(

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py in _update_dataloader_iter(data_fetcher, batch_idx)
    119     if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
    120         # restore iteration
--> 121         dataloader_iter = enumerate(data_fetcher, batch_idx)
    122     else:
    123         dataloader_iter = iter(data_fetcher)

~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py in __iter__(self)
    195             raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
    196         self.reset()
--> 197         self.dataloader_iter = iter(self.dataloader)
    198         self._apply_patch()
    199         self.prefetching(self.prefetch_batches)

TypeError: 'method' object is not iterable

Test

Now lets test the model with the data set we produced in the notebook Cryogenic Detector Data Augmentation!

dh_test = ai.DataHandler(nmbr_channels=1,
              sample_frequency=SAMPLE_FREQUENCY,
              record_length=RECORD_LENGTH)
dh_test.set_filepath('test_data', 'test_v0_1', appendix=False)
DataHandler Instance created.

We introduce again binary labels, the same way we did above.

test_label_cuts = ai.cuts.LogicalCut(np.logical_or(np.logical_or(dh_test.get('events', 'labels')[0] == 1, 
                                                   dh_test.get('events', 'labels')[0] == 17),
                                    dh_test.get('events', 'labels')[0] == 8))
print(test_label_cuts.counts(), test_label_cuts.counts()/test_label_cuts.total())
850 0.26153846153846155
dh_test.include_values(values=test_label_cuts.get_flag(),
                 naming='labels_binary', channel=0, type='events', delete_old=True)
Delete old labels_binary dataset
Included values.

We build samples from the test set and insert it to the model.

# predictions with the model are made that way
with h5py.File(path_h5_test, 'r') as f:
    start_from = 0
    nmbr_classes_test = 2
    test_size = f['events/event'].shape[1]
    test_idx = np.arange(start_from,start_from + test_size)
    x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]}  # array of shape: (nmbr_events, nmbr_features)
    y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])
    probs = cnn.get_prob(x).numpy()
    x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]}  # this was changed in place
    prediction = cnn.predict(x).numpy()

    # predictions can be saved with instance of EvaluationTools
    # print('PREDICTION: ', prediction)
    print('ACCURACY: ', np.sum(prediction == y)/len(y))
    print('Best model: ', checkpoint_callback.best_model_path)
    print('Probabilites: ', np.exp(probs[:10,:]))  # log softmax values
ACCURACY:  0.9436923076923077
Best model:  /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/cnn-clfdowndown=0-epoch=19-val_loss=0.16.ckpt
Probabilites:  [[0.06205076 0.93794924]
 [0.09400847 0.90599155]
 [0.07259091 0.92740905]
 [0.02571316 0.9742869 ]
 [0.05754973 0.94245034]
 [0.02000174 0.9799983 ]
 [0.08924916 0.9107508 ]
 [0.08573773 0.91426224]
 [0.07457279 0.9254272 ]
 [0.05423618 0.9457638 ]]

We store the predictions as cnn_cut in the data set and the probabilities.

dh_test.include_values(values=prediction,
                 naming='cnn_cut', channel=0, type='events', delete_old=True)
Delete old cnn_cut dataset
Included values.
dh_test.include_values(values=np.exp(probs[:,0]),
                 naming='cnn_probs', channel=0, type='events', delete_old=True)
Delete old cnn_probs dataset
Included values.

Lets plot a confusion matrix to see which classes worked nicely, and which did not.

confusion_matrix = np.zeros((24, 2), dtype=int)
for i,j in zip(dh_test.get('events', 'labels')[0], prediction):
    confusion_matrix[i,j] += 1
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-40-f98c1290f435> in <module>
      1 confusion_matrix = np.zeros((24, 2), dtype=int)
----> 2 for i,j in zip(dh_test.get('events', 'labels')[0], prediction):
      3     confusion_matrix[i,j] += 1

NameError: name 'prediction' is not defined
not_zero = np.sum(confusion_matrix, axis=1) > 0
nmbr_not_zero = np.sum(not_zero)
print('Nmbr used classes: ', nmbr_not_zero)
confusion_matrix_squeeze = confusion_matrix[not_zero]
Nmbr used classes:  13

For the actual plot, we need the seaborn library. In case you have not installed it, just run

!pip install seaborn

in a new cell.

import seaborn as sns

labels = np.unique(dh_test.get('events', 'labels')[0])

ax = sns.heatmap(confusion_matrix_squeeze, annot=True, fmt='d', linewidths=.5, xticklabels=['Cut','Survived'], yticklabels=labels)
plt.ylabel("Label") 
plt.show()
../_images/13universaltrainingset_55_0.svg

Lets see how the standard event looks like, produced with the cnn cut.

sev_raw = np.mean(dh_test.get('events', 'event')[0], axis=0)
sev_raw /= np.max(sev_raw)
sev_pred = np.mean(dh_test.get('events', 'event')[0, prediction == 1], axis=0)
sev_pred /= np.max(sev_pred)
sev_label = np.mean(dh_test.get('events', 'event')[0, y == 1], axis=0)
sev_label /= np.max(sev_label)

plt.plot(sev_raw, label='SEV without cuts', color='blue')
plt.plot(sev_label, label='SEV with labels', color='green')
plt.plot(sev_pred, label='SEV with predictions', color='red', linewidth=2)
plt.xlabel('Sample Index')
plt.ylabel('Amplitude (V)')
plt.legend()
plt.show()
../_images/13universaltrainingset_57_0.svg

Apply to real data

Here we show, how the model can be applied to another data set easily. During the training, we saved checkpoints with which the model can be realoaded at a later point. Now that we have the model already loaded in this notebook, we do not need to load it again. We can just apply the nn_predict function.

# include the predictions in another HDF5 file
ai.models.nn_predict(h5_path='test_data/test_001.h5',
           model=cnn,
           feature_channel=0,
           group_name='events',
           prediction_name='cnn_prob',
           keys=['event'],
           no_channel_idx_in_pred=False,
           use_prob=True)
cnn_prob written to file test_data/test_001.h5.

Because this is a multi-purpose model and probably useful in many situations, we stored it as a resource inside the library. So lets list all available resources:

ai.resources.list_resources()
Resources stored in /Users/felix/PycharmProjects/cait/cait/resources:
cnn-clf-binary-v1.ckpt
cnn-clf-binary-v0.ckpt

We have two different versions of this models stored, the v0 with 32 times downsampling, and the v1 with 4 time downsampling. Lets load and apply the v0 to the light channel (feature_channel 1).

ckp_path = ai.resources.get_resource_path('cnn-clf-binary-v0.ckpt')

ai.models.nn_predict(h5_path='test_data/test_001.h5',
           model=CNNModule.load_from_checkpoint(ckp_path),
           feature_channel=0,
           group_name='events',
           prediction_name='cnn_cut',
           keys=['event'],
           no_channel_idx_in_pred=False,
           use_prob=False)

ai.models.nn_predict(h5_path='test_data/test_001.h5',
           model=CNNModule.load_from_checkpoint(ckp_path),
           feature_channel=0,
           group_name='events',
           prediction_name='cnn_prob',
           keys=['event'],
           no_channel_idx_in_pred=False,
           use_prob=True)
cnn_cut written to file test_data/test_001.h5.
cnn_prob written to file test_data/test_001.h5.

Lets see the predictions inside the HDF5 set of with a DataHandler.

dh_real = ai.DataHandler(channels=[0,1])
dh_real.set_filepath(path_h5='test_data/',
                fname='test_001',
                appendix=False)
DataHandler Instance created.

For the phonon channel we expect only pulse events (1).

dh_real.get('events', 'cnn_cut')[0, :10]
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
np.exp(dh_real.get('events', 'cnn_prob')[0, :10])
array([[0.20349745, 0.79650256],
       [0.20579075, 0.79420924],
       [0.08865106, 0.91134895],
       [0.0886578 , 0.91134221],
       [0.16476031, 0.8352397 ],
       [0.08603831, 0.91396167],
       [0.18273026, 0.81726976],
       [0.16321516, 0.83678482],
       [0.16562542, 0.83437454],
       [0.10445209, 0.89554791]])

For the light channel, we expect pulse and noise events (0,1).

dh_real.get('events', 'cnn_cut')[1]
array([0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.,
       0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1.,
       0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1.,
       0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0.,
       1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0.,
       1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
       1., 1., 0., 0., 1., 0., 1.])

Done.