Neural Networks for Regression¶
In this tutorial we train a neural network for the task of pulse height regression. In previous work equally precise pulse height discrimination as with the optimum filter could be achieved, with improved robustness to noise fluctuations.
Same as in the previous notebook, we use an LSTM model, PyTorch and PyTorch Lightning. We will not explain all utility steps again but reference the notebook for classification with neural networks.
import numpy as np
import cait as ai
from pytorch_lightning import Trainer
from torchvision import transforms
import h5py
from cait.datasets import RemoveOffset, Normalize, DownSample, ToTensor, CryoDataModule
from cait.models import LSTMModule, nn_predict
from pytorch_lightning.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg'] # we need this for a suitable resolution of the plots
As we work on a regression problem now, the use the mean squared error to the true pulse height as optimization objective.
# some parameters
# nmbr_gpus = ... uncommment and put in trainer to use GPUs
path_h5 = 'test_data/efficiency_001.h5'
type = 'events'
keys = ['event', 'true_ph']
channel_indices = [[0], [0]]
feature_indices = [None, None]
feature_keys = ['event_ch0']
label_keys = ['true_ph_ch0']
norm_vals = {'event_ch0': [0, 1]}
down_keys = ['event_ch0']
down = 8
input_size = 8
nmbr_out = 1
device_name='cpu'
max_epochs = 10
save_naming = 'lstm-reg'
Dataset und Model¶
As in the previous notebook we define the data transformations, the DataModule and the LightningModule.
# create the transforms
transforms = transforms.Compose([RemoveOffset(keys=feature_keys),
Normalize(norm_vals=norm_vals),
DownSample(keys=down_keys, down=down),
ToTensor()])
# create data module and init the setup
dm = CryoDataModule(hdf5_path=path_h5,
type=type,
keys=keys,
channel_indices=channel_indices,
feature_indices=feature_indices,
transform=transforms)
dm.prepare_data(val_size=0.2,
test_size=0.2,
batch_size=8,
dataset_size=None,
nmbr_workers=0, # set to number of CPUS on the machine
only_idx=None,
shuffle_dataset=True,
random_seed=42,
feature_keys=feature_keys,
label_keys=label_keys,
keys_one_hot=[])
dm.setup()
# create lstm clf
lstm = LSTMModule(input_size=input_size,
hidden_size=input_size * 10,
num_layers=2,
seq_steps=int(dm.dims[1] / input_size), # downsampling is already considered in dm
device_name=device_name,
nmbr_out=nmbr_out, # this is the number of labels
lr=1e-4,
label_keys=label_keys,
feature_keys=feature_keys,
is_classifier=False,
down=down,
down_keys=feature_keys,
norm_vals=norm_vals,
offset_keys=feature_keys)
Tensorboard¶
We can again start an instance of Tensorboard.
.. code:: python
%load_ext tensorboard
%tensorboard --logdir=lightning_logs
.. note::
**Tensorboard on Server without X-Forwarding**
If you work on a remote server that has X-forwarding deactivated, i.e. you don't have to option to show graphical elements, you can start the ssh connection with the additional -L flag:
ssh -L 16006:127.0.0.1:6006 <SERVER_SSH_ADRESS>
Then your local machine listens to the standard port of tensorboard on the remote server and you can open the tensorboard interface in a browser on your local machine by typing http://127.0.0.1:16006/ in the address line.
Training¶
Thanks to the unified PyTorch Lightning framework, the training works the same way as for a classifier.
# create callback to save the best model
checkpoint_callback = ModelCheckpoint(dirpath='callbacks',
monitor='val_loss',
filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')
# create instance of Trainer
trainer = Trainer(max_epochs=max_epochs,
callbacks=[checkpoint_callback])
# keyword gpus=nmbr_gpus for GPU Usage
# keyword max_epochs for number of maximal epochs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
# all training happens here
trainer.fit(model=lstm,
datamodule=dm)
| Name | Type | Params
--------------------------------
0 | lstm | LSTM | 80 K
1 | fc1 | Linear | 20 K
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:216: UserWarning: Please also save or load the state of the optimizer when saving or loading the scheduler.
warnings.warn(SAVE_STATE_WARNING, UserWarning)
1
Evaluation¶
# load best model
lstm.load_from_checkpoint(checkpoint_callback.best_model_path)
LSTMModule(
(lstm): LSTM(8, 80, num_layers=2, batch_first=True)
(fc1): Linear(in_features=20480, out_features=1, bias=True)
)
# run test set
result = trainer.test()
print(result)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
warnings.warn(*args, **kwargs)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0017),
'train_loss': tensor(0.0030),
'val_loss': tensor(0.0020)}
--------------------------------------------------------------------------------
[{'train_loss': 0.003030629362910986, 'val_loss': 0.00197098427452147, 'test_loss': 0.0017100092954933643}]
For evaluation, we calculate the RMS between our predictions and the true pulse heights on the test set.
# predictions with the model are made that way
f = h5py.File(dm.hdf5_path, 'r')
test_idx = dm.test_sampler.indices
test_idx.sort()
x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]} # array of shape: (nmbr_events, nmbr_features)
y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])
prediction = lstm.predict(x).numpy()
# predictions can be saved with instance of EvaluationTools
print('RMS OF PREDICTION: ', np.sqrt(np.mean((prediction - y)**2)))
print('Best model: ', checkpoint_callback.best_model_path)
print('Predictions: ', prediction)
RMS OF PREDICTION: 0.3913621050898803
Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-reg-epoch=08-val_loss=0.00.ckpt
Predictions: [[0.5190951 ]
[0.21318421]
[0.5722188 ]
[0.868355 ]
[0.2756531 ]
[0.48385212]
[0.63625795]
[0.7250182 ]
[0.23767623]
[0.1014615 ]
[0.27463716]
[0.30882367]
[0.79756474]
[0.49813315]
[0.29637933]
[0.3356294 ]
[1.0668943 ]
[0.88048095]
[0.9578627 ]
[0.8911479 ]
[0.34058082]
[0.20126037]
[0.33892164]
[0.13703091]
[0.28569117]
[0.35804522]
[0.8819514 ]
[0.11296244]
[0.6042915 ]
[0.6077907 ]
[0.7591764 ]
[0.8522276 ]
[0.5375982 ]
[0.12401645]
[0.27267727]
[0.80078995]
[0.23864874]
[0.17496601]
[0.2327037 ]
[0.17392637]
[0.2112688 ]
[0.27889025]
[0.7392883 ]
[0.15140365]
[0.42190912]
[0.44661522]
[0.1106218 ]
[0.17759526]
[0.4016158 ]
[0.9523 ]
[0.5110052 ]
[0.31491077]
[0.59735143]
[0.22782765]
[0.37556702]
[0.31949657]
[0.1590495 ]
[0.58973217]
[0.26327324]
[0.26482552]
[0.17542289]
[0.21271276]
[0.4088708 ]
[0.2576398 ]
[0.35049498]
[0.13741563]
[0.34548762]
[0.14705749]
[0.39027983]
[0.6021813 ]
[0.61985874]
[0.5223465 ]
[0.9593096 ]
[0.17346363]
[0.21017638]
[0.8206003 ]
[0.7811061 ]
[0.5741181 ]
[0.7600825 ]
[0.8185749 ]
[0.457417 ]
[0.1574406 ]
[0.15766893]
[0.12770347]
[0.10976886]
[1.072875 ]
[0.39373276]
[0.6390164 ]
[0.6039722 ]
[0.9202833 ]
[1.0242869 ]
[0.14536014]
[0.35428348]
[0.1805871 ]
[0.8126399 ]
[0.17252713]
[0.71183455]
[0.16563846]
[0.33092207]
[0.6755552 ]]
At this point, we leave the thorough evaluation of the pulse height regression method as an exercise to the reader ;-)
Please forward questions and correspondence about this notebook to felix.wagner(at)oeaw.ac.at.