{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Networks for Regression\n", "\n", "```{warning}\n", "Note that this tutorial and the features it describes are currently **not maintained**. If you wish to help out and contribute to it, please reach out to the maintainers.\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial we train a neural network for the task of pulse height regression. In previous work equally precise pulse height discrimination as with the optimum filter could be achieved, with improved robustness to noise fluctuations.\n", "\n", "Same as in the previous notebook, we use an LSTM model, PyTorch and PyTorch Lightning. We will not explain all utility steps again but reference the notebook for classification with neural networks." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:45:03.830608Z", "start_time": "2021-11-07T13:45:01.622740Z" } }, "outputs": [], "source": [ "import numpy as np\n", "import cait as ai\n", "from pytorch_lightning import Trainer\n", "from torchvision import transforms\n", "import h5py\n", "from cait.datasets import RemoveOffset, Normalize, DownSample, ToTensor, CryoDataModule\n", "from cait.models import LSTMModule, nn_predict\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "import matplotlib.pyplot as plt\n", "%config InlineBackend.figure_formats = ['svg'] # we need this for a suitable resolution of the plots" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we work on a regression problem now, the use the mean squared error to the true pulse height as optimization objective." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:45:04.134126Z", "start_time": "2021-11-07T13:45:02.778Z" } }, "outputs": [], "source": [ "# some parameters\n", "# nmbr_gpus = ... uncommment and put in trainer to use GPUs\n", "path_h5 = 'test_data/efficiency_001.h5'\n", "type = 'events'\n", "keys = ['event', 'true_ph']\n", "channel_indices = [[0], [0]]\n", "feature_indices = [None, None]\n", "feature_keys = ['event_ch0']\n", "label_keys = ['true_ph_ch0']\n", "norm_vals = {'event_ch0': [0, 1]}\n", "down_keys = ['event_ch0']\n", "down = 8\n", "input_size = 8\n", "nmbr_out = 1\n", "device_name='cpu'\n", "max_epochs = 10\n", "save_naming = 'lstm-reg'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset und Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As in the previous notebook we define the data transformations, the DataModule and the LightningModule." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:45:04.273668Z", "start_time": "2021-11-07T13:45:04.240995Z" } }, "outputs": [], "source": [ "# create the transforms\n", "transforms = transforms.Compose([RemoveOffset(keys=feature_keys),\n", " Normalize(norm_vals=norm_vals),\n", " DownSample(keys=down_keys, down=down),\n", " ToTensor()])\n", "\n", "# create data module and init the setup\n", "dm = CryoDataModule(hdf5_path=path_h5,\n", " type=type,\n", " keys=keys,\n", " channel_indices=channel_indices,\n", " feature_indices=feature_indices,\n", " transform=transforms)\n", "\n", "dm.prepare_data(val_size=0.2,\n", " test_size=0.2,\n", " batch_size=8,\n", " dataset_size=None,\n", " nmbr_workers=8, # set to number of CPUS on the machine\n", " only_idx=None,\n", " shuffle_dataset=True,\n", " random_seed=42,\n", " feature_keys=feature_keys,\n", " label_keys=label_keys,\n", " keys_one_hot=[])\n", "\n", "dm.setup()\n", "\n", "# create lstm clf\n", "lstm = LSTMModule(input_size=input_size,\n", " hidden_size=input_size * 10,\n", " num_layers=2,\n", " seq_steps=int(dm.dims[1] / input_size), # downsampling is already considered in dm\n", " device_name=device_name,\n", " nmbr_out=nmbr_out, # this is the number of labels\n", " lr=1e-4,\n", " label_keys=label_keys,\n", " feature_keys=feature_keys,\n", " is_classifier=False,\n", " down=down,\n", " down_keys=feature_keys,\n", " norm_vals=norm_vals,\n", " offset_keys=feature_keys)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensorboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can again start an instance of Tensorboard." ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2021-02-26T01:39:27.966390Z", "start_time": "2021-02-26T01:39:27.960239Z" } }, "source": [ ".. code:: python\n", "\n", " %load_ext tensorboard\n", " %tensorboard --logdir=lightning_logs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ".. note:: \n", "\n", " **Tensorboard on Server without X-Forwarding**\n", " If you work on a remote server that has X-forwarding deactivated, i.e. you don't have to option to show graphical elements, you can start the ssh connection with the additional -L flag:\n", "\n", " ssh -L 16006:127.0.0.1:6006 \n", "\n", " Then your local machine listens to the standard port of tensorboard on the remote server and you can open the tensorboard interface in a browser on your local machine by typing http://127.0.0.1:16006/ in the address line." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thanks to the unified PyTorch Lightning framework, the training works the same way as for a classifier." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:45:07.540863Z", "start_time": "2021-11-07T13:45:07.529480Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n" ] } ], "source": [ "# create callback to save the best model\n", "checkpoint_callback = ModelCheckpoint(dirpath='callbacks',\n", " monitor='val_loss',\n", " filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')\n", "\n", "# create instance of Trainer\n", "trainer = Trainer(max_epochs=max_epochs,\n", " callbacks=[checkpoint_callback])\n", "# keyword gpus=nmbr_gpus for GPU Usage\n", "# keyword max_epochs for number of maximal epochs" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:46:13.357692Z", "start_time": "2021-11-07T13:45:08.774562Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | lstm | LSTM | 80 K \n", "1 | fc1 | Linear | 20 K \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a0a162fcd5974fa485d4ce2a34fcc3ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:216: UserWarning:\n", "\n", "Please also save or load the state of the optimizer when saving or loading the scheduler.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# all training happens here\n", "trainer.fit(model=lstm,\n", " datamodule=dm)" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2021-02-25T21:56:32.294873Z", "start_time": "2021-02-25T21:56:32.285320Z" } }, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:47:01.179223Z", "start_time": "2021-11-07T13:47:01.160958Z" } }, "outputs": [ { "data": { "text/plain": [ "LSTMModule(\n", " (lstm): LSTM(8, 80, num_layers=2, batch_first=True)\n", " (fc1): Linear(in_features=20480, out_features=1, bias=True)\n", ")" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# load best model\n", "lstm.load_from_checkpoint(checkpoint_callback.best_model_path)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:47:09.489074Z", "start_time": "2021-11-07T13:47:02.230069Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "15de0bfc659043c9932a480465cb8afc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "DATALOADER:0 TEST RESULTS\n", "{'test_loss': tensor(0.0023),\n", " 'train_loss': tensor(0.0003),\n", " 'val_loss': tensor(0.0018)}\n", "--------------------------------------------------------------------------------\n", "\n", "[{'train_loss': 0.0002781856164801866, 'val_loss': 0.0018266444094479084, 'test_loss': 0.0023284959606826305}]\n" ] } ], "source": [ "# run test set\n", "result = trainer.test()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For evaluation, we calculate the RMS between our predictions and the true pulse heights on the test set." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-11-07T13:48:12.168632Z", "start_time": "2021-11-07T13:48:11.866536Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMS OF PREDICTION: 0.39585079037963333\n", "Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-reg-epoch=08-val_loss=0.00.ckpt\n", "Predictions: [[0.2937095 ]\n", " [0.31969512]\n", " [0.46127528]\n", " [0.34636137]\n", " [0.22584124]\n", " [0.2579584 ]\n", " [0.8273574 ]\n", " [0.3698974 ]\n", " [0.2475244 ]\n", " [0.9473152 ]\n", " [0.86203074]\n", " [0.23335579]\n", " [0.7685009 ]\n", " [0.5914988 ]\n", " [0.49016467]\n", " [0.43513584]\n", " [0.60164446]\n", " [0.2559728 ]\n", " [0.4249911 ]\n", " [0.17477603]\n", " [0.13642003]\n", " [0.36697793]\n", " [0.21585909]\n", " [0.8377478 ]\n", " [0.28353265]\n", " [0.5347649 ]\n", " [0.21387848]\n", " [0.47983745]\n", " [0.73118913]\n", " [0.11893168]\n", " [0.8334492 ]\n", " [0.1365829 ]\n", " [0.82246524]\n", " [0.48721352]\n", " [0.2677841 ]\n", " [0.23085369]\n", " [0.43038726]\n", " [0.14360833]\n", " [0.7804697 ]\n", " [0.12496844]\n", " [0.42676687]\n", " [0.23684482]\n", " [0.47283232]\n", " [0.1116754 ]\n", " [0.7609559 ]\n", " [0.11592295]\n", " [0.13516772]\n", " [0.9346943 ]\n", " [0.9774377 ]\n", " [0.41864622]\n", " [0.1303454 ]\n", " [0.6284125 ]\n", " [0.56606853]\n", " [0.9467117 ]\n", " [0.36443248]\n", " [0.9275517 ]\n", " [0.40045786]\n", " [0.6748831 ]\n", " [0.32109845]\n", " [0.6209897 ]\n", " [0.13024886]\n", " [0.7977837 ]\n", " [0.72050947]\n", " [0.45500043]\n", " [0.36789766]\n", " [1.019515 ]\n", " [1.0309461 ]\n", " [0.9192072 ]\n", " [0.32847393]\n", " [0.7248096 ]\n", " [0.6283419 ]\n", " [0.5696617 ]\n", " [0.61126596]\n", " [0.6211911 ]\n", " [0.86789256]\n", " [0.2289334 ]\n", " [0.19279546]\n", " [0.4699389 ]\n", " [0.421057 ]\n", " [1.0869708 ]\n", " [0.66301364]\n", " [0.35148814]\n", " [0.11834735]\n", " [0.39033556]\n", " [0.52159286]\n", " [0.11959913]\n", " [0.14312187]\n", " [0.88150346]\n", " [0.48689187]\n", " [0.45247683]\n", " [0.10437884]\n", " [0.47180602]\n", " [0.45709407]\n", " [0.11745495]\n", " [0.18325084]\n", " [0.53727484]\n", " [0.12649499]\n", " [0.46097454]\n", " [0.16666122]\n", " [0.89489394]]\n" ] } ], "source": [ "# predictions with the model are made that way\n", "f = h5py.File(dm.hdf5_path, 'r')\n", "test_idx = dm.test_sampler.indices\n", "test_idx.sort()\n", "x = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]} # array of shape: (nmbr_events, nmbr_features)\n", "y = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])\n", "prediction = lstm.predict(x).numpy()\n", "\n", "# predictions can be saved with instance of EvaluationTools\n", "print('RMS OF PREDICTION: ', np.sqrt(np.mean((prediction - y)**2)))\n", "print('Best model: ', checkpoint_callback.best_model_path)\n", "print('Predictions: ', prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At this point, we leave the thorough evaluation of the pulse height regression method as an exercise to the reader ;-)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Please forward questions and correspondence about this notebook to felix.wagner(at)oeaw.ac.at." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }