{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pile Up Augmentation\n",
"\n",
"```{warning}\n",
"Note that this tutorial and the features it describes are currently **not maintained**. If you wish to help out and contribute to it, please reach out to the maintainers.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2021-07-29T09:47:46.276007Z",
"start_time": "2021-07-29T09:47:46.268922Z"
}
},
"source": [
"*Correspondence to: felix.wagner@oeaw.ac.at*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we explore the data augmentation module of the Cait Python package again and create pile up events of a detector. This notebook is based on the workshop presentation \"Nonlinear pile-up separation with LSTM neural networks for cryogenic particle detectors\", at the Machine Learning for the Physical Sciences workshop at Neurips21 (https://ml4physicalsciences.github.io/2021/)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:52:13.633572Z",
"start_time": "2021-11-30T08:52:09.221139Z"
}
},
"outputs": [],
"source": [
"from cait.augment import ParameterSampler, plot_events, unfold, L2, EventDefinition, Distribution\n",
"import numpy as np\n",
"import cait as ai \n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import h5py\n",
"from tqdm.auto import tqdm\n",
"from scipy import signal\n",
"from PIL import Image, ImageDraw\n",
"import os\n",
"%config InlineBackend.figure_formats = ['svg']"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:52:14.879087Z",
"start_time": "2021-11-30T08:52:14.876035Z"
}
},
"outputs": [],
"source": [
"if not os.path.exists('test_data/'):\n",
" os.makedirs('test_data/')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some definitions for the plots."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:52:16.369741Z",
"start_time": "2021-11-30T08:52:16.366455Z"
}
},
"outputs": [],
"source": [
"mpl.rcParams['figure.figsize'] = (7.2, 4.45)\n",
"mpl.rcParams['savefig.dpi'] = 300"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define the global parameters.\n",
"\n",
"For the whole notebook: Only capitalized parameters need to be set by the user."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:54:52.010665Z",
"start_time": "2021-11-30T08:54:52.003725Z"
}
},
"outputs": [],
"source": [
"RECORD_LENGTH = 16384\n",
"SAMPLE_FREQUENCY = 25000\n",
"RESOLUTION = 0.01\n",
"\n",
"POLYNOMIAL_DRIFTS = True\n",
"RASTERIZE = False\n",
"SQUARE_WAVES = False\n",
"SATURATION = True\n",
"CLASS_NAMES = [#'Event Pulse',\n",
" #'Noise',\n",
" #'Decaying Baseline',\n",
" #'Temperature Rise',\n",
" #'Spike',\n",
" #'Squid Jump',\n",
" #'Reset',\n",
" #'Cosinus Tail',\n",
" #'Decaying Baseline with Event Pulse',\n",
" 'Pile Up',\n",
" #'Early or late Trigger',\n",
" #'Carrier Event',\n",
" #'Decaying Baseline with Tail Event',\n",
" ]\n",
"\n",
"\n",
"# ----------------------------------------- \n",
"# no need to change below parameters\n",
"# ----------------------------------------- \n",
"\n",
"label_names = {\n",
" 'unlabeled': 0,\n",
" 'Event Pulse': 1,\n",
" 'Test/Control Pulse': 2,\n",
" 'Noise': 3,\n",
" 'Squid Jump': 4,\n",
" 'Spike': 5,\n",
" 'Early or late Trigger': 6,\n",
" 'Pile Up': 7,\n",
" 'Carrier Event': 8,\n",
" 'Strongly Saturated Event Pulse': 9,\n",
" 'Strongly Saturated Test/Control Pulse': 10,\n",
" 'Decaying Baseline': 11,\n",
" 'Temperature Rise': 12,\n",
" 'Stick Event': 13,\n",
" 'Square Waves': 14,\n",
" 'Human Disturbance': 15,\n",
" 'Large Sawtooth': 16,\n",
" 'Cosinus Tail': 17,\n",
" 'Light only Event': 18,\n",
" 'Ring & Light Event': 19,\n",
" 'Sharp Light Event': 20,\n",
" 'Reset': 21,\n",
" 'Decaying Baseline with Event Pulse': 22,\n",
" 'Decaying Baseline with Tail Event': 23,\n",
" 'unknown/other': 99,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the sampler class."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:54:54.539982Z",
"start_time": "2021-11-30T08:54:54.536121Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"parsam = ParameterSampler(record_length=RECORD_LENGTH,\n",
" sample_frequency=SAMPLE_FREQUENCY)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the detector resolution. The measured resolution will deviate from the set resolution, depending on the noise power spectrum and the method for reconstructing the resolution."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:54:55.556287Z",
"start_time": "2021-11-30T08:54:55.552840Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"parsam.set_args(resolution=np.array([RESOLUTION]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the pile up data set we define a custom event definition."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:01.943535Z",
"start_time": "2021-11-30T08:55:01.936267Z"
}
},
"outputs": [],
"source": [
"class PileUpEventDefinition(EventDefinition):\n",
"\n",
" def get_class_pars(self, label, size, **kwargs):\n",
"\n",
" if label == 'Event Pulse':\n",
" pileups = np.ones(size)\n",
" ps_nmbr = 0\n",
" decay = False\n",
" rise = False\n",
" spike = False\n",
" jump = False\n",
" pulse_reset = False\n",
" tail = False\n",
" onset_iv = np.array([-160, 400]) * 0.001\n",
"\n",
" elif label == 'Pile Up':\n",
" pileups = int(2)*np.ones(size, dtype=int) # np.random.choice(a=[2, 3], size=size, p=[0.5, 0.5])\n",
" ps_nmbr = 0\n",
" decay = False\n",
" rise = False\n",
" spike = False\n",
" jump = False\n",
" pulse_reset = False\n",
" tail = False\n",
" onset_iv = np.array([-20, 20]) * 0.001\n",
"\n",
" else:\n",
" raise KeyError('Class {} not available.'.format(label))\n",
"\n",
" return pileups, ps_nmbr, decay, rise, spike, jump, pulse_reset, tail, onset_iv\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:04.715465Z",
"start_time": "2021-11-30T08:55:04.710586Z"
}
},
"outputs": [],
"source": [
"class PileUpPulseHeights(Distribution):\n",
"\n",
" def __init__(self, mini=0, maxi=0.5):\n",
" self.mini = mini\n",
" self.maxi = maxi\n",
"\n",
" def sample(self, size, **kwargs):\n",
"\n",
" pulse_height = np.random.uniform(low=self.mini, high=self.maxi, size=size)\n",
"\n",
" return pulse_height"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:06.502503Z",
"start_time": "2021-11-30T08:55:06.498649Z"
}
},
"outputs": [],
"source": [
"parsam.set_args(event_definition=PileUpEventDefinition())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:07.042579Z",
"start_time": "2021-11-30T08:55:07.039318Z"
}
},
"outputs": [],
"source": [
"parsam.set_args(ph_dist=PileUpPulseHeights())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the pulse shapes. You can always resample, in case you don't like them."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:23.995035Z",
"start_time": "2021-11-30T08:55:23.564271Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# sample pulse shapes\n",
"\n",
"# ps1, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))\n",
"# ps2, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))\n",
"# print(ps1)\n",
"# print(ps2)\n",
"\n",
"mpl.rcParams['figure.figsize'] = (5, 4)\n",
"# mpl.rcParams['savefig.dpi'] = 300\n",
"fontsize = 14\n",
"\n",
"mpl.rcParams['xtick.labelsize'] = fontsize\n",
"mpl.rcParams['ytick.labelsize'] = fontsize\n",
"mpl.rcParams['font.size'] = fontsize\n",
"mpl.rcParams['axes.titlesize'] = fontsize\n",
"mpl.rcParams['axes.labelsize'] = fontsize\n",
"mpl.rcParams['legend.fontsize'] = fontsize\n",
"mpl.rcParams['mathtext.fontset'] = 'stix'\n",
"mpl.rcParams['font.family'] = 'STIXGeneral'\n",
"\n",
"# or define your own pulse shapes ...\n",
"ps1 = {'t0': np.array([0.]), \n",
" 'tau_t': np.array([0.05815608]), \n",
" 'tau_in': np.array([0.02059209]), \n",
" 'tau_n': np.array([0.01395427]), \n",
" 'An': np.array([-0.02508469])/0.0715207609981429, \n",
" 'At': np.array([0.14102789])/0.0715207609981429}\n",
"ps2 = {'t0': np.array([0.]), \n",
" 'tau_t': np.array([0.01730027]), \n",
" 'tau_in': np.array([0.00128385]), \n",
" 'tau_n': np.array([0.00815371]), \n",
" 'An': np.array([2.59789859])/1.553262177826752,\n",
" 'At': np.array([0.02694577])/1.553262177826752}\n",
"\n",
"for i,p in enumerate([ps1, ps2]):\n",
"\n",
" event = ai.fit.pulse_template(parsam.t, **unfold(p, 1))\n",
" plot_events(event.reshape(1,-1), t=parsam.t, show=True, \n",
" savepath='test_data/pulse_shape_{}.pdf'.format(i), text=['']) # 'Standard Event'\n",
" \n",
"parsam.set_args(pulse_shapes=[[ps1['t0'], ps1['An'], ps1['At'], ps1['tau_n'], ps1['tau_in'], ps1['tau_t']],\n",
" [ps2['t0'], ps2['An'], ps2['At'], ps2['tau_n'], ps2['tau_in'], ps2['tau_t']],\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the noise power spectrum. You can always resample, in case you don't like it."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:37.641720Z",
"start_time": "2021-11-30T08:55:36.420376Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# sample nps\n",
"\n",
"parsam.set_args(nps=None)\n",
"\n",
"noise_par, info = parsam.sample_noise(size=1)\n",
"\n",
"# or define your own nps ...\n",
"# noise_par['nps'] = nps[:, 1].reshape(1,-1)\n",
"# info['fq'] = np.fft.rfftfreq(n=RECORD_LENGTH, d=1 / SAMPLE_FREQUENCY)\n",
"\n",
"# plot begin -------\n",
"\n",
"plt.close()\n",
"\n",
"fig, ax = plt.subplots(constrained_layout=True)\n",
"\n",
"ax.loglog(info['fq'], noise_par['nps'][0], color='black', linewidth=1.5, zorder=50)\n",
"\n",
"# eye guidelines\n",
"\n",
"fig.supxlabel('Frequency (Hz)')\n",
"fig.supylabel('Amplitude (a.u.)')\n",
"ax.text(x=.6,y=.5,s='', # 'Noise Power Spectrum', \n",
" transform=ax.transAxes,\n",
" bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))\n",
"plt.savefig('test_data/noise_power_spectrum.pdf')\n",
"plt.show()\n",
"\n",
"parsam.set_args(nps=noise_par['nps'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the saturation curve. You can always resample, in case you don't like it."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:47.389525Z",
"start_time": "2021-11-30T08:55:46.506059Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# sat = parsam.sample_saturation(size=1)\n",
"# print(sat)\n",
"\n",
"# or define your own saturation ...\n",
"#sat = {'K': np.array([8.8786987]),\n",
"# 'C': np.array([1.23606841]),\n",
"# 'Q': np.array([2.12402435]),\n",
"# 'B': np.array([0.57876924]),\n",
"# 'nu': np.array([1.61744008]),\n",
"# 'A': np.array([-7.95906284])}\n",
"\n",
"sat = {'K': np.array([10]),\n",
" 'C': np.array([3]),\n",
" 'Q': np.array([0.5]),\n",
" 'B': np.array([5]),\n",
" 'nu': np.array([0.5])}\n",
"\n",
"sat['A'] = sat['K'] / (1 - (sat['C'] + sat['Q'])**(1/sat['nu']))\n",
"\n",
"plt.close()\n",
"\n",
"fig, ax = plt.subplots(constrained_layout=True)\n",
"\n",
"ax.plot(np.arange(0, 0.5, 0.01), ai.fit.scaled_logistic_curve(np.arange(0, 0.5, 0.01), **sat), color='black', linewidth=2.5)\n",
"ax.plot(np.arange(0, 0.5, 0.01), np.arange(0, 0.5, 0.01), linestyle='dotted', color='black')\n",
"\n",
"fig.supxlabel('True Pulse Height (V)')\n",
"fig.supylabel('Saturated Pulse Height (V)')\n",
"\n",
"ax.set_xlim(0,0.5)\n",
"ax.set_ylim(0,0.3)\n",
"\n",
"ax.text(x=.6,y=.3,s='', transform=ax.transAxes, # Saturation Curve\n",
" bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))\n",
"\n",
"plt.savefig('test_data/saturation_curve.pdf')\n",
"plt.show()\n",
"\n",
"parsam.set_args(saturation_pars=sat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Plot some augmented events."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:55:57.812629Z",
"start_time": "2021-11-30T08:55:55.822928Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample Noise...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c72235cdeb284ce286f32b6f07254b4b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Sample Polynomials...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf946921ba6c4477aafcf867d73b2916",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Sample Pulse Nmbr 0\n",
"Sample Pulse Nmbr 1\n",
"Sample Saturation ...\n",
"Pile Up\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"NMBR_PLOTS = 1\n",
"\n",
"# classes = np.random.choice(CLASS_NAMES, size=NMBR_PLOTS)\n",
"classes = [CLASS_NAMES[0] for i in range(NMBR_PLOTS)]\n",
"\n",
"for i in range(NMBR_PLOTS):\n",
"\n",
" event, info = parsam.get_event(label=classes[i],\n",
" size=4,\n",
" rasterize=RASTERIZE,\n",
" poly=POLYNOMIAL_DRIFTS,\n",
" square=SQUARE_WAVES,\n",
" saturation=SATURATION,\n",
" verb=True,\n",
" )\n",
"\n",
" print(classes[i])\n",
" plot_events(event, t=parsam.t, show=False)\n",
" if classes[i] == 'Event Pulse':\n",
" print('Resolution (only meaningful without Saturation): ', np.std(np.abs(np.max(event, axis=1) - info['pulse_height'])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create an HDF5 data set."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:59:03.540810Z",
"start_time": "2021-11-30T08:58:08.553318Z"
},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "50aef184759642488d3096606c8a2204",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3000.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r\n",
"Simulating Pile Up ...\n"
]
}
],
"source": [
"FNAME = 'test_data/pileup_v0_6.h5'\n",
"IDX = list(range(1)) # which to take from the class names\n",
"BATCHSIZE = 500\n",
"EVENTS_PER_CLASS = {\n",
" #'Event Pulse': 30000, # 9000,\n",
" #'Noise': 9000,\n",
" #'Decaying Baseline': 1500,\n",
" #'Temperature Rise': 1500,\n",
" #'Spike': 1500,\n",
" #'Squid Jump': 1500,\n",
" #'Reset': 6000,\n",
" #'Cosinus Tail': 3000,\n",
" #'Decaying Baseline with Event Pulse': 3000,\n",
" #'Decaying Baseline with Tail Event': 3000,\n",
" 'Pile Up': 3000, # 9000,\n",
" #'Early or late Trigger': 3000,\n",
" #'Carrier Event': 9000,\n",
"}\n",
"DIV = 1\n",
"TOTAL_EVENTS = int(np.sum(list(EVENTS_PER_CLASS.values()))/DIV)\n",
"BATCHSIZE = int(BATCHSIZE/DIV)\n",
"\n",
"NMBR_PILEUP = 2\n",
"\n",
"# -------------------------------------------\n",
"# no changes required below this line!\n",
"# -------------------------------------------\n",
"\n",
"with h5py.File(FNAME, 'w') as f:\n",
" \n",
" f.require_group('events')\n",
" f['events'].create_dataset('event',\n",
" shape=(1, TOTAL_EVENTS, RECORD_LENGTH),\n",
" dtype=np.float32)\n",
" f['events'].create_dataset('labels',\n",
" shape=(1, TOTAL_EVENTS),\n",
" dtype=int)\n",
" f['events'].create_dataset('true_ph',\n",
" shape=(1, TOTAL_EVENTS),\n",
" dtype=float)\n",
" f['events'].create_dataset('true_onset',\n",
" shape=(TOTAL_EVENTS, ),\n",
" dtype=float)\n",
" \n",
" f['events'].create_dataset('pulse_traces',\n",
" shape=(1, TOTAL_EVENTS, NMBR_PILEUP, RECORD_LENGTH),\n",
" dtype=np.float32)\n",
" f['events'].create_dataset('pulse_height_pileup',\n",
" shape=(1, TOTAL_EVENTS, NMBR_PILEUP),\n",
" dtype=np.float32)\n",
" f['events'].create_dataset('t0_pileup',\n",
" shape=(1, TOTAL_EVENTS, NMBR_PILEUP),\n",
" dtype=np.float32)\n",
"\n",
" f.require_group('saturation')\n",
" f['saturation'].create_dataset('fitpar',\n",
" data=np.array([sat['A'], sat['K'], sat['C'], sat['Q'], sat['B'], sat['nu'], ]))\n",
"\n",
" bar = tqdm(total=TOTAL_EVENTS)\n",
" bcount = 0\n",
"\n",
" for i in IDX:\n",
"\n",
" bar.write('Simulating {} ...'.format(CLASS_NAMES[i]))\n",
"\n",
" nmbr_events = int(EVENTS_PER_CLASS[CLASS_NAMES[i]] / DIV)\n",
" batches = int(nmbr_events / BATCHSIZE)\n",
"\n",
" for b in range(batches):\n",
" event, info = parsam.get_event(label=CLASS_NAMES[i],\n",
" size=BATCHSIZE,\n",
" rasterize=RASTERIZE,\n",
" poly=POLYNOMIAL_DRIFTS,\n",
" square=SQUARE_WAVES,\n",
" saturation=SATURATION,\n",
" verb=False,\n",
" )\n",
"\n",
" f['events']['event'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), :] = event\n",
" f['events']['labels'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = label_names[CLASS_NAMES[i]]\n",
" f['events']['true_ph'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_height']\n",
" f['events']['true_onset'][int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['t0']*1000\n",
" if 'pulse_traces' in info:\n",
" for j in range(NMBR_PILEUP):\n",
" f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_traces']\n",
" f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), j] = info['pulse_height_pileup_{}'.format(j)]\n",
" f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), j] = info['t0_pileup_{}'.format(j)]\n",
" else:\n",
" f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n",
" f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n",
" f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = 0\n",
" f['events']['pulse_traces'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = event\n",
" f['events']['pulse_height_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = info['pulse_height']\n",
" f['events']['t0_pileup'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), 0] = info['t0']*1000\n",
" \n",
" # attributes\n",
"\n",
" f['events']['labels'].attrs.create(name='unlabeled', data=0)\n",
" f['events']['labels'].attrs.create(name='Event_Pulse', data=1)\n",
" f['events']['labels'].attrs.create(name='Test/Control_Pulse', data=2)\n",
" f['events']['labels'].attrs.create(name='Noise', data=3)\n",
" f['events']['labels'].attrs.create(name='Squid_Jump', data=4)\n",
" f['events']['labels'].attrs.create(name='Spike', data=5)\n",
" f['events']['labels'].attrs.create(name='Early_or_late_Trigger', data=6)\n",
" f['events']['labels'].attrs.create(name='Pile_Up', data=7)\n",
" f['events']['labels'].attrs.create(name='Carrier_Event', data=8)\n",
" f['events']['labels'].attrs.create(name='Strongly_Saturated_Event_Pulse', data=9)\n",
" f['events']['labels'].attrs.create(name='Strongly_Saturated_Test/Control_Pulse', data=10)\n",
" f['events']['labels'].attrs.create(name='Decaying_Baseline', data=11)\n",
" f['events']['labels'].attrs.create(name='Temperature_Rise', data=12)\n",
" f['events']['labels'].attrs.create(name='Stick_Event', data=13)\n",
" f['events']['labels'].attrs.create(name='Square_Waves', data=14)\n",
" f['events']['labels'].attrs.create(name='Human_Disturbance', data=15)\n",
" f['events']['labels'].attrs.create(name='Large_Sawtooth', data=16)\n",
" f['events']['labels'].attrs.create(name='Cosinus_Tail', data=17)\n",
" f['events']['labels'].attrs.create(name='Light_only_Event', data=18)\n",
" f['events']['labels'].attrs.create(name='Ring_Light_Event', data=19)\n",
" f['events']['labels'].attrs.create(name='Sharp_Light_Event', data=20)\n",
" f['events']['labels'].attrs.create(name='Reset', data=21)\n",
" f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Event_Pulse', data=22)\n",
" f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Tail_Pulse', data=23)\n",
" f['events']['labels'].attrs.create(name='unknown/other', data=99)\n",
"\n",
" # f['saturation']['fitpar'].attrs.create(name='A', data=0)\n",
" # f['saturation']['fitpar'].attrs.create(name='K', data=1)\n",
" # f['saturation']['fitpar'].attrs.create(name='C', data=2)\n",
" # f['saturation']['fitpar'].attrs.create(name='Q', data=3)\n",
" # f['saturation']['fitpar'].attrs.create(name='B', data=4)\n",
" # f['saturation']['fitpar'].attrs.create(name='nu', data=5)\n",
"\n",
" bar.update(BATCHSIZE)\n",
" bcount += 1\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:59:40.711123Z",
"start_time": "2021-11-30T08:59:40.595233Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with h5py.File(FNAME, 'r+') as f:\n",
" idx = 5\n",
" plt.plot(f['events']['event'][0, idx])\n",
" for i in range(NMBR_PILEUP):\n",
" plt.plot(f['events']['pulse_traces'][0, idx, i])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T08:59:54.415313Z",
"start_time": "2021-11-30T08:59:54.330667Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"with h5py.File('test_data/pileup_v0_6.h5', 'r+') as f:\n",
" idx = 8\n",
" plt.plot(f['events']['event'][0, idx], color='black', alpha=0.05, linewidth=0.5)\n",
" plt.axis('off')\n",
" plt.savefig('post_background.pdf')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Done."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the separation model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We simulate the non-pileuped pulses and try to reconstruct the individual ones. We use an LSTM model."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:40.535703Z",
"start_time": "2021-11-30T09:05:40.521160Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"from cait.cuts import LogicalCut\n",
"from cait.fit import scaled_logistic_curve, pulse_template\n",
"from cait.models import SeparationLSTM\n",
"import os\n",
"from pytorch_lightning import Trainer\n",
"from torchvision import transforms\n",
"import h5py\n",
"from cait.datasets import RemoveOffset, Normalize, ToTensor, CryoDataModule, PileUpDownSample\n",
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import torch\n",
"from tqdm.auto import trange, tqdm\n",
"%config InlineBackend.figure_formats = ['svg'] # we need this for a suitable resolution of the plots"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:41.132283Z",
"start_time": "2021-11-30T09:05:41.125632Z"
}
},
"outputs": [],
"source": [
"fontsize = 14\n",
"\n",
"mpl.rcParams['xtick.labelsize'] = fontsize\n",
"mpl.rcParams['ytick.labelsize'] = fontsize\n",
"mpl.rcParams['font.size'] = fontsize\n",
"mpl.rcParams['axes.titlesize'] = fontsize\n",
"mpl.rcParams['axes.labelsize'] = fontsize\n",
"mpl.rcParams['legend.fontsize'] = fontsize\n",
"mpl.rcParams['mathtext.fontset'] = 'stix'\n",
"mpl.rcParams['font.family'] = 'STIXGeneral'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:43.038688Z",
"start_time": "2021-11-30T09:05:43.034881Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"False\n"
]
}
],
"source": [
"print(torch.cuda.is_available())\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:52.308847Z",
"start_time": "2021-11-30T09:05:52.303190Z"
}
},
"outputs": [],
"source": [
"# some parameters\n",
"# nmbr_gpus = ... uncommment and put in trainer to use GPUs\n",
"path_h5 = 'test_data/pileup_v0_6.h5'\n",
"type = 'events' # the group key for the data in the HDF5 set\n",
"keys = ['event', 'pulse_traces'] # the datasets in the group from which we include data in the samples for the NN\n",
"channel_indices = [[0], [0]] # the first indices of the datasets\n",
"feature_indices = [None, None] # the third indices of the datasets\n",
"feature_keys = ['event_ch0'] # the keys in the samples of the NN dataset that are input to the NN\n",
" # in the data set for the NN, the keys have additionally appended the channel index \n",
"label_keys = ['pulse_traces_ch0'] # the keys in the samples of the NN dataset that are labels to the NN\n",
"norm_vals = {'event_ch0': [0, 0.5], 'pulse_traces_ch0': [0, 0.5]} # we do a min - max normalization of all samples, so these are roughly the lowest and highest values of the events in the data set\n",
"down_keys = ['event_ch0', 'pulse_traces_ch0'] # if we input the raw time series, we apply downsampling first\n",
"down = 64 # all samples in the NN dataset with the indices specified above are by the factor down downsampled\n",
"record_length = 16384\n",
"max_epochs = 5 # the maximal trianing epochs of the neural network\n",
"\n",
"save_naming = 'lstm' # 'lstm' or 'unet'\n",
"nmbr_pileup = 2\n",
"plots = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset and Model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:54.224298Z",
"start_time": "2021-11-30T09:05:54.220334Z"
}
},
"outputs": [],
"source": [
"# create the transforms\n",
"trans = transforms.Compose([# RemoveOffset(keys=feature_keys),\n",
" Normalize(norm_vals=norm_vals),\n",
" PileUpDownSample(keys=down_keys, down=down),\n",
" ToTensor()])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:05:57.470949Z",
"start_time": "2021-11-30T09:05:57.466483Z"
}
},
"outputs": [],
"source": [
"# create data module and init the setup\n",
"dm = CryoDataModule(hdf5_path=path_h5,\n",
" type=type,\n",
" keys=keys,\n",
" channel_indices=channel_indices,\n",
" feature_indices=feature_indices,\n",
" transform=trans)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:06:11.691502Z",
"start_time": "2021-11-30T09:06:11.529252Z"
}
},
"outputs": [],
"source": [
"dm.prepare_data(val_size=0.1,\n",
" test_size=0.2,\n",
" batch_size=64,\n",
" dataset_size=None,\n",
" nmbr_workers=0, # set to number of CPUS on the machine (strange fact: for me 0 is faster than 8?) probabily inefficient implementation of our dataset for multithreading (issue on GitLab is open)\n",
" only_idx=None,\n",
" shuffle_dataset=True,\n",
" random_seed=21, # 21, 1, 2, 3, 4\n",
" feature_keys=feature_keys,\n",
" label_keys=label_keys,\n",
" keys_one_hot=label_keys)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:06:12.975444Z",
"start_time": "2021-11-30T09:06:12.962512Z"
}
},
"outputs": [],
"source": [
"dm.setup()"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:12:59.183608Z",
"start_time": "2021-11-30T09:12:59.171340Z"
}
},
"outputs": [],
"source": [
"if save_naming == 'lstm':\n",
" \n",
" ips = 16\n",
" \n",
" model = SeparationLSTM(\n",
" input_size=ips,\n",
" hidden_size=ips * 10,\n",
" num_layers=3,\n",
" seq_steps=int(dm.dims[1] / ips), # downsampling is already considered in dm\n",
" nmbr_pileup=nmbr_pileup, \n",
" device_name=device,\n",
" lr=1e-3,\n",
" feature_keys=feature_keys,\n",
" label_keys=label_keys,\n",
" down=down,\n",
" down_keys=feature_keys,\n",
" norm_vals=norm_vals,\n",
" offset_keys=feature_keys,\n",
" )\n",
"\n",
"else:\n",
" raise NotImplementedError"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## View Events"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:01.020268Z",
"start_time": "2021-11-30T09:13:00.565617Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"if plots:\n",
"\n",
" mpl.rcParams['figure.figsize'] = (6, 6)\n",
"\n",
" width=4\n",
" fig, axes = plt.subplots(nrows=width, ncols=width)\n",
" for i in range(width**2):\n",
" if i < 8:\n",
" idx = i\n",
" else:\n",
" idx = i + 100\n",
" example = dm.dataset_full[idx]['event_ch0']\n",
" axes[int(i/width),int(i%width)].plot(example, color='black')\n",
" axes[int(i/width),int(i%width)].axis('off')\n",
" # axes[int(i/width),int(i%width)].set_ylim((-0.1,1.1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:02.825213Z",
"start_time": "2021-11-30T09:13:02.821258Z"
}
},
"outputs": [],
"source": [
"checkpoint_callback = ModelCheckpoint(dirpath='callbacks',\n",
" monitor='val_loss',\n",
" filename=save_naming + '-{epoch:02d}-{val_loss:.2f}')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:03.793326Z",
"start_time": "2021-11-30T09:13:03.783201Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n"
]
}
],
"source": [
"trainer = Trainer(max_epochs=max_epochs,\n",
" callbacks=[checkpoint_callback],\n",
" gpus=0)\n",
"# keyword gpus=nmbr_gpus for GPU Usage\n",
"# keyword max_epochs for number of maximal epochs"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:24.219762Z",
"start_time": "2021-11-30T09:13:04.711395Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params\n",
"--------------------------------\n",
"0 | lstm | LSTM | 526 K \n",
"1 | fc | Linear | 5 K \n",
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7801e3010919423a90dd8d0bea3173ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:216: UserWarning:\n",
"\n",
"Please also save or load the state of the optimizer when saving or loading the scheduler.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.fit(model=model,\n",
" datamodule=dm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:26.479144Z",
"start_time": "2021-11-30T09:13:26.475711Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best model: /Users/felix/PycharmProjects/cait/docs/source/tutorials/callbacks/lstm-epoch=03-val_loss=0.00.ckpt\n"
]
}
],
"source": [
"print('Best model: ', checkpoint_callback.best_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:27.714413Z",
"start_time": "2021-11-30T09:13:27.697118Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"SeparationLSTM(\n",
" (lstm): LSTM(16, 160, num_layers=3, batch_first=True)\n",
" (fc): Linear(in_features=160, out_features=32, bias=True)\n",
")"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_from_checkpoint(checkpoint_callback.best_model_path)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:28.182342Z",
"start_time": "2021-11-30T09:13:28.179401Z"
}
},
"outputs": [],
"source": [
"# model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:29.778323Z",
"start_time": "2021-11-30T09:13:29.272276Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning:\n",
"\n",
"The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "41b85361d1b74a41b59dfb9dbac2e2f8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"DATALOADER:0 TEST RESULTS\n",
"{'test_loss': tensor(0.0028),\n",
" 'train_loss': tensor(0.0022),\n",
" 'val_loss': tensor(0.0029)}\n",
"--------------------------------------------------------------------------------\n",
"\n",
"[{'train_loss': 0.0021677217446267605, 'val_loss': 0.002877358580008149, 'test_loss': 0.002773477230221033}]\n"
]
}
],
"source": [
"result = trainer.test(model=model,\n",
" datamodule=dm)\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"ExecuteTime": {
"end_time": "2021-11-30T09:13:32.746169Z",
"start_time": "2021-11-30T09:13:32.467439Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"