{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neural Networks for Event Classification\n",
"```{warning}\n",
"Note that this tutorial and the features it describes are currently **not maintained**. If you wish to help out and contribute to it, please reach out to the maintainers.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial we train a Long Short Term Memory neural network to distriminate absorber and carrier recoil events.\n",
"\n",
"A neural network is a supervised, non-linear fit model, capable of fitting high dimensional and complex correlations. Input and output is always a one dimensional vector. The Long Short Term Memory neural network processes input data in a time-distributed fashion, by splitting the input vector in smaller ones, called time steps. It reuses its fit parameters for every time step, while storing two gate-protected state vectors internally, that represent long- and short term correlations between time steps. This network type is especially suitable for the processing for time distributed data, as ours.\n",
"\n",
"For the neural networks we use the PyTorch framework, that is the most used framework within the research community. We also use the API of PyTorch Lighning, that reduces the code bulk inside training files significantly and increases the flexibility, while keeping all possibilities to add individual elements to the training process."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:06.173345Z",
"start_time": "2021-11-07T13:42:03.667021Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import cait as ai\n",
"from pytorch_lightning import Trainer\n",
"from torchvision import transforms\n",
"import h5py\n",
"from cait.datasets import RemoveOffset, Normalize, DownSample, ToTensor, CryoDataModule\n",
"from cait.models import LSTMModule, nn_predict\n",
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"import matplotlib.pyplot as plt\n",
"%config InlineBackend.figure_formats = ['svg'] # we need this for a suitable resolution of the plots"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The parameter here might seem a bit overwhelming. We try to explain them with the comments next to each."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:06.193615Z",
"start_time": "2021-11-07T13:42:06.188571Z"
}
},
"outputs": [],
"source": [
"# some parameters\n",
"# nmbr_gpus = ... uncommment and put in trainer to use GPUs\n",
"path_h5 = 'test_data/labeled_001.h5'\n",
"type = 'events' # the group key for the data in the HDF5 set\n",
"keys = ['event', 'labels'] # the datasets in the group from which we include data in the samples for the NN\n",
"channel_indices = [[0], [0]] # the first indices of the datasets\n",
"feature_indices = [None, None] # the third indices of the datasets\n",
"feature_keys = ['event_ch0'] # the keys in the samples of the NN dataset that are input to the NN\n",
" # in the data set for the NN, the keys have additionally appended the channel index \n",
"label_keys = ['labels_ch0'] # the keys in the samples of the NN dataset that are labels to the NN\n",
"norm_vals = {'event_ch0': [0, 3]} # we do a min - max normalization of all samples, so these are roughly the lowest and highest values of the events in the data set\n",
"down_keys = ['event_ch0'] # if we input the raw time series, we apply downsampling first\n",
"down = 64 # all samples in the NN dataset with the indices specified above are by the factor down downsampled\n",
"input_size = 8 # the input size of the LSTM cell\n",
"nmbr_classes = 10 # the number of classes in the data set - attention - the class index of the carrier event is 8, therefore we need at least 9 classes, even though one two of them are present in our data set\n",
"max_epochs = 20 # the maximal trianing epochs of the neural network\n",
"save_naming = 'lstm-clf' # the name adition if we want to save the trained model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset and Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define several transforms that are done on all data input to the NN. First, this is to remove the offset of the raw time series, then the samples are min-max normalized. If we choose to do a downsampling of the raw data, this is done before the numpy arrays of the samples are converted to PyTorch tensors."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:17.138608Z",
"start_time": "2021-11-07T13:42:17.134361Z"
}
},
"outputs": [],
"source": [
"# create the transforms\n",
"transforms = transforms.Compose([RemoveOffset(keys=feature_keys),\n",
" Normalize(norm_vals=norm_vals),\n",
" DownSample(keys=down_keys, down=down),\n",
" ToTensor()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The CryoDataModule is a Cait-specific child of Pytorch Lightnings DataModule class. It puts the data from the HDF5 file into a format, that can be input to the NN. The data is still loaded lazy, to prevent memory issues."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:18.338118Z",
"start_time": "2021-11-07T13:42:18.334097Z"
}
},
"outputs": [],
"source": [
"# create data module and init the setup\n",
"dm = CryoDataModule(hdf5_path=path_h5,\n",
" type=type,\n",
" keys=keys,\n",
" channel_indices=channel_indices,\n",
" feature_indices=feature_indices,\n",
" transform=transforms)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the prepare_data routine we define several facts about the training process. This is especially the split sizes of training, validation and test sets and the batch size. Here we can also specify the number of workers if we want to do multithreading in the training process. Try both arguments 0 and the number of your CPUs for the key nmbr_workers!"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:20.004431Z",
"start_time": "2021-11-07T13:42:19.953483Z"
}
},
"outputs": [],
"source": [
"dm.prepare_data(val_size=0.2,\n",
" test_size=0.2,\n",
" batch_size=32,\n",
" dataset_size=None,\n",
" nmbr_workers=0, # set to number of CPUS on the machine (strange fact: for me 0 is faster than 8?) probabily inefficient implementation of our dataset for multithreading (issue on GitLab is open)\n",
" only_idx=None,\n",
" shuffle_dataset=True,\n",
" random_seed=21,\n",
" feature_keys=feature_keys,\n",
" label_keys=label_keys,\n",
" keys_one_hot=label_keys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The setup method does internal definitions and preprocessing of the data. The difference to the prepare_data routine is, that it is called on each worker separately."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:21.031766Z",
"start_time": "2021-11-07T13:42:21.024548Z"
}
},
"outputs": [],
"source": [
"dm.setup()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define the LSTM neural network model, as a child of the LightningModule. We need to specify the architecture of the model (layers, etc), as well as the learning rate, i.e. the size of weight updates in the optimization process. If the training does not work, it very often comes down to a wrong learning rate!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:22.531934Z",
"start_time": "2021-11-07T13:42:22.523755Z"
}
},
"outputs": [],
"source": [
"# create lstm clf\n",
"lstm = LSTMModule(input_size=input_size,\n",
" hidden_size=input_size * 10,\n",
" num_layers=2,\n",
" seq_steps=int(dm.dims[1] / input_size), # downsampling is already considered in dm\n",
" device_name='cpu',\n",
" nmbr_out=nmbr_classes, # this is the number of labels\n",
" lr=1e-3,\n",
" label_keys=label_keys,\n",
" feature_keys=feature_keys,\n",
" is_classifier=True,\n",
" down=down,\n",
" down_keys=feature_keys,\n",
" norm_vals=norm_vals,\n",
" offset_keys=feature_keys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tensorboard"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensorboard is a convenient and powerful tool for the visualization of neural network trainings. It can be display inline in a Jupyter notebook or in a webbrowser when called from a terminal. The folder lightning_logs will be created automatically and store data of the training process."
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2021-02-26T01:38:13.182048Z",
"start_time": "2021-02-26T01:38:13.176464Z"
}
},
"source": [
".. code:: python\n",
"\n",
" %load_ext tensorboard\n",
" %tensorboard --logdir=lightning_logs"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:40.151932Z",
"start_time": "2021-11-07T13:42:26.552658Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=lightning_logs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you run scripts on servers without X-forwarding, you can still run Tensorboard on the server and listen to the sending port on your local machine. For this, include the -L flag when connecting to the server, as shown below.\n",
"\n",
".. code:: console\n",
"\n",
" ssh -L 16006:127.0.0.1:6006 \n",
" \n",
"You can then open Tensoreboard in your local webbrowser with the URL http://127.0.0.1:16006."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the training of the NN we use a trick, namely we save the model with the lowest loss value on the validation set. This is the best possible estimator for the loss value on the test set and therefore also for the performance on unseen data."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:45.857310Z",
"start_time": "2021-11-07T13:42:45.848941Z"
}
},
"outputs": [],
"source": [
"# create callback to save the best model\n",
"checkpoint_callback = ModelCheckpoint(dirpath='callbacks',\n",
" monitor='val_loss',\n",
" filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Pytorch Lightning Trainer class takes care of the whole training process under the hood. We pass the number of epochs and the checkpoints as arguments. Here we could also specify to use GPUs, if we have any."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:42:47.226922Z",
"start_time": "2021-11-07T13:42:47.218622Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n"
]
}
],
"source": [
"# create instance of Trainer\n",
"trainer = Trainer(max_epochs=max_epochs,\n",
" callbacks=[checkpoint_callback])\n",
"# keyword gpus=nmbr_gpus for GPU Usage\n",
"# keyword max_epochs for number of maximal epochs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we start the training of our model."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:08.958717Z",
"start_time": "2021-11-07T13:42:48.822104Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params\n",
"--------------------------------\n",
"0 | lstm | LSTM | 80 K \n",
"1 | fc1 | Linear | 25 K \n",
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "422ea4b1caa34804ab6de92327709e60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:216: UserWarning:\n",
"\n",
"Please also save or load the state of the optimizer when saving or loading the scheduler.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# all training happens here\n",
"trainer.fit(model=lstm,\n",
" datamodule=dm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Having finished the training, we load the best saved model from out checkpoint for the evaluation."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:16.815124Z",
"start_time": "2021-11-07T13:43:16.800782Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"LSTMModule(\n",
" (lstm): LSTM(8, 80, num_layers=2, batch_first=True)\n",
" (fc1): Linear(in_features=2560, out_features=10, bias=True)\n",
")"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load best model\n",
"lstm.load_from_checkpoint(checkpoint_callback.best_model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The trainer has stored the path to the test set internally, so we can test our model."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:18.192435Z",
"start_time": "2021-11-07T13:43:18.047489Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3e04979515354794ae2e28927c17cca9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"DATALOADER:0 TEST RESULTS\n",
"{'test_loss': tensor(0.0525),\n",
" 'train_loss': tensor(0.0309),\n",
" 'val_loss': tensor(0.0304)}\n",
"--------------------------------------------------------------------------------\n",
"\n",
"[{'train_loss': 0.030931567773222923, 'val_loss': 0.03036707453429699, 'test_loss': 0.0525243803858757}]\n"
]
}
],
"source": [
"# run test set\n",
"result = trainer.test()\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The loss value does not say much about the performance of our model in terms of classification accuracy. So we start a per-hand evaluation below."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:25.448628Z",
"start_time": "2021-11-07T13:43:25.359108Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PREDICTION: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1\n",
" 8 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 8 1 1 1 1 1 1 1 8 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 8 8 8 8 8 8 8 8 8\n",
" 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
" 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8\n",
" 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]\n",
"ACCURACY: 0.975\n",
"Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-clf-epoch=17-val_loss=0.02-v0.ckpt\n"
]
}
],
"source": [
"# predictions with the model are made that way\n",
"f = h5py.File(dm.hdf5_path, 'r')\n",
"test_idx = dm.test_sampler.indices\n",
"test_idx.sort()\n",
"x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]} # array of shape: (nmbr_events, nmbr_features)\n",
"y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])\n",
"prediction = lstm.predict(x).numpy()\n",
"\n",
"# predictions can be saved with instance of EvaluationTools\n",
"print('PREDICTION: ', prediction)\n",
"print('ACCURACY: ', np.sum(prediction == y)/len(y))\n",
"print('Best model: ', checkpoint_callback.best_model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The score on the test set pretty good, it seems that our model learned something!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predictions on Real Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We ultimately want to include predictions of our neural network model to the HDF5 set, in order to use it for quality cuts in the analysis process. For this we can use the function nn_predict with the according arguments."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:29.764223Z",
"start_time": "2021-11-07T13:43:29.631299Z"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eaf22d00cf7940f1981312320bf7b728",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=160.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"pred_lstm written to file test_data/test_001.h5.\n"
]
}
],
"source": [
"# include the predictions in another HDF5 file\n",
"nn_predict(h5_path='test_data/test_001.h5',\n",
" model=lstm,\n",
" feature_channel=0,\n",
" group_name='events',\n",
" prediction_name='pred_lstm',\n",
" keys=['event'],\n",
" no_channel_idx_in_pred=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We don't have labels for an exact evaluation in out raw data set, so we need another method for evaluation. For this we do a plot of the decay times of the pulses, because we learned in a previous notebook that this is a good discriminative quantity."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:31.085856Z",
"start_time": "2021-11-07T13:43:31.071561Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataHandler Instance created.\n"
]
}
],
"source": [
"dh = ai.DataHandler(channels=[0,1])\n",
"dh.set_filepath(path_h5='test_data/',\n",
" fname='test_001',\n",
" appendix=False)\n",
"\n",
"ph = dh.get('events', 'mainpar')[0,:,0]\n",
"decay_time = (dh.get('events','mainpar')[0,:,6] - dh.get('events','mainpar')[0,:,4])/dh.sample_frequency\n",
"pred_lstm = dh.get('events', 'pred_lstm')[0]\n",
"pred_absorber = ai.cuts.LogicalCut(pred_lstm == 1)\n",
"pred_carrier = ai.cuts.LogicalCut(pred_lstm == 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looks like the neural network gets almost all of our events in the raw data right! For disclosure we have to mention, that our used dataset is rather small for neural network training, were datasets contain usually at least several thousand of samples. So this accuracy can surely be improved. In previous work, discrimination thresholds in the order of magnitude of the baseline resolution could be achieved."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-07T13:43:32.988119Z",
"start_time": "2021-11-07T13:43:32.573590Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.close()\n",
"ai.styles.use_cait_style(dpi=150)\n",
"plt.scatter(ph[pred_absorber.get_flag()], decay_time[pred_absorber.get_flag()], marker='.', alpha=0.9, zorder=10, label='Prediction Absorber')\n",
"plt.scatter(ph[pred_carrier.get_flag()], decay_time[pred_carrier.get_flag()], marker='.', alpha=0.9, zorder=10, label='Prediction Carrier')\n",
"ai.styles.make_grid()\n",
"plt.xlabel('Pulse Height (V)')\n",
"plt.ylabel('Decay Time (s)')\n",
"plt.xlim([0,5])\n",
"plt.ylim([0,0.01])\n",
"legend = plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), mode=\"expand\", borderaxespad=0., ncol=3)\n",
"for lh in legend.legendHandles:\n",
" lh.set_alpha(1.0)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please forward questions and correspondence about this notebook to felix.wagner(at)oeaw.ac.at."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}