Univeral Training Data Augmentation
Warning
Note that this tutorial and the features it describes are currently not maintained. If you wish to help out and contribute to it, please reach out to the maintainers.
Correspondence to: felix.wagner@oeaw.ac.at
In this notebook we explore the data augmentation module of the Cait Python package.
We create here a universal training set, to identify particle recoils and empty noise baselines against artifacts. In the second section of the notebook, we train a classifier model to perform this task.
from cait.augment import ParameterSampler, plot_events, unfold, L2
import numpy as np
import cait as ai # install from the develop branch
import matplotlib.pyplot as plt
import matplotlib as mpl
import h5py
from tqdm.auto import tqdm
from scipy import signal
from cait.models import CNNModule
from torchvision import transforms
from cait.datasets import SingleMinMaxNorm, DownSample, ToTensor, CryoDataModule
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
%config InlineBackend.figure_formats = ['svg']
Some definitions for the plots.
mpl.rcParams['figure.figsize'] = (7.2, 4.45)
mpl.rcParams['savefig.dpi'] = 300
Data Augmentation
Define the global parameters. For details, see the notebook Cryogenic detector data augmentation.
For the whole notebook: Only capitalized parameters need to be set by the user.
RECORD_LENGTH = 16384
SAMPLE_FREQUENCY = 25000
POLYNOMIAL_DRIFTS = True
RASTERIZE = False
SQUARE_WAVES = False
SATURATION = True
CLASS_NAMES = ['Event Pulse',
'Noise',
'Decaying Baseline',
'Temperature Rise',
'Spike',
'Squid Jump',
'Reset',
'Cosinus Tail',
'Decaying Baseline with Event Pulse',
'Pile Up',
'Early or late Trigger',
]
# -----------------------------------------
# no need to change below parameters
# -----------------------------------------
label_names = {
'unlabeled': 0,
'Event Pulse': 1,
'Test/Control Pulse': 2,
'Noise': 3,
'Squid Jump': 4,
'Spike': 5,
'Early or late Trigger': 6,
'Pile Up': 7,
'Carrier Event': 8,
'Strongly Saturated Event Pulse': 9,
'Strongly Saturated Test/Control Pulse': 10,
'Decaying Baseline': 11,
'Temperature Rise': 12,
'Stick Event': 13,
'Square Waves': 14,
'Human Disturbance': 15,
'Large Sawtooth': 16,
'Cosinus Tail': 17,
'Light only Event': 18,
'Ring & Light Event': 19,
'Sharp Light Event': 20,
'Reset': 21,
'Decaying Baseline with Event Pulse': 22,
'Decaying Baseline with Tail Event': 23,
'unknown/other': 99,
}
We define the sampler class.
parsam = ParameterSampler(record_length=RECORD_LENGTH,
sample_frequency=SAMPLE_FREQUENCY)
Plot some augmented events. We did not define the baseline resolution, pulse shapes, saturation curve or noise power spectrum. Therefore for every batch, a new one is sampled, making the data overally more diverse.
NMBR_PLOTS = 1
# classes = np.random.choice(CLASS_NAMES, size=NMBR_PLOTS)
classes = [CLASS_NAMES[0] for i in range(NMBR_PLOTS)]
for i in range(NMBR_PLOTS):
event, info = parsam.get_event(label=classes[i],
size=9,
rasterize=RASTERIZE,
poly=POLYNOMIAL_DRIFTS,
square=SQUARE_WAVES,
saturation=SATURATION,
verb=True,
)
print(classes[i])
plot_events(event, t=parsam.t)
if classes[i] == 'Event Pulse':
print('Resolution (only meaningful without Saturation): ', np.std(np.abs(np.max(event, axis=1) - info['pulse_height'])))
Sample Noise...
Sample Polynomials...
Sample Pulse Nmbr 0
Sample Saturation ...
Event Pulse
Resolution (only meaningful without Saturation): 2.803244455274235
Now create the HDF5 data set.
FNAME = 'test_data/universal_v0_1.h5'
IDX = list(range(11)) # which to take from the class names
BATCHSIZE = 50
EVENTS_PER_CLASS = { # use always multiple of batch size!
'Event Pulse': 2000,
'Noise': 1000,
'Decaying Baseline': 150,
'Temperature Rise': 150,
'Spike': 150,
'Squid Jump': 150,
'Reset': 300,
'Cosinus Tail': 500,
'Decaying Baseline with Event Pulse': 500,
'Pile Up': 1000,
'Early or late Trigger': 500,
}
DIV = 1
TOTAL_EVENTS = int(np.sum(list(EVENTS_PER_CLASS.values()))/DIV)
BATCHSIZE = int(BATCHSIZE/DIV)
# -------------------------------------------
# no changes required below this line!
# -------------------------------------------
with h5py.File(FNAME, 'w') as f:
f.require_group('events')
f['events'].create_dataset('event',
shape=(1, TOTAL_EVENTS, RECORD_LENGTH),
dtype=np.float32)
f['events'].create_dataset('labels',
shape=(1, TOTAL_EVENTS),
dtype=int)
f['events'].create_dataset('true_ph',
shape=(1, TOTAL_EVENTS),
dtype=float)
f['events'].create_dataset('true_onset',
shape=(TOTAL_EVENTS, ),
dtype=float)
bar = tqdm(total=TOTAL_EVENTS)
bcount = 0
for i in IDX:
bar.write('Simulating {} ...'.format(CLASS_NAMES[i]))
nmbr_events = int(EVENTS_PER_CLASS[CLASS_NAMES[i]] / DIV)
batches = int(nmbr_events / BATCHSIZE)
for b in range(batches):
event, info = parsam.get_event(label=CLASS_NAMES[i],
size=BATCHSIZE,
rasterize=RASTERIZE,
poly=POLYNOMIAL_DRIFTS,
square=SQUARE_WAVES,
saturation=SATURATION,
verb=False,
)
f['events']['event'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), :] = event
f['events']['labels'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = label_names[CLASS_NAMES[i]]
f['events']['true_ph'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_height']
f['events']['true_onset'][int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['t0']*1000
# attributes
f['events']['labels'].attrs.create(name='unlabeled', data=0)
f['events']['labels'].attrs.create(name='Event_Pulse', data=1)
f['events']['labels'].attrs.create(name='Test/Control_Pulse', data=2)
f['events']['labels'].attrs.create(name='Noise', data=3)
f['events']['labels'].attrs.create(name='Squid_Jump', data=4)
f['events']['labels'].attrs.create(name='Spike', data=5)
f['events']['labels'].attrs.create(name='Early_or_late_Trigger', data=6)
f['events']['labels'].attrs.create(name='Pile_Up', data=7)
f['events']['labels'].attrs.create(name='Carrier_Event', data=8)
f['events']['labels'].attrs.create(name='Strongly_Saturated_Event_Pulse', data=9)
f['events']['labels'].attrs.create(name='Strongly_Saturated_Test/Control_Pulse', data=10)
f['events']['labels'].attrs.create(name='Decaying_Baseline', data=11)
f['events']['labels'].attrs.create(name='Temperature_Rise', data=12)
f['events']['labels'].attrs.create(name='Stick_Event', data=13)
f['events']['labels'].attrs.create(name='Square_Waves', data=14)
f['events']['labels'].attrs.create(name='Human_Disturbance', data=15)
f['events']['labels'].attrs.create(name='Large_Sawtooth', data=16)
f['events']['labels'].attrs.create(name='Cosinus_Tail', data=17)
f['events']['labels'].attrs.create(name='Light_only_Event', data=18)
f['events']['labels'].attrs.create(name='Ring_Light_Event', data=19)
f['events']['labels'].attrs.create(name='Sharp_Light_Event', data=20)
f['events']['labels'].attrs.create(name='Reset', data=21)
f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Event_Pulse', data=22)
f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Tail_Pulse', data=23)
f['events']['labels'].attrs.create(name='unknown/other', data=99)
bar.update(BATCHSIZE)
bcount += 1
Simulating Event Pulse ...
Simulating Noise ...
Simulating Decaying Baseline ...
Simulating Temperature Rise ...
Simulating Spike ...
Simulating Squid Jump ...
Simulating Reset ...
Simulating Cosinus Tail ...
Simulating Decaying Baseline with Event Pulse ...
Simulating Pile Up ...
Simulating Early or late Trigger ...
Lets look into the simulated data with the VizTool!
dh = ai.DataHandler(nmbr_channels=1,
sample_frequency=SAMPLE_FREQUENCY,
record_length=RECORD_LENGTH)
dh.set_filepath('test_data', 'universal_v0_1', appendix=False)
DataHandler Instance created.
First we calculate the main parameters and additional main paramters.
dh.calc_mp()
dh.calc_additional_mp(no_of=True)
CALCULATE MAIN PARAMETERS.
CALCULATE ADDITIONAL MAIN PARAMETERS.
Looks like a total mess, right?
datasets = {
'Pulse Height Phonon (V)': ['pulse_height', 0, None],
'Rise Time Phonon (ms)': ['rise_time', 0, None],
'Decay Time Phonon (ms)': ['decay_time', 0, None],
'Onset Phonon (ms)': ['onset', 0, None],
'Slope Phonon (V)': ['slope', 0, None],
'Variance Phonon (V^2)': ['var', 0, None],
'Mean Phonon (V)': ['mean', 0, None],
'Skewness Phonon': ['skewness', 0, None],
}
viz = ai.VizTool(path_h5='test_data/',
fname='universal_v0_1',
group='events',
datasets=datasets,
nmbr_channels=1,
batch_size=1000,
sample_frequency=25000,
record_length=16384)
viz.set_colors(dh.get('events', 'labels')[0])
viz.show()
DataHandler Instance created.
Universal Classification
Now we can train a univeral classifier model, which identifies particle recoils agains all kinds of artifacts, noise or pile up! We mostly use the same functions as in the ‘Classification with Neural Network’ Notebook, just this time we will use a CNN model.
We want to do a binary classification between events and other artifacts (noise, etc). Therefore we set all labels to 0 or 1, depending on if they show an event or not.
label_cuts = ai.cuts.LogicalCut(np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 17))
print(label_cuts.counts(), label_cuts.counts()/label_cuts.total())
2500 0.390625
dh.include_values(values=label_cuts.get_flag(),
naming='labels_binary', channel=0, type='events', delete_old=True)
Included values.
The cells below work very much the same way as they do for the other NN classification notebook.
# some parameters
# nmbr_gpus = ... uncommment and put in trainer to use GPUs
path_h5 = 'test_data/universal_v0_1.h5'
path_h5_test = 'test_data/test_v0_1.h5'
type = 'events' # the group key for the data in the HDF5 set
keys = ['event', 'labels_binary'] # the datasets in the group from which we include data in the samples for the NN
channel_indices = [[0], [0]] # the first indices of the datasets
feature_indices = [None, None] # the third indices of the datasets
feature_keys = ['event_ch0'] # the keys in the samples of the NN dataset that are input to the NN
# in the data set for the NN, the keys have additionally appended the channel index
label_keys = ['labels_binary_ch0'] # the keys in the samples of the NN dataset that are labels to the NN
norm_vals = {'event_ch0': [0, 10]} # we do a min - max normalization of all samples, so these are roughly the lowest and highest values of the events in the data set
down_keys = ['event_ch0'] # if we input the raw time series, we apply downsampling first
down = 32 # all samples in the NN dataset with the indices specified above are by the factor down downsampled
record_length = 16384
nmbr_classes = 2 # the number of classes in the data set - attention - the class index of the carrier event is 8, therefore we need at least 9 classes, even though one two of them are present in our data set
max_epochs = 20 # the maximal trianing epochs of the neural network
save_naming = 'cnn-clf' # the name adition if we want to save the trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# create the transforms
trans = transforms.Compose([SingleMinMaxNorm(norm_vals.keys()),
DownSample(keys=down_keys, down=down),
ToTensor()])
# create data module and init the setup
dm = CryoDataModule(hdf5_path=path_h5,
type=type,
keys=keys,
channel_indices=channel_indices,
feature_indices=feature_indices,
transform=trans,
)
dm.prepare_data(val_size=0.2,
test_size=0.1,
batch_size=16,
dataset_size=None,
nmbr_workers=0,
only_idx=None,
shuffle_dataset=True,
random_seed=21,
feature_keys=feature_keys,
label_keys=label_keys,
keys_one_hot=label_keys,
)
dm.setup()
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:175: LightningDeprecationWarning:
DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.
# create cnn clf
cnn = CNNModule(
kernelsize=8,
input_size=int(record_length/down),
device_name=device,
nmbr_out=nmbr_classes, # this is the number of labels
lr=1e-3,
label_keys=label_keys,
feature_keys=feature_keys,
down=down,
down_keys=feature_keys,
norm_vals=norm_vals,
norm_type='indiv_minmax',
offset_keys=feature_keys,
)
# create callback to save the best model
checkpoint_callback = ModelCheckpoint(dirpath='callbacks',
monitor='val_loss',
filename=save_naming + 'down{down}-{epoch:02d}-{val_loss:.2f}')
# create instance of Trainer
trainer = Trainer(max_epochs=max_epochs,
callbacks=[checkpoint_callback],
# gpus=1,
)
# keyword gpus=nmbr_gpus for GPU Usage
# keyword max_epochs for number of maximal epochs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
# all training happens here
trainer.fit(model=cnn,
datamodule=dm)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/core/datamodule.py:469: LightningDeprecationWarning:
DataModule.setup has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.setup.
| Name | Type | Params
---------------------------------
0 | conv1 | Conv1d | 450
1 | conv2 | Conv1d | 4.0 K
2 | fc1 | Linear | 14.2 K
3 | fc2 | Linear | 402
---------------------------------
19.1 K Trainable params
0 Non-trainable params
19.1 K Total params
0.076 Total estimated model params size (MB)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:623: UserWarning:
Checkpoint directory /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks exists and is not empty.
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning:
The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning:
The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
# load best model
cnn.load_from_checkpoint(checkpoint_callback.best_model_path)
# cnn.load_from_checkpoint('/Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/cnn-clfdowndown=0-epoch=18-val_loss=0.16.ckpt')
CNNModule(
(conv1): Conv1d(1, 50, kernel_size=(8,), stride=(4,))
(conv2): Conv1d(50, 10, kernel_size=(8,), stride=(4,))
(fc1): Linear(in_features=70, out_features=200, bias=True)
(fc2): Linear(in_features=200, out_features=2, bias=True)
)
cnn()
CNNModule(
(conv1): Conv1d(1, 50, kernel_size=(8,), stride=(4,))
(conv2): Conv1d(50, 10, kernel_size=(8,), stride=(4,))
(fc1): Linear(in_features=70, out_features=200, bias=True)
(fc2): Linear(in_features=200, out_features=2, bias=True)
)
# run test set
result = trainer.test(model=cnn, datamodule=dm)
print(result)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-35-0076aad303dc> in <module>
1 # run test set
----> 2 result = trainer.test(model=cnn, dataloaders=[dm.test_dataloader])
3 print(result)
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in test(self, model, dataloaders, ckpt_path, verbose, datamodule, test_dataloaders)
904 )
905 dataloaders = test_dataloaders
--> 906 return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
907
908 def _test_impl(
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
680 """
681 try:
--> 682 return trainer_fn(*args, **kwargs)
683 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
684 except KeyboardInterrupt as exception:
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _test_impl(self, model, dataloaders, ckpt_path, verbose, datamodule)
947
948 # run test
--> 949 results = self._run(model, ckpt_path=self.tested_ckpt_path)
950
951 assert self.state.stopped
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
1193
1194 # dispatch `start_training` or `start_evaluating` or `start_predicting`
-> 1195 self._dispatch()
1196
1197 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
1269 def _dispatch(self):
1270 if self.evaluating:
-> 1271 self.training_type_plugin.start_evaluating(self)
1272 elif self.predicting:
1273 self.training_type_plugin.start_predicting(self)
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_evaluating(self, trainer)
204 def start_evaluating(self, trainer: "pl.Trainer") -> None:
205 # double dispatch to initiate the test loop
--> 206 self._results = trainer.run_stage()
207
208 def start_predicting(self, trainer: "pl.Trainer") -> None:
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
1280
1281 if self.evaluating:
-> 1282 return self._run_evaluate()
1283 if self.predicting:
1284 return self._run_predict()
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_evaluate(self)
1328
1329 with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
-> 1330 eval_loop_results = self._evaluation_loop.run()
1331
1332 # remove the tensors from the eval results
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
143 try:
144 self.on_advance_start(*args, **kwargs)
--> 145 self.advance(*args, **kwargs)
146 self.on_advance_end()
147 self.restarting = False
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in advance(self, *args, **kwargs)
108 dl_max_batches = self._max_batches[dataloader_idx]
109
--> 110 dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
111
112 # store batch level output per dataloader
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
138 self.reset()
139
--> 140 self.on_run_start(*args, **kwargs)
141
142 while not self.done:
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py in on_run_start(self, data_fetcher, dataloader_idx, dl_max_batches, num_dataloaders)
84
85 self._reload_dataloader_state_dict(data_fetcher)
---> 86 self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
87
88 def advance(
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py in _update_dataloader_iter(data_fetcher, batch_idx)
119 if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
120 # restore iteration
--> 121 dataloader_iter = enumerate(data_fetcher, batch_idx)
122 else:
123 dataloader_iter = iter(data_fetcher)
~/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/fetching.py in __iter__(self)
195 raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
196 self.reset()
--> 197 self.dataloader_iter = iter(self.dataloader)
198 self._apply_patch()
199 self.prefetching(self.prefetch_batches)
TypeError: 'method' object is not iterable
Test
Now lets test the model with the data set we produced in the notebook Cryogenic Detector Data Augmentation!
dh_test = ai.DataHandler(nmbr_channels=1,
sample_frequency=SAMPLE_FREQUENCY,
record_length=RECORD_LENGTH)
dh_test.set_filepath('test_data', 'test_v0_1', appendix=False)
DataHandler Instance created.
We introduce again binary labels, the same way we did above.
test_label_cuts = ai.cuts.LogicalCut(np.logical_or(np.logical_or(dh_test.get('events', 'labels')[0] == 1,
dh_test.get('events', 'labels')[0] == 17),
dh_test.get('events', 'labels')[0] == 8))
print(test_label_cuts.counts(), test_label_cuts.counts()/test_label_cuts.total())
850 0.26153846153846155
dh_test.include_values(values=test_label_cuts.get_flag(),
naming='labels_binary', channel=0, type='events', delete_old=True)
Delete old labels_binary dataset
Included values.
We build samples from the test set and insert it to the model.
# predictions with the model are made that way
with h5py.File(path_h5_test, 'r') as f:
start_from = 0
nmbr_classes_test = 2
test_size = f['events/event'].shape[1]
test_idx = np.arange(start_from,start_from + test_size)
x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]} # array of shape: (nmbr_events, nmbr_features)
y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])
probs = cnn.get_prob(x).numpy()
x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]} # this was changed in place
prediction = cnn.predict(x).numpy()
# predictions can be saved with instance of EvaluationTools
# print('PREDICTION: ', prediction)
print('ACCURACY: ', np.sum(prediction == y)/len(y))
print('Best model: ', checkpoint_callback.best_model_path)
print('Probabilites: ', np.exp(probs[:10,:])) # log softmax values
ACCURACY: 0.9436923076923077
Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/cnn-clfdowndown=0-epoch=19-val_loss=0.16.ckpt
Probabilites: [[0.06205076 0.93794924]
[0.09400847 0.90599155]
[0.07259091 0.92740905]
[0.02571316 0.9742869 ]
[0.05754973 0.94245034]
[0.02000174 0.9799983 ]
[0.08924916 0.9107508 ]
[0.08573773 0.91426224]
[0.07457279 0.9254272 ]
[0.05423618 0.9457638 ]]
We store the predictions as cnn_cut in the data set and the probabilities.
dh_test.include_values(values=prediction,
naming='cnn_cut', channel=0, type='events', delete_old=True)
Delete old cnn_cut dataset
Included values.
dh_test.include_values(values=np.exp(probs[:,0]),
naming='cnn_probs', channel=0, type='events', delete_old=True)
Delete old cnn_probs dataset
Included values.
Lets plot a confusion matrix to see which classes worked nicely, and which did not.
confusion_matrix = np.zeros((24, 2), dtype=int)
for i,j in zip(dh_test.get('events', 'labels')[0], prediction):
confusion_matrix[i,j] += 1
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-40-f98c1290f435> in <module>
1 confusion_matrix = np.zeros((24, 2), dtype=int)
----> 2 for i,j in zip(dh_test.get('events', 'labels')[0], prediction):
3 confusion_matrix[i,j] += 1
NameError: name 'prediction' is not defined
not_zero = np.sum(confusion_matrix, axis=1) > 0
nmbr_not_zero = np.sum(not_zero)
print('Nmbr used classes: ', nmbr_not_zero)
confusion_matrix_squeeze = confusion_matrix[not_zero]
Nmbr used classes: 13
For the actual plot, we need the seaborn library. In case you have not installed it, just run
!pip install seaborn
in a new cell.
import seaborn as sns
labels = np.unique(dh_test.get('events', 'labels')[0])
ax = sns.heatmap(confusion_matrix_squeeze, annot=True, fmt='d', linewidths=.5, xticklabels=['Cut','Survived'], yticklabels=labels)
plt.ylabel("Label")
plt.show()
Lets see how the standard event looks like, produced with the cnn cut.
sev_raw = np.mean(dh_test.get('events', 'event')[0], axis=0)
sev_raw /= np.max(sev_raw)
sev_pred = np.mean(dh_test.get('events', 'event')[0, prediction == 1], axis=0)
sev_pred /= np.max(sev_pred)
sev_label = np.mean(dh_test.get('events', 'event')[0, y == 1], axis=0)
sev_label /= np.max(sev_label)
plt.plot(sev_raw, label='SEV without cuts', color='blue')
plt.plot(sev_label, label='SEV with labels', color='green')
plt.plot(sev_pred, label='SEV with predictions', color='red', linewidth=2)
plt.xlabel('Sample Index')
plt.ylabel('Amplitude (V)')
plt.legend()
plt.show()
Apply to real data
Here we show, how the model can be applied to another data set easily. During the training, we saved checkpoints with which the model can be realoaded at a later point. Now that we have the model already loaded in this notebook, we do not need to load it again. We can just apply the nn_predict function.
# include the predictions in another HDF5 file
ai.models.nn_predict(h5_path='test_data/test_001.h5',
model=cnn,
feature_channel=0,
group_name='events',
prediction_name='cnn_prob',
keys=['event'],
no_channel_idx_in_pred=False,
use_prob=True)
cnn_prob written to file test_data/test_001.h5.
Because this is a multi-purpose model and probably useful in many situations, we stored it as a resource inside the library. So lets list all available resources:
ai.resources.list_resources()
Resources stored in /Users/felix/PycharmProjects/cait/cait/resources:
cnn-clf-binary-v1.ckpt
cnn-clf-binary-v0.ckpt
We have two different versions of this models stored, the v0 with 32 times downsampling, and the v1 with 4 time downsampling. Lets load and apply the v0 to the light channel (feature_channel 1).
ckp_path = ai.resources.get_resource_path('cnn-clf-binary-v0.ckpt')
ai.models.nn_predict(h5_path='test_data/test_001.h5',
model=CNNModule.load_from_checkpoint(ckp_path),
feature_channel=0,
group_name='events',
prediction_name='cnn_cut',
keys=['event'],
no_channel_idx_in_pred=False,
use_prob=False)
ai.models.nn_predict(h5_path='test_data/test_001.h5',
model=CNNModule.load_from_checkpoint(ckp_path),
feature_channel=0,
group_name='events',
prediction_name='cnn_prob',
keys=['event'],
no_channel_idx_in_pred=False,
use_prob=True)
cnn_cut written to file test_data/test_001.h5.
cnn_prob written to file test_data/test_001.h5.
Lets see the predictions inside the HDF5 set of with a DataHandler.
dh_real = ai.DataHandler(channels=[0,1])
dh_real.set_filepath(path_h5='test_data/',
fname='test_001',
appendix=False)
DataHandler Instance created.
For the phonon channel we expect only pulse events (1).
dh_real.get('events', 'cnn_cut')[0, :10]
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
np.exp(dh_real.get('events', 'cnn_prob')[0, :10])
array([[0.20349745, 0.79650256],
[0.20579075, 0.79420924],
[0.08865106, 0.91134895],
[0.0886578 , 0.91134221],
[0.16476031, 0.8352397 ],
[0.08603831, 0.91396167],
[0.18273026, 0.81726976],
[0.16321516, 0.83678482],
[0.16562542, 0.83437454],
[0.10445209, 0.89554791]])
For the light channel, we expect pulse and noise events (0,1).
dh_real.get('events', 'cnn_cut')[1]
array([0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1.,
0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1.,
0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0.,
1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0.,
1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
1., 1., 0., 0., 1., 0., 1.])
Done.