{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pile Up Augmentation\n", "\n", "```{warning}\n", "Note that this tutorial and the features it describes are currently **not maintained**. If you wish to help out and contribute to it, please reach out to the maintainers.\n", "```" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2021-07-29T09:47:46.276007Z", "start_time": "2021-07-29T09:47:46.268922Z" } }, "source": [ "*Correspondence to: felix.wagner@oeaw.ac.at*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook we explore the data augmentation module of the Cait Python package again and create pile up events of a detector. This notebook is based on the workshop presentation \"Nonlinear pile-up separation with LSTM neural networks for cryogenic particle detectors\", at the Machine Learning for the Physical Sciences workshop at Neurips21 (https://ml4physicalsciences.github.io/2021/)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:52:13.633572Z", "start_time": "2021-11-30T08:52:09.221139Z" } }, "outputs": [], "source": [ "from cait.augment import ParameterSampler, plot_events, unfold, L2, EventDefinition, Distribution\n", "import numpy as np\n", "import cait as ai \n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import h5py\n", "from tqdm.auto import tqdm\n", "from scipy import signal\n", "from PIL import Image, ImageDraw\n", "import os\n", "%config InlineBackend.figure_formats = ['svg']" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:52:14.879087Z", "start_time": "2021-11-30T08:52:14.876035Z" } }, "outputs": [], "source": [ "if not os.path.exists('test_data/'):\n", " os.makedirs('test_data/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some definitions for the plots." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:52:16.369741Z", "start_time": "2021-11-30T08:52:16.366455Z" } }, "outputs": [], "source": [ "mpl.rcParams['figure.figsize'] = (7.2, 4.45)\n", "mpl.rcParams['savefig.dpi'] = 300" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the global parameters.\n", "\n", "For the whole notebook: Only capitalized parameters need to be set by the user." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:54:52.010665Z", "start_time": "2021-11-30T08:54:52.003725Z" } }, "outputs": [], "source": [ "RECORD_LENGTH = 16384\n", "SAMPLE_FREQUENCY = 25000\n", "RESOLUTION = 0.01\n", "\n", "POLYNOMIAL_DRIFTS = True\n", "RASTERIZE = False\n", "SQUARE_WAVES = False\n", "SATURATION = True\n", "CLASS_NAMES = [#'Event Pulse',\n", " #'Noise',\n", " #'Decaying Baseline',\n", " #'Temperature Rise',\n", " #'Spike',\n", " #'Squid Jump',\n", " #'Reset',\n", " #'Cosinus Tail',\n", " #'Decaying Baseline with Event Pulse',\n", " 'Pile Up',\n", " #'Early or late Trigger',\n", " #'Carrier Event',\n", " #'Decaying Baseline with Tail Event',\n", " ]\n", "\n", "\n", "# ----------------------------------------- \n", "# no need to change below parameters\n", "# ----------------------------------------- \n", "\n", "label_names = {\n", " 'unlabeled': 0,\n", " 'Event Pulse': 1,\n", " 'Test/Control Pulse': 2,\n", " 'Noise': 3,\n", " 'Squid Jump': 4,\n", " 'Spike': 5,\n", " 'Early or late Trigger': 6,\n", " 'Pile Up': 7,\n", " 'Carrier Event': 8,\n", " 'Strongly Saturated Event Pulse': 9,\n", " 'Strongly Saturated Test/Control Pulse': 10,\n", " 'Decaying Baseline': 11,\n", " 'Temperature Rise': 12,\n", " 'Stick Event': 13,\n", " 'Square Waves': 14,\n", " 'Human Disturbance': 15,\n", " 'Large Sawtooth': 16,\n", " 'Cosinus Tail': 17,\n", " 'Light only Event': 18,\n", " 'Ring & Light Event': 19,\n", " 'Sharp Light Event': 20,\n", " 'Reset': 21,\n", " 'Decaying Baseline with Event Pulse': 22,\n", " 'Decaying Baseline with Tail Event': 23,\n", " 'unknown/other': 99,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the sampler class." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:54:54.539982Z", "start_time": "2021-11-30T08:54:54.536121Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "parsam = ParameterSampler(record_length=RECORD_LENGTH,\n", " sample_frequency=SAMPLE_FREQUENCY)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the detector resolution. The measured resolution will deviate from the set resolution, depending on the noise power spectrum and the method for reconstructing the resolution." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:54:55.556287Z", "start_time": "2021-11-30T08:54:55.552840Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "parsam.set_args(resolution=np.array([RESOLUTION]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the pile up data set we define a custom event definition." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:01.943535Z", "start_time": "2021-11-30T08:55:01.936267Z" } }, "outputs": [], "source": [ "class PileUpEventDefinition(EventDefinition):\n", "\n", " def get_class_pars(self, label, size, **kwargs):\n", "\n", " if label == 'Event Pulse':\n", " pileups = np.ones(size)\n", " ps_nmbr = 0\n", " decay = False\n", " rise = False\n", " spike = False\n", " jump = False\n", " pulse_reset = False\n", " tail = False\n", " onset_iv = np.array([-160, 400]) * 0.001\n", "\n", " elif label == 'Pile Up':\n", " pileups = int(2)*np.ones(size, dtype=int) # np.random.choice(a=[2, 3], size=size, p=[0.5, 0.5])\n", " ps_nmbr = 0\n", " decay = False\n", " rise = False\n", " spike = False\n", " jump = False\n", " pulse_reset = False\n", " tail = False\n", " onset_iv = np.array([-20, 20]) * 0.001\n", "\n", " else:\n", " raise KeyError('Class {} not available.'.format(label))\n", "\n", " return pileups, ps_nmbr, decay, rise, spike, jump, pulse_reset, tail, onset_iv\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:04.715465Z", "start_time": "2021-11-30T08:55:04.710586Z" } }, "outputs": [], "source": [ "class PileUpPulseHeights(Distribution):\n", "\n", " def __init__(self, mini=0, maxi=0.5):\n", " self.mini = mini\n", " self.maxi = maxi\n", "\n", " def sample(self, size, **kwargs):\n", "\n", " pulse_height = np.random.uniform(low=self.mini, high=self.maxi, size=size)\n", "\n", " return pulse_height" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:06.502503Z", "start_time": "2021-11-30T08:55:06.498649Z" } }, "outputs": [], "source": [ "parsam.set_args(event_definition=PileUpEventDefinition())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:07.042579Z", "start_time": "2021-11-30T08:55:07.039318Z" } }, "outputs": [], "source": [ "parsam.set_args(ph_dist=PileUpPulseHeights())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the pulse shapes. You can always resample, in case you don't like them." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:23.995035Z", "start_time": "2021-11-30T08:55:23.564271Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:55:23.780528\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:55:23.972761\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# sample pulse shapes\n", "\n", "# ps1, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))\n", "# ps2, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))\n", "# print(ps1)\n", "# print(ps2)\n", "\n", "mpl.rcParams['figure.figsize'] = (5, 4)\n", "# mpl.rcParams['savefig.dpi'] = 300\n", "fontsize = 14\n", "\n", "mpl.rcParams['xtick.labelsize'] = fontsize\n", "mpl.rcParams['ytick.labelsize'] = fontsize\n", "mpl.rcParams['font.size'] = fontsize\n", "mpl.rcParams['axes.titlesize'] = fontsize\n", "mpl.rcParams['axes.labelsize'] = fontsize\n", "mpl.rcParams['legend.fontsize'] = fontsize\n", "mpl.rcParams['mathtext.fontset'] = 'stix'\n", "mpl.rcParams['font.family'] = 'STIXGeneral'\n", "\n", "# or define your own pulse shapes ...\n", "ps1 = {'t0': np.array([0.]), \n", " 'tau_t': np.array([0.05815608]), \n", " 'tau_in': np.array([0.02059209]), \n", " 'tau_n': np.array([0.01395427]), \n", " 'An': np.array([-0.02508469])/0.0715207609981429, \n", " 'At': np.array([0.14102789])/0.0715207609981429}\n", "ps2 = {'t0': np.array([0.]), \n", " 'tau_t': np.array([0.01730027]), \n", " 'tau_in': np.array([0.00128385]), \n", " 'tau_n': np.array([0.00815371]), \n", " 'An': np.array([2.59789859])/1.553262177826752,\n", " 'At': np.array([0.02694577])/1.553262177826752}\n", "\n", "for i,p in enumerate([ps1, ps2]):\n", "\n", " event = ai.fit.pulse_template(parsam.t, **unfold(p, 1))\n", " plot_events(event.reshape(1,-1), t=parsam.t, show=True, \n", " savepath='test_data/pulse_shape_{}.pdf'.format(i), text=['']) # 'Standard Event'\n", " \n", "parsam.set_args(pulse_shapes=[[ps1['t0'], ps1['An'], ps1['At'], ps1['tau_n'], ps1['tau_in'], ps1['tau_t']],\n", " [ps2['t0'], ps2['An'], ps2['At'], ps2['tau_n'], ps2['tau_in'], ps2['tau_t']],\n", " ])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the noise power spectrum. You can always resample, in case you don't like it." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:37.641720Z", "start_time": "2021-11-30T08:55:36.420376Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:55:37.514986\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# sample nps\n", "\n", "parsam.set_args(nps=None)\n", "\n", "noise_par, info = parsam.sample_noise(size=1)\n", "\n", "# or define your own nps ...\n", "# noise_par['nps'] = nps[:, 1].reshape(1,-1)\n", "# info['fq'] = np.fft.rfftfreq(n=RECORD_LENGTH, d=1 / SAMPLE_FREQUENCY)\n", "\n", "# plot begin -------\n", "\n", "plt.close()\n", "\n", "fig, ax = plt.subplots(constrained_layout=True)\n", "\n", "ax.loglog(info['fq'], noise_par['nps'][0], color='black', linewidth=1.5, zorder=50)\n", "\n", "# eye guidelines\n", "\n", "fig.supxlabel('Frequency (Hz)')\n", "fig.supylabel('Amplitude (a.u.)')\n", "ax.text(x=.6,y=.5,s='', # 'Noise Power Spectrum', \n", " transform=ax.transAxes,\n", " bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))\n", "plt.savefig('test_data/noise_power_spectrum.pdf')\n", "plt.show()\n", "\n", "parsam.set_args(nps=noise_par['nps'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the saturation curve. You can always resample, in case you don't like it." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:47.389525Z", "start_time": "2021-11-30T08:55:46.506059Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:55:47.356653\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# sat = parsam.sample_saturation(size=1)\n", "# print(sat)\n", "\n", "# or define your own saturation ...\n", "#sat = {'K': np.array([8.8786987]),\n", "# 'C': np.array([1.23606841]),\n", "# 'Q': np.array([2.12402435]),\n", "# 'B': np.array([0.57876924]),\n", "# 'nu': np.array([1.61744008]),\n", "# 'A': np.array([-7.95906284])}\n", "\n", "sat = {'K': np.array([10]),\n", " 'C': np.array([3]),\n", " 'Q': np.array([0.5]),\n", " 'B': np.array([5]),\n", " 'nu': np.array([0.5])}\n", "\n", "sat['A'] = sat['K'] / (1 - (sat['C'] + sat['Q'])**(1/sat['nu']))\n", "\n", "plt.close()\n", "\n", "fig, ax = plt.subplots(constrained_layout=True)\n", "\n", "ax.plot(np.arange(0, 0.5, 0.01), ai.fit.scaled_logistic_curve(np.arange(0, 0.5, 0.01), **sat), color='black', linewidth=2.5)\n", "ax.plot(np.arange(0, 0.5, 0.01), np.arange(0, 0.5, 0.01), linestyle='dotted', color='black')\n", "\n", "fig.supxlabel('True Pulse Height (V)')\n", "fig.supylabel('Saturated Pulse Height (V)')\n", "\n", "ax.set_xlim(0,0.5)\n", "ax.set_ylim(0,0.3)\n", "\n", "ax.text(x=.6,y=.3,s='', transform=ax.transAxes, # Saturation Curve\n", " bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))\n", "\n", "plt.savefig('test_data/saturation_curve.pdf')\n", "plt.show()\n", "\n", "parsam.set_args(saturation_pars=sat)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot some augmented events." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:55:57.812629Z", "start_time": "2021-11-30T08:55:55.822928Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sample Noise...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c72235cdeb284ce286f32b6f07254b4b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Sample Polynomials...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cf946921ba6c4477aafcf867d73b2916", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Sample Pulse Nmbr 0\n", "Sample Pulse Nmbr 1\n", "Sample Saturation ...\n", "Pile Up\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:55:57.735564\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "NMBR_PLOTS = 1\n", "\n", "# classes = np.random.choice(CLASS_NAMES, size=NMBR_PLOTS)\n", "classes = [CLASS_NAMES[0] for i in range(NMBR_PLOTS)]\n", "\n", "for i in range(NMBR_PLOTS):\n", "\n", " event, info = parsam.get_event(label=classes[i],\n", " size=4,\n", " rasterize=RASTERIZE,\n", " poly=POLYNOMIAL_DRIFTS,\n", " square=SQUARE_WAVES,\n", " saturation=SATURATION,\n", " verb=True,\n", " )\n", "\n", " print(classes[i])\n", " plot_events(event, t=parsam.t, show=False)\n", " if classes[i] == 'Event Pulse':\n", " print('Resolution (only meaningful without Saturation): ', np.std(np.abs(np.max(event, axis=1) - info['pulse_height'])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create an HDF5 data set." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:59:03.540810Z", "start_time": "2021-11-30T08:58:08.553318Z" }, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "50aef184759642488d3096606c8a2204", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "Simulating Pile Up ...\n" ] } ], "source": [ "FNAME = 'test_data/pileup_v0_6.h5'\n", "IDX = list(range(1)) # which to take from the class names\n", "BATCHSIZE = 500\n", "EVENTS_PER_CLASS = {\n", " #'Event Pulse': 30000, # 9000,\n", " #'Noise': 9000,\n", " #'Decaying Baseline': 1500,\n", " #'Temperature Rise': 1500,\n", " #'Spike': 1500,\n", " #'Squid Jump': 1500,\n", " #'Reset': 6000,\n", " #'Cosinus Tail': 3000,\n", " #'Decaying Baseline with Event Pulse': 3000,\n", " #'Decaying Baseline with Tail Event': 3000,\n", " 'Pile Up': 3000, # 9000,\n", " #'Early or late Trigger': 3000,\n", " #'Carrier Event': 9000,\n", "}\n", "DIV = 1\n", "TOTAL_EVENTS = int(np.sum(list(EVENTS_PER_CLASS.values()))/DIV)\n", "BATCHSIZE = int(BATCHSIZE/DIV)\n", "\n", "NMBR_PILEUP = 2\n", "\n", "# -------------------------------------------\n", "# no changes required below this line!\n", "# -------------------------------------------\n", "\n", "with h5py.File(FNAME, 'w') as f:\n", " \n", " f.require_group('events')\n", " f['events'].create_dataset('event',\n", " shape=(1, TOTAL_EVENTS, RECORD_LENGTH),\n", " dtype=np.float32)\n", " f['events'].create_dataset('labels',\n", " shape=(1, TOTAL_EVENTS),\n", " dtype=int)\n", " f['events'].create_dataset('true_ph',\n", " shape=(1, TOTAL_EVENTS),\n", " dtype=float)\n", " f['events'].create_dataset('true_onset',\n", " shape=(TOTAL_EVENTS, ),\n", " dtype=float)\n", " \n", " f['events'].create_dataset('pulse_traces',\n", " shape=(1, TOTAL_EVENTS, NMBR_PILEUP, RECORD_LENGTH),\n", " dtype=np.float32)\n", " f['events'].create_dataset('pulse_height_pileup',\n", " shape=(1, TOTAL_EVENTS, NMBR_PILEUP),\n", " dtype=np.float32)\n", " f['events'].create_dataset('t0_pileup',\n", " shape=(1, TOTAL_EVENTS, NMBR_PILEUP),\n", " dtype=np.float32)\n", "\n", " f.require_group('saturation')\n", " f['saturation'].create_dataset('fitpar',\n", " data=np.array([sat['A'], sat['K'], sat['C'], sat['Q'], sat['B'], sat['nu'], ]))\n", "\n", " bar = tqdm(total=TOTAL_EVENTS)\n", " bcount = 0\n", "\n", " for i in IDX:\n", "\n", " bar.write('Simulating {} ...'.format(CLASS_NAMES[i]))\n", "\n", " nmbr_events = int(EVENTS_PER_CLASS[CLASS_NAMES[i]] / DIV)\n", " batches = int(nmbr_events / BATCHSIZE)\n", "\n", " for b in range(batches):\n", " event, info = parsam.get_event(label=CLASS_NAMES[i],\n", " size=BATCHSIZE,\n", " rasterize=RASTERIZE,\n", " poly=POLYNOMIAL_DRIFTS,\n", " square=SQUARE_WAVES,\n", " saturation=SATURATION,\n", " verb=False,\n", " )\n", "\n", " f['events']['event'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), :] = event\n", " f['events']['labels'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = label_names[CLASS_NAMES[i]]\n", " f['events']['true_ph'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_height']\n", " f['events']['true_onset'][int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['t0']*1000\n", " if 'pulse_traces' in info:\n", " for j in range(NMBR_PILEUP):\n", " f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_traces']\n", " f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), j] = info['pulse_height_pileup_{}'.format(j)]\n", " f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), j] = info['t0_pileup_{}'.format(j)]\n", " else:\n", " f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n", " f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n", " f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n", " f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = event\n", " f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = info['pulse_height']\n", " f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = info['t0']*1000\n", " \n", " # attributes\n", "\n", " f['events']['labels'].attrs.create(name='unlabeled', data=0)\n", " f['events']['labels'].attrs.create(name='Event_Pulse', data=1)\n", " f['events']['labels'].attrs.create(name='Test/Control_Pulse', data=2)\n", " f['events']['labels'].attrs.create(name='Noise', data=3)\n", " f['events']['labels'].attrs.create(name='Squid_Jump', data=4)\n", " f['events']['labels'].attrs.create(name='Spike', data=5)\n", " f['events']['labels'].attrs.create(name='Early_or_late_Trigger', data=6)\n", " f['events']['labels'].attrs.create(name='Pile_Up', data=7)\n", " f['events']['labels'].attrs.create(name='Carrier_Event', data=8)\n", " f['events']['labels'].attrs.create(name='Strongly_Saturated_Event_Pulse', data=9)\n", " f['events']['labels'].attrs.create(name='Strongly_Saturated_Test/Control_Pulse', data=10)\n", " f['events']['labels'].attrs.create(name='Decaying_Baseline', data=11)\n", " f['events']['labels'].attrs.create(name='Temperature_Rise', data=12)\n", " f['events']['labels'].attrs.create(name='Stick_Event', data=13)\n", " f['events']['labels'].attrs.create(name='Square_Waves', data=14)\n", " f['events']['labels'].attrs.create(name='Human_Disturbance', data=15)\n", " f['events']['labels'].attrs.create(name='Large_Sawtooth', data=16)\n", " f['events']['labels'].attrs.create(name='Cosinus_Tail', data=17)\n", " f['events']['labels'].attrs.create(name='Light_only_Event', data=18)\n", " f['events']['labels'].attrs.create(name='Ring_Light_Event', data=19)\n", " f['events']['labels'].attrs.create(name='Sharp_Light_Event', data=20)\n", " f['events']['labels'].attrs.create(name='Reset', data=21)\n", " f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Event_Pulse', data=22)\n", " f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Tail_Pulse', data=23)\n", " f['events']['labels'].attrs.create(name='unknown/other', data=99)\n", "\n", " # f['saturation']['fitpar'].attrs.create(name='A', data=0)\n", " # f['saturation']['fitpar'].attrs.create(name='K', data=1)\n", " # f['saturation']['fitpar'].attrs.create(name='C', data=2)\n", " # f['saturation']['fitpar'].attrs.create(name='Q', data=3)\n", " # f['saturation']['fitpar'].attrs.create(name='B', data=4)\n", " # f['saturation']['fitpar'].attrs.create(name='nu', data=5)\n", "\n", " bar.update(BATCHSIZE)\n", " bcount += 1\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:59:40.711123Z", "start_time": "2021-11-30T08:59:40.595233Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:59:40.683848\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "with h5py.File(FNAME, 'r+') as f:\n", " idx = 5\n", " plt.plot(f['events']['event'][0, idx])\n", " for i in range(NMBR_PILEUP):\n", " plt.plot(f['events']['pulse_traces'][0, idx, i])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T08:59:54.415313Z", "start_time": "2021-11-30T08:59:54.330667Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T09:59:54.405965\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "with h5py.File('test_data/pileup_v0_6.h5', 'r+') as f:\n", " idx = 8\n", " plt.plot(f['events']['event'][0, idx], color='black', alpha=0.05, linewidth=0.5)\n", " plt.axis('off')\n", " plt.savefig('post_background.pdf')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Done." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the separation model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We simulate the non-pileuped pulses and try to reconstruct the individual ones. We use an LSTM model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:40.535703Z", "start_time": "2021-11-30T09:05:40.521160Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from cait.cuts import LogicalCut\n", "from cait.fit import scaled_logistic_curve, pulse_template\n", "from cait.models import SeparationLSTM\n", "import os\n", "from pytorch_lightning import Trainer\n", "from torchvision import transforms\n", "import h5py\n", "from cait.datasets import RemoveOffset, Normalize, ToTensor, CryoDataModule, PileUpDownSample\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import torch\n", "from tqdm.auto import trange, tqdm\n", "%config InlineBackend.figure_formats = ['svg'] # we need this for a suitable resolution of the plots" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:41.132283Z", "start_time": "2021-11-30T09:05:41.125632Z" } }, "outputs": [], "source": [ "fontsize = 14\n", "\n", "mpl.rcParams['xtick.labelsize'] = fontsize\n", "mpl.rcParams['ytick.labelsize'] = fontsize\n", "mpl.rcParams['font.size'] = fontsize\n", "mpl.rcParams['axes.titlesize'] = fontsize\n", "mpl.rcParams['axes.labelsize'] = fontsize\n", "mpl.rcParams['legend.fontsize'] = fontsize\n", "mpl.rcParams['mathtext.fontset'] = 'stix'\n", "mpl.rcParams['font.family'] = 'STIXGeneral'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:43.038688Z", "start_time": "2021-11-30T09:05:43.034881Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n" ] } ], "source": [ "print(torch.cuda.is_available())\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:52.308847Z", "start_time": "2021-11-30T09:05:52.303190Z" } }, "outputs": [], "source": [ "# some parameters\n", "# nmbr_gpus = ... uncommment and put in trainer to use GPUs\n", "path_h5 = 'test_data/pileup_v0_6.h5'\n", "type = 'events' # the group key for the data in the HDF5 set\n", "keys = ['event', 'pulse_traces'] # the datasets in the group from which we include data in the samples for the NN\n", "channel_indices = [[0], [0]] # the first indices of the datasets\n", "feature_indices = [None, None] # the third indices of the datasets\n", "feature_keys = ['event_ch0'] # the keys in the samples of the NN dataset that are input to the NN\n", " # in the data set for the NN, the keys have additionally appended the channel index \n", "label_keys = ['pulse_traces_ch0'] # the keys in the samples of the NN dataset that are labels to the NN\n", "norm_vals = {'event_ch0': [0, 0.5], 'pulse_traces_ch0': [0, 0.5]} # we do a min - max normalization of all samples, so these are roughly the lowest and highest values of the events in the data set\n", "down_keys = ['event_ch0', 'pulse_traces_ch0'] # if we input the raw time series, we apply downsampling first\n", "down = 64 # all samples in the NN dataset with the indices specified above are by the factor down downsampled\n", "record_length = 16384\n", "max_epochs = 5 # the maximal trianing epochs of the neural network\n", "\n", "save_naming = 'lstm' # 'lstm' or 'unet'\n", "nmbr_pileup = 2\n", "plots = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset and Model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:54.224298Z", "start_time": "2021-11-30T09:05:54.220334Z" } }, "outputs": [], "source": [ "# create the transforms\n", "trans = transforms.Compose([# RemoveOffset(keys=feature_keys),\n", " Normalize(norm_vals=norm_vals),\n", " PileUpDownSample(keys=down_keys, down=down),\n", " ToTensor()])" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:05:57.470949Z", "start_time": "2021-11-30T09:05:57.466483Z" } }, "outputs": [], "source": [ "# create data module and init the setup\n", "dm = CryoDataModule(hdf5_path=path_h5,\n", " type=type,\n", " keys=keys,\n", " channel_indices=channel_indices,\n", " feature_indices=feature_indices,\n", " transform=trans)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:06:11.691502Z", "start_time": "2021-11-30T09:06:11.529252Z" } }, "outputs": [], "source": [ "dm.prepare_data(val_size=0.1,\n", " test_size=0.2,\n", " batch_size=64,\n", " dataset_size=None,\n", " nmbr_workers=0, # set to number of CPUS on the machine (strange fact: for me 0 is faster than 8?) probabily inefficient implementation of our dataset for multithreading (issue on GitLab is open)\n", " only_idx=None,\n", " shuffle_dataset=True,\n", " random_seed=21, # 21, 1, 2, 3, 4\n", " feature_keys=feature_keys,\n", " label_keys=label_keys,\n", " keys_one_hot=label_keys)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:06:12.975444Z", "start_time": "2021-11-30T09:06:12.962512Z" } }, "outputs": [], "source": [ "dm.setup()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:12:59.183608Z", "start_time": "2021-11-30T09:12:59.171340Z" } }, "outputs": [], "source": [ "if save_naming == 'lstm':\n", " \n", " ips = 16\n", " \n", " model = SeparationLSTM(\n", " input_size=ips,\n", " hidden_size=ips * 10,\n", " num_layers=3,\n", " seq_steps=int(dm.dims[1] / ips), # downsampling is already considered in dm\n", " nmbr_pileup=nmbr_pileup, \n", " device_name=device,\n", " lr=1e-3,\n", " feature_keys=feature_keys,\n", " label_keys=label_keys,\n", " down=down,\n", " down_keys=feature_keys,\n", " norm_vals=norm_vals,\n", " offset_keys=feature_keys,\n", " )\n", "\n", "else:\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## View Events" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:01.020268Z", "start_time": "2021-11-30T09:13:00.565617Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T10:13:00.999948\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "if plots:\n", "\n", " mpl.rcParams['figure.figsize'] = (6, 6)\n", "\n", " width=4\n", " fig, axes = plt.subplots(nrows=width, ncols=width)\n", " for i in range(width**2):\n", " if i < 8:\n", " idx = i\n", " else:\n", " idx = i + 100\n", " example = dm.dataset_full[idx]['event_ch0']\n", " axes[int(i/width),int(i%width)].plot(example, color='black')\n", " axes[int(i/width),int(i%width)].axis('off')\n", " # axes[int(i/width),int(i%width)].set_ylim((-0.1,1.1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:02.825213Z", "start_time": "2021-11-30T09:13:02.821258Z" } }, "outputs": [], "source": [ "checkpoint_callback = ModelCheckpoint(dirpath='callbacks',\n", " monitor='val_loss',\n", " filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:03.793326Z", "start_time": "2021-11-30T09:13:03.783201Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n" ] } ], "source": [ "trainer = Trainer(max_epochs=max_epochs,\n", " callbacks=[checkpoint_callback],\n", " gpus=0)\n", "# keyword gpus=nmbr_gpus for GPU Usage\n", "# keyword max_epochs for number of maximal epochs" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:24.219762Z", "start_time": "2021-11-30T09:13:04.711395Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | lstm | LSTM | 526 K \n", "1 | fc | Linear | 5 K \n", "/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n", "\n", "The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n", "\n", "The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7801e3010919423a90dd8d0bea3173ae", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:216: UserWarning:\n", "\n", "Please also save or load the state of the optimizer when saving or loading the scheduler.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.fit(model=model,\n", " datamodule=dm)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:26.479144Z", "start_time": "2021-11-30T09:13:26.475711Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-epoch=03-val_loss=0.00.ckpt\n" ] } ], "source": [ "print('Best model: ', checkpoint_callback.best_model_path)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:27.714413Z", "start_time": "2021-11-30T09:13:27.697118Z" } }, "outputs": [ { "data": { "text/plain": [ "SeparationLSTM(\n", " (lstm): LSTM(16, 160, num_layers=3, batch_first=True)\n", " (fc): Linear(in_features=160, out_features=32, bias=True)\n", ")" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_from_checkpoint(checkpoint_callback.best_model_path)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:28.182342Z", "start_time": "2021-11-30T09:13:28.179401Z" } }, "outputs": [], "source": [ "# model.cuda()" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:29.778323Z", "start_time": "2021-11-30T09:13:29.272276Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n", "\n", "The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "41b85361d1b74a41b59dfb9dbac2e2f8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "DATALOADER:0 TEST RESULTS\n", "{'test_loss': tensor(0.0028),\n", " 'train_loss': tensor(0.0022),\n", " 'val_loss': tensor(0.0029)}\n", "--------------------------------------------------------------------------------\n", "\n", "[{'train_loss': 0.0021677217446267605, 'val_loss': 0.002877358580008149, 'test_loss': 0.002773477230221033}]\n" ] } ], "source": [ "result = trainer.test(model=model,\n", " datamodule=dm)\n", "print(result)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:32.746169Z", "start_time": "2021-11-30T09:13:32.467439Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T10:13:32.692972\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot saturation curve\n", "\n", "mpl.rcParams['figure.figsize'] = (5, 4)\n", "\n", "x_temp = np.arange(17,20,0.005)\n", "y_temp = np.zeros(len(x_temp))\n", "y_temp[x_temp > 17.5] = 1.6*(x_temp[x_temp > 17.5] - 17.5)\n", "y_temp[x_temp > 17.7] = - 1/(x_temp[x_temp > 17.7] - 19) - 0.47\n", "y_temp[x_temp > 18.7] = 0.3*np.log(10*x_temp[x_temp > 18.7] - 10*18.65) + 3\n", "y_temp[x_temp > 19] = 3.38\n", "y_temp[x_temp < 17.5] += np.random.normal(scale=0.01, size=len(x_temp[x_temp < 17.5]))\n", "y_temp[np.logical_and(x_temp < 19, x_temp > 17.5)] += np.random.normal(scale=0.04, size=len(x_temp[np.logical_and(x_temp < 19, x_temp > 17.5)]))\n", "y_temp[x_temp > 19] += np.random.normal(scale=0.0015, size=len(x_temp[x_temp > 19]))\n", "x_temp += np.random.normal(scale=0.01, size=len(x_temp))\n", "y_temp /= y_temp[-1]\n", "y_temp[y_temp > 1.005] = 1.005\n", "y_temp[y_temp < -0.01] = -0.01\n", "\n", "plt.axhline(y=0.5, xmax=0.51, color='red')\n", "plt.axhline(y=0.82, xmax=0.56, color='red')\n", "plt.axvline(x=18.54, ymax=0.497, color='red')\n", "plt.axvline(x=18.7, ymax=0.787, color='red')\n", "plt.scatter(x_temp, y_temp, marker='.', s=4, color='black')\n", "\n", "plt.text(x=17.1, y=0.64, s='$\\Delta R$', color='red')\n", "plt.text(x=18.8, y=0.05, s='$\\Delta T$', color='red')\n", "\n", "plt.xlabel('Temperature (mK)')\n", "plt.ylabel('Thermometer Resistance (a.u.)')\n", "plt.tight_layout()\n", "\n", "plt.savefig('test_data/curve.pdf')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:34.587668Z", "start_time": "2021-11-30T09:13:34.567352Z" } }, "outputs": [], "source": [ "# plot some outputs\n", "f = h5py.File(dm.hdf5_path, 'r')\n", "test_idx = dm.test_sampler.indices[5:10]\n", "test_idx.sort()\n", "x_ = {feature_keys[0]: f[type][keys[0]][channel_indices[0][0], test_idx]}\n", "y_ = np.array(f[type][keys[1]][channel_indices[1][0], test_idx])\n", "seperated = model.predict(x_.copy()).cpu().numpy()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:35.839098Z", "start_time": "2021-11-30T09:13:35.073509Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T10:13:35.779454\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "if plots:\n", " \n", " sat = {'K': np.array([10]),\n", " 'C': np.array([3]),\n", " 'Q': np.array([0.5]),\n", " 'B': np.array([5]),\n", " 'nu': np.array([0.5])}\n", "\n", " sat['A'] = sat['K'] / (1 - (sat['C'] + sat['Q'])**(1/sat['nu']))\n", "\n", " mpl.rcParams['figure.figsize'] = (10, 4)\n", "\n", " fig, axes = plt.subplots(nrows=nmbr_pileup + 1, ncols=5, sharey=False)\n", "\n", " axes[0, 0].set_title(' ')\n", "\n", " for i in range(5): # rows\n", " example = x_['event_ch0'][i]\n", " lab = 'measured' if i == 0 else ''\n", " axes[0,i].plot(example, color='black', label=lab)\n", " lab = 'separated (ours)' if i == 0 else ''\n", " axes[0,i].plot(scaled_logistic_curve(np.sum(seperated[i], axis=0), **sat), color='red', label=lab)\n", " axes[0,i].axis('off')\n", " for j in range(nmbr_pileup): # columns\n", " lab = 'ground truth' if j == 0 and i == 0 else ''\n", " axes[j+1, i].plot(y_[i, j], color='olive', label=lab)\n", " sep_ = seperated[i, j]\n", " axes[j+1, i].plot(sep_, color='red')\n", " axes[j+1, i].axis('off')\n", " \n", " lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]\n", " lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]\n", "\n", " fig.legend(lines, labels, ncol=3, frameon=False, loc='upper right', fontsize=14)\n", " \n", " axes[0, 0].text(0.06,0.8,'Input',rotation=90, transform=plt.gcf().transFigure)\n", " axes[1, 0].text(0.06,0.45,'Output 1',rotation=90, transform=plt.gcf().transFigure)\n", " axes[2, 0].text(0.06,0.15,'Output 2',rotation=90, transform=plt.gcf().transFigure)\n", "\n", " # fig.supylabel('Amplitude (a.u.)')\n", " # fig.supxlabel('Time (a.u.)')\n", " fig.tight_layout()\n", " \n", " plt.savefig('test_data/events.pdf', dpi=150)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:13:44.632866Z", "start_time": "2021-11-30T09:13:44.628443Z" } }, "outputs": [], "source": [ "def get_ph(x):\n", " return np.max(x - np.mean(x[..., :500], keepdims=True), axis=-1)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "ExecuteTime": { "end_time": "2021-11-30T09:14:00.413234Z", "start_time": "2021-11-30T09:14:00.135640Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-11-30T10:14:00.365308\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "if plots:\n", "\n", " mpl.rcParams['figure.figsize'] = (5, 4)\n", "\n", " bins = np.linspace(0,0.6,300)\n", " wei = 1/(bins[1] - bins[0])/100000\n", "\n", " plt.hist(ph_true, weights=wei*np.ones(len(ph_true)), bins=bins, \n", " label='ground truth', histtype='step', linewidth=1.5, color='olive')\n", " plt.hist(ph_recon, weights=wei*np.ones(len(ph_recon)), bins=bins, \n", " label='measured', histtype='step', linewidth=1.5, color='black')\n", " plt.hist(ph_sep, weights=wei*np.ones(len(ph_sep)), bins=bins, \n", " label='separated (ours)', histtype='step', linewidth=2.5, color='red')\n", "\n", " plt.xlim(0,0.6)\n", " # plt.ylim(0,2)\n", "\n", " plt.xlabel('Recoil Energy (a.u.)')\n", " plt.ylabel(r'$10^6$ $\\cdot$ Counts / Energy Unit')\n", " plt.legend(loc='upper right')\n", " plt.tight_layout()\n", " plt.savefig('test_data/hist.pdf')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Done." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }