Cryogenic Detector Data Augmentation

Correspondence to: felix.wagner@oeaw.ac.at

In this notebook we explore the data augmentation module of the Cait Python package. We are going to simulate a dataset of realistic events, including two different pulse shapes (absorber and carrier recoils), and a set of artifacts.

from cait.augment import ParameterSampler, plot_events, unfold, L2
import numpy as np
import cait as ai  # install from the develop branch
import matplotlib.pyplot as plt
import matplotlib as mpl
import h5py
from tqdm.auto import tqdm
from scipy import signal
from sklearn.ensemble import RandomForestClassifier
%config InlineBackend.figure_formats = ['svg']

Some definitions for the plots.

mpl.rcParams['figure.figsize'] = (6, 4)
mpl.rcParams['savefig.dpi'] = 300

Define the global parameters. These define the most important info about the augmented data: Record length, sample frequency and resolution (standard deviation) of the noise.

There is also the possibility to add polynomial drift structures to the baselines, which we observe in measured baselines very often. The rasterization feature introduces a discrete sampling effect, as we would observe it from the 16 bit precision of the digitizer. The saturation applies a saturation curve to all events, as it would happen in any real TES measurement.

The class names list includes the names of all classes that we want to simulate in the dataset. These are all the classes that are currently available within cait.

For the whole notebook: Only capitalized parameters need to be set by the user.

RECORD_LENGTH = 16384
SAMPLE_FREQUENCY = 25000
RESOLUTION = 0.01  # 0.003

POLYNOMIAL_DRIFTS = False
RASTERIZE = True
SQUARE_WAVES = False
SATURATION = True
CLASS_NAMES = ['Event Pulse',
               'Noise',
               'Decaying Baseline',
               'Temperature Rise',
               'Spike',
               'Squid Jump',
               'Reset',
               'Cosinus Tail',
               'Decaying Baseline with Event Pulse',
               'Pile Up',
               'Early or late Trigger',
               'Carrier Event',
               'Decaying Baseline with Tail Event',
               ]


# ----------------------------------------- 
# no need to change below parameters
# ----------------------------------------- 

label_names = {
    'unlabeled': 0,
    'Event Pulse': 1,
    'Test/Control Pulse': 2,
    'Noise': 3,
    'Squid Jump': 4,
    'Spike': 5,
    'Early or late Trigger': 6,
    'Pile Up': 7,
    'Carrier Event': 8,
    'Strongly Saturated Event Pulse': 9,
    'Strongly Saturated Test/Control Pulse': 10,
    'Decaying Baseline': 11,
    'Temperature Rise': 12,
    'Stick Event': 13,
    'Square Waves': 14,
    'Human Disturbance': 15,
    'Large Sawtooth': 16,
    'Cosinus Tail': 17,
    'Light only Event': 18,
    'Ring & Light Event': 19,
    'Sharp Light Event': 20,
    'Reset': 21,
    'Decaying Baseline with Event Pulse': 22,
    'Decaying Baseline with Tail Event': 23,
    'unknown/other': 99,
}

Initializa an instance of the sampler class. This class is responsible for the simulation of all event, but we need to set all the arguments. Otherwise the arguments will be randomly sampled (see notebook Universal Training Set).

parsam = ParameterSampler(record_length=RECORD_LENGTH,
                          sample_frequency=SAMPLE_FREQUENCY)

We define the detector resolution. The measured resolution will deviate from the set resolution, depending on the noise power spectrum and the method for reconstructing the resolution.

parsam.set_args(resolution=np.array([RESOLUTION]))

We define the pulse shapes. You can either ask the ParameterSampler instance to generate some shapes randomly (you can always resample, in case you don’t like them), or you put your own parameters. The pulses follow the pulse shape model for cryogenic TES detectors, introduced by Franz Pröbst. In the cell below, we commented to option to randomly sample pulse shapes out, but you can always uncomment it (and comment the “define your own” section instead).

# sample pulse shapes

# ps1, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))
# ps2, _ = parsam.sample_pulse_par(size=1, t0=np.array([0.]))
# print(ps1)
# print(ps2)

# or define your own pulse shapes ...
ps1 = {'t0': np.array([0.]), 
       'tau_t': np.array([0.05815608]), 
       'tau_in': np.array([0.02059209]), 
       'tau_n': np.array([0.01395427]), 
       'An': np.array([-0.02508469])/0.0715207609981429, 
       'At': np.array([0.14102789])/0.0715207609981429}
ps2 = {'t0': np.array([0.]), 
       'tau_t': np.array([0.01730027]), 
       'tau_in': np.array([0.00128385]), 
       'tau_n': np.array([0.00815371]), 
       'An': np.array([2.59789859])/1.553262177826752,
       'At': np.array([0.02694577])/1.553262177826752}

for i,p in enumerate([ps1, ps2]):

    event = ai.fit.pulse_template(parsam.t, **unfold(p, 1))
    plot_events(event.reshape(1,-1), t=parsam.t, show=True, text=['Standard Event'])
    
parsam.set_args(pulse_shapes=[[ps1['t0'], ps1['An'], ps1['At'], ps1['tau_n'], ps1['tau_in'], ps1['tau_t']],
                              [ps2['t0'], ps2['An'], ps2['At'], ps2['tau_n'], ps2['tau_in'], ps2['tau_t']],
                              ])
../_images/12cryoaug_13_0.svg../_images/12cryoaug_13_1.svg

We define the noise power spectrum, with randomly sampled parameters. You can always resample, in case you don’t like it.

The noise power spectrum follows a parametric shape: Two \(1/f^a\) functions are superposed, with \(a\) sampled in between 1 and 2. The constant component is forced to zero. A low pass filter starts above 1e4 Hz. The characteristic 50 Hz Peak, including its first and second harmonic, are sampled with heights relative to the overall NPS height.

# sample nps

parsam.set_args(nps=None)

noise_par, info = parsam.sample_noise(size=1)

# or define your own nps ...
# noise_par['nps'] = nps[:, 1].reshape(1,-1)
# info['fq'] = np.fft.rfftfreq(n=RECORD_LENGTH, d=1 / SAMPLE_FREQUENCY)

# plot begin -------

plt.close()

fig, ax = plt.subplots(constrained_layout=True, figsize=(6,2.5))

ax.loglog(info['fq'], noise_par['nps'][0], color='black', linewidth=2.5, zorder=50) #ab9c73 #6a7d8e

# eye guidelines

ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('Amplitude (V^2 / Hz)')
# ax.text(x=.6,y=.5,s='Noise Power Spectrum', 
#         transform=ax.transAxes,
#         bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))
plt.savefig('test_data/nps.pdf')
plt.show()

parsam.set_args(nps=noise_par['nps'])
../_images/12cryoaug_15_0.svg

We define the saturation curve. It is modeled with a generalized logistics function, we put some parameters explicitely, but you can also randomy sample new ones.

# sat = parsam.sample_saturation(size=1)
# print(sat)

# or define your own saturation ...
sat = {'K': np.array([8.8786987]),
       'C': np.array([1.23606841]),
       'Q': np.array([2.12402435]),
       'B': np.array([0.57876924]),
       'nu': np.array([1.61744008]),
       'A': np.array([-7.95906284])}

plt.close()

fig, ax = plt.subplots(constrained_layout=True, figsize=(6,2.5))

ax.plot(np.arange(0, 10, 0.1), ai.fit.scaled_logistic_curve(np.arange(0, 10, 0.1), **sat), color='black', linewidth=2.5)
ax.set_xlabel('True pulse height (V)')
ax.set_ylabel('Saturated pulse height (V)')

# ax.text(x=.6,y=.5,s='Saturation Curve', transform=ax.transAxes,
#         bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))

plt.savefig('test_data/sat.pdf')
plt.show()

parsam.set_args(saturation_pars=sat)
../_images/12cryoaug_17_0.svg

Now we are ready to plot some augmented events!

NMBR_PLOTS = 6

classes = np.random.choice(CLASS_NAMES, size=NMBR_PLOTS)
classes[0] = CLASS_NAMES[0]
classes[1] = CLASS_NAMES[1]
classes[2] = CLASS_NAMES[11]
# classes = [CLASS_NAMES[0] for i in range(NMBR_PLOTS)]

for i in range(NMBR_PLOTS):

    event, info = parsam.get_event(label=classes[i],
                                   size=4,
                                   rasterize=RASTERIZE,
                                   poly=POLYNOMIAL_DRIFTS,
                                   square=SQUARE_WAVES,
                                   saturation=SATURATION,
                                   verb=False,
                                   )

    print(classes[i])
    plot_events(event, t=parsam.t)
    if classes[i] == 'Event Pulse':
        print('Resolution (only meaningful without Saturation): ', np.std(np.abs(np.max(event, axis=1) - info['pulse_height'])))
Event Pulse
../_images/12cryoaug_19_1.svg
Resolution (only meaningful without Saturation):  0.09678070047987046
Noise
../_images/12cryoaug_19_3.svg
Carrier Event
../_images/12cryoaug_19_5.svg
Decaying Baseline
../_images/12cryoaug_19_7.svg
Carrier Event
../_images/12cryoaug_19_9.svg
Cosinus Tail
../_images/12cryoaug_19_11.svg
# plot for the paper
classes = ['Event Pulse',
 'Noise',
 'Spike',
 'Reset',
 'Decaying Baseline with Event Pulse',
 'Pile Up',
 #'Early or late Trigger',
          ]
text = ['(i) Target recoil', '(ii) Noise trigger', '(iii) Spike from electronics', 
        '(iv) Flux quantum loss', '(v) Decaying baseline', '(vi) Pile-up', 
        #'Early trigger',
       ]
nmbr_x = 3
nmbr_y = 2

fig, axs = plt.subplots(nmbr_y, nmbr_x, constrained_layout=True, figsize=(12,4))

clwheel = ['black', 'black']

for i,c in enumerate(classes):

    event, info = parsam.get_event(label=c,
                                   size=1,
                                   rasterize=RASTERIZE,
                                   poly=POLYNOMIAL_DRIFTS,
                                   square=SQUARE_WAVES,
                                   saturation=SATURATION,
                                   verb=False,
                                   )
    
    axs[int(i / nmbr_x), int(i % nmbr_x)].plot(parsam.t, event[0], color=clwheel[int(i%2)], linewidth=2)  
    
    axs[int(i / nmbr_x), int(i % nmbr_x)].set_title(text[i])
    
#     axs[int(i / nmbr_x), int(i % nmbr_x)].text(x=.6, y=.5, s=text[i],
#                                                            transform=axs[int(i / nmbr_x), int(i % nmbr_x)].transAxes,
#                                                            bbox=dict(boxstyle='square', fc='white', alpha=0.8, ec='k'))
    
fig.supxlabel('Time (ms)')
fig.supylabel('Amplitude (V)')
plt.savefig('test_data/events_plot.pdf')
plt.show()
../_images/12cryoaug_20_0.svg

Now we are ready to create an HDF5 data set with these simulated events!

To change the number of simulated events, change the EVENTS_PER_CLASS dictionary. Please notice, that the number of events per class always has to be a multiple of the BATCHSIZE, otherwise the simulation will be spoiled.

FNAME = 'test_data/test_v0_1.h5'  # you can change this to your desired name
IDX = list(range(13))  # which to take from the class names
BATCHSIZE = 50
EVENTS_PER_CLASS = {  # use always multiple of batch size!
    'Event Pulse': 8000,
    'Noise': 3000,
    'Decaying Baseline': 500,
    'Temperature Rise': 500,
    'Spike': 500,
    'Squid Jump': 500,
    'Reset': 500,
    'Cosinus Tail': 500,
    'Decaying Baseline with Event Pulse': 500,
    'Decaying Baseline with Tail Event': 500,
    'Pile Up': 9000,
    'Early or late Trigger': 500,
    'Carrier Event': 8000,
}
DIV = 1
TOTAL_EVENTS = int(np.sum(list(EVENTS_PER_CLASS.values()))/DIV)
BATCHSIZE = int(BATCHSIZE/DIV)

# -------------------------------------------
# no changes required below this line!
# -------------------------------------------

with h5py.File(FNAME, 'w') as f:
    f.require_group('events')
    f['events'].create_dataset('event',
                                shape=(1, TOTAL_EVENTS, RECORD_LENGTH),
                                dtype=np.float32)
    f['events'].create_dataset('labels',
                                shape=(1, TOTAL_EVENTS),
                                dtype=int)
    f['events'].create_dataset('true_ph',
                                shape=(1, TOTAL_EVENTS),
                                dtype=float)
    f['events'].create_dataset('true_onset',
                                shape=(TOTAL_EVENTS, ),
                                dtype=float)

    f.require_group('saturation')
    f['saturation'].create_dataset('fitpar',
                                    data=np.array([sat['A'], sat['K'], sat['C'], sat['Q'], sat['B'], sat['nu'], ]))

    bar = tqdm(total=TOTAL_EVENTS)
    bcount = 0

    for i in IDX:

        bar.write('Simulating {} ...'.format(CLASS_NAMES[i]))

        nmbr_events = int(EVENTS_PER_CLASS[CLASS_NAMES[i]] / DIV)
        batches = int(nmbr_events / BATCHSIZE)

        for b in range(batches):
            event, info = parsam.get_event(label=CLASS_NAMES[i],
                                           size=BATCHSIZE,
                                           rasterize=RASTERIZE,
                                           poly=POLYNOMIAL_DRIFTS,
                                           square=SQUARE_WAVES,
                                           saturation=SATURATION,
                                           verb=False,
                                           )

            f['events']['event'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE), :] = event
            f['events']['labels'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = label_names[CLASS_NAMES[i]]
            f['events']['true_ph'][0, int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['pulse_height']
            f['events']['true_onset'][int(bcount * BATCHSIZE):int((bcount + 1) * BATCHSIZE)] = info['t0']*1000

            # attributes

            f['events']['labels'].attrs.create(name='unlabeled', data=0)
            f['events']['labels'].attrs.create(name='Event_Pulse', data=1)
            f['events']['labels'].attrs.create(name='Test/Control_Pulse', data=2)
            f['events']['labels'].attrs.create(name='Noise', data=3)
            f['events']['labels'].attrs.create(name='Squid_Jump', data=4)
            f['events']['labels'].attrs.create(name='Spike', data=5)
            f['events']['labels'].attrs.create(name='Early_or_late_Trigger', data=6)
            f['events']['labels'].attrs.create(name='Pile_Up', data=7)
            f['events']['labels'].attrs.create(name='Carrier_Event', data=8)
            f['events']['labels'].attrs.create(name='Strongly_Saturated_Event_Pulse', data=9)
            f['events']['labels'].attrs.create(name='Strongly_Saturated_Test/Control_Pulse', data=10)
            f['events']['labels'].attrs.create(name='Decaying_Baseline', data=11)
            f['events']['labels'].attrs.create(name='Temperature_Rise', data=12)
            f['events']['labels'].attrs.create(name='Stick_Event', data=13)
            f['events']['labels'].attrs.create(name='Square_Waves', data=14)
            f['events']['labels'].attrs.create(name='Human_Disturbance', data=15)
            f['events']['labels'].attrs.create(name='Large_Sawtooth', data=16)
            f['events']['labels'].attrs.create(name='Cosinus_Tail', data=17)
            f['events']['labels'].attrs.create(name='Light_only_Event', data=18)
            f['events']['labels'].attrs.create(name='Ring_Light_Event', data=19)
            f['events']['labels'].attrs.create(name='Sharp_Light_Event', data=20)
            f['events']['labels'].attrs.create(name='Reset', data=21)
            f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Event_Pulse', data=22)
            f['events']['labels'].attrs.create(name='Decaying_Baseline_with_Tail_Pulse', data=23)
            f['events']['labels'].attrs.create(name='unknown/other', data=99)

            f['saturation']['fitpar'].attrs.create(name='A', data=0)
            f['saturation']['fitpar'].attrs.create(name='K', data=1)
            f['saturation']['fitpar'].attrs.create(name='C', data=2)
            f['saturation']['fitpar'].attrs.create(name='Q', data=3)
            f['saturation']['fitpar'].attrs.create(name='B', data=4)
            f['saturation']['fitpar'].attrs.create(name='nu', data=5)

            bar.update(BATCHSIZE)
            bcount += 1
Simulating Event Pulse ...
Simulating Noise ...
Simulating Decaying Baseline ...
Simulating Temperature Rise ...
Simulating Spike ...
Simulating Squid Jump ...
Simulating Reset ...
Simulating Cosinus Tail ...
Simulating Decaying Baseline with Event Pulse ...
Simulating Pile Up ...
Simulating Early or late Trigger ...
Simulating Carrier Event ...
Simulating Decaying Baseline with Tail Event ...

Now use the VizTool to look at these events. Try to calculate a nice SEV! Not so easy, eh?

dh = ai.DataHandler(nmbr_channels=1,
              sample_frequency=SAMPLE_FREQUENCY,
              record_length=RECORD_LENGTH)
# dh.set_filepath('test_data', 'test_v0_1', appendix=False)
dh.set_filepath('test_data', 'test_v0_2', appendix=False)
DataHandler Instance created.
dh.content()
The following properties are in the HDF5 sets can be accessed through the get(group, dataset) methode.
The following data sets are contained in the the group events:
dataset: add_mainpar, shape: (1, 32500, 16)
dataset: cnn_cut, shape: (1, 32500)
dataset: cnn_prob, shape: (1, 32500, 2)
dataset: event, shape: (1, 32500, 16384)
dataset: labels, shape: (1, 32500)
dataset: mainpar, shape: (1, 32500, 10)
dataset: true_onset, shape: (32500,)
dataset: true_ph, shape: (1, 32500)
dataset: pulse_height, shape: (1, 32500)
dataset: onset, shape: (1, 32500)
dataset: rise_time, shape: (1, 32500)
dataset: decay_time, shape: (1, 32500)
dataset: slope, shape: (1, 32500)
dataset: array_max, shape: (1, 32500)
dataset: array_min, shape: (1, 32500)
dataset: var_first_eight, shape: (1, 32500)
dataset: mean_first_eight, shape: (1, 32500)
dataset: var_last_eight, shape: (1, 32500)
dataset: mean_last_eight, shape: (1, 32500)
dataset: var, shape: (1, 32500)
dataset: mean, shape: (1, 32500)
dataset: skewness, shape: (1, 32500)
dataset: max_derivative, shape: (1, 32500)
dataset: ind_max_derivative, shape: (1, 32500)
dataset: min_derivative, shape: (1, 32500)
dataset: ind_min_derivative, shape: (1, 32500)
dataset: max_filtered, shape: (1, 32500)
dataset: ind_max_filtered, shape: (1, 32500)
dataset: skewness_filtered_peak, shape: (1, 32500)
The following data sets are contained in the the group saturation:
dataset: fitpar, shape: (6, 1)

First we calculate the main parameters and additional main parameters.

dh.calc_mp()
dh.calc_additional_mp(no_of=True)
CALCULATE MAIN PARAMETERS.
CALCULATE ADDITIONAL MAIN PARAMETERS.
ckp_path = ai.resources.get_resource_path('cnn-clf-binary-v1.ckpt')

ai.models.nn_predict(h5_path='test_data/test_v0_2.h5',
           model=ai.models.CNNModule.load_from_checkpoint(ckp_path),
           feature_channel=0,
           group_name='events',
           prediction_name='cnn_cut',
           keys=['event'],
           no_channel_idx_in_pred=False,
           use_prob=False)
cnn_prob written to file test_data/test_v0_2.h5.
# calc metrics

fair_labels = dh.get('events', 'labels')[0]
fair_labels[fair_labels == 8] = 1

event_preds = dh.get('events', 'cnn_cut')[0, fair_labels == 1]
artifact_preds = dh.get('events', 'cnn_cut')[0, fair_labels != 1]

accuracy = (np.sum(event_preds) + np.sum(np.logical_not(artifact_preds)))/(len(event_preds) + len(artifact_preds))
precision = np.sum(event_preds)/(np.sum(event_preds) + np.sum(artifact_preds))
recall = np.sum(event_preds)/len(event_preds)
efficiency = np.sum(event_preds)/len(event_preds)
f_score = 2*(precision * recall)/(precision + recall)

for var in ['accuracy', 'precision', 'recall', 'efficiency', 'f_score']:
    print(var, eval(var))
accuracy 0.9045230769230769
precision 0.8534392984379282
recall 0.9731875
efficiency 0.9731875
f_score 0.9093882318586654
# calc resolution
clean_unsat = np.logical_and(dh.get('events', 'labels')[0] == 1, dh.get('events', 'pulse_height')[0] < 1)
clean_unsat = np.logical_and(clean_unsat, dh.get('events', 'pulse_height')[0] > 0.006)
resolution = np.sqrt(np.mean((dh.get('events', 'pulse_height')[0, clean_unsat] - dh.get('events', 'true_ph')[0, clean_unsat])**2))
print('resolution', resolution)
resolution 0.002472199846284186
plt.scatter(dh.get('events', 'true_ph')[0, clean_unsat], 
            dh.get('events', 'pulse_height')[0, clean_unsat], marker='o', s=10, rasterized=True)
plt.scatter(dh.get('events', 'true_ph')[0, clean_unsat_surv], 
            dh.get('events', 'pulse_height')[0, clean_unsat_surv], marker='o', s=10, rasterized=True)
plt.xlabel('True PH')
plt.ylabel('PH')
plt.show()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-18-c9ed85df4cfa> in <module>
      1 plt.scatter(dh.get('events', 'true_ph')[0, clean_unsat], 
      2             dh.get('events', 'pulse_height')[0, clean_unsat], marker='o', s=10, rasterized=True)
----> 3 plt.scatter(dh.get('events', 'true_ph')[0, clean_unsat_surv], 
      4             dh.get('events', 'pulse_height')[0, clean_unsat_surv], marker='o', s=10, rasterized=True)
      5 plt.xlabel('True PH')

NameError: name 'clean_unsat_surv' is not defined
../_images/12cryoaug_31_1.svg
confusion_matrix = np.zeros((24, 2), dtype=int)
for i,j in zip(dh.get('events', 'labels')[0], dh.get('events', 'cnn_cut')[0]):
    confusion_matrix[i,int(j)] += 1
not_zero = np.sum(confusion_matrix, axis=1) > 0
nmbr_not_zero = np.sum(not_zero)
print('Nmbr used classes: ', nmbr_not_zero)
confusion_matrix_squeeze = confusion_matrix[not_zero]
Nmbr used classes:  13
import seaborn as sns

labels = np.unique(dh.get('events', 'labels')[0])

ax = sns.heatmap(confusion_matrix_squeeze, # /np.sum(confusion_matrix_squeeze, axis=1)
                 annot=True, fmt='f', linewidths=.5, xticklabels=['Cut','Survived'], yticklabels=labels)
plt.ylabel("Label") 
plt.show()
../_images/12cryoaug_34_0.svg
    # 'unlabeled': 0,
    # 'Event Pulse': 1,
    # 'Test/Control Pulse': 2,
    # 'Noise': 3,
    # 'Squid Jump': 4,
    # 'Spike': 5,
    # 'Early or late Trigger': 6,
    # 'Pile Up': 7,
    # 'Carrier Event': 8,
    # 'Strongly Saturated Event Pulse': 9,
    # 'Strongly Saturated Test/Control Pulse': 10,
    # 'Decaying Baseline': 11,
    # 'Temperature Rise': 12,
    # 'Stick Event': 13,
    # 'Square Waves': 14,
    # 'Human Disturbance': 15,
    # 'Large Sawtooth': 16,
    # 'Cosinus Tail': 17,
    # 'Light only Event': 18,
    # 'Ring & Light Event': 19,
    # 'Sharp Light Event': 20,
    # 'Reset': 21,
    # 'Decaying Baseline with Event Pulse': 22,
    # 'Decaying Baseline with Tail Event': 23,
    # 'unknown/other': 99,
with h5py.File('test_data/test_v0_2.h5', 'r') as f:
    
    fig, axs = plt.subplots(3, 4, constrained_layout=True, figsize=(12,5))
    
    # pile ups
    idx = np.extract(np.logical_and(dh.get('events', 'labels') == 7, dh.get('events', 'cnn_cut') == 0), np.arange(f['events']['event'].shape[1]))
    i = 0
    ev = f['events']['event'][0, idx[i]]
    axs[0,0].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[0,0].text(-0.1,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    i = 1
    ev = f['events']['event'][0, idx[i]]
    axs[1,0].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[1,0].text(-0.1,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    i = 2
    ev = f['events']['event'][0, idx[i]]
    axs[2,0].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[2,0].text(-0.1,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    
    # 
    idx = np.extract(np.logical_and(dh.get('events', 'labels') == 7, dh.get('events', 'cnn_cut') == 1), np.arange(f['events']['event'].shape[1]))
    i = 3
    ev = f['events']['event'][0, idx[i]]
    axs[0,1].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[0,1].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    i = 4
    ev = f['events']['event'][0, idx[i]]
    axs[1,1].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[1,1].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    i = 5
    ev = f['events']['event'][0, idx[i]]
    axs[2,1].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[2,1].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    
    # 
    idx = np.extract(np.logical_and(dh.get('events', 'labels') == 23, dh.get('events', 'cnn_cut') == 0), np.arange(f['events']['event'].shape[1]))
    i = 0
    ev = f['events']['event'][0, idx[i]]
    axs[0,2].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[0,2].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    i = 1
    ev = f['events']['event'][0, idx[i]]
    axs[1,2].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[1,2].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    i = 2
    ev = f['events']['event'][0, idx[i]]
    axs[2,2].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[2,2].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='red')
    
    
    # decaying baseline with tail event
    idx = np.extract(np.logical_and(dh.get('events', 'labels') == 23, dh.get('events', 'cnn_cut') == 1), np.arange(f['events']['event'].shape[1]))
    i = 3
    ev = f['events']['event'][0, idx[i]]
    axs[0,3].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[0,3].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    i = 4
    ev = f['events']['event'][0, idx[i]]
    axs[1,3].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[1,3].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    i = 5
    ev = f['events']['event'][0, idx[i]]
    axs[2,3].plot(parsam.t, ev, color='black', linewidth=2.5)
    true_ph = dh.get('events', 'true_ph')[0, idx[i]]
    axs[2,3].text(0.2,0.85*np.max(ev),f"CNN out: {np.exp(dh.get('events', 'cnn_prob')[0, idx[i], 1]):.2f}", color='green')
    
    
    fig.supxlabel('Time (ms)')
    fig.supylabel('Amplitude (V)')
    plt.savefig('test_data/false_positive.pdf')
    plt.show()
    
    #ab9c73 #6a7d8e
../_images/12cryoaug_37_0.svg
# calc energy dependent efficiency

clean_unsat = dh.get('events', 'labels')[0] == 1  # np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 8)
clean_unsat = np.logical_and(clean_unsat, dh.get('events', 'pulse_height')[0] < 1)
correct = np.logical_and(clean_unsat, dh.get('events', 'cnn_cut')[0])

clean_unsat_2 = dh.get('events', 'labels')[0] == 8  # np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 8)
clean_unsat_2 = np.logical_and(clean_unsat_2, dh.get('events', 'pulse_height')[0] < 1)
correct_2 = np.logical_and(clean_unsat_2, dh.get('events', 'cnn_cut')[0])

bins = np.linspace(0.001, 0.05, 50)
all_ , _ = np.histogram(dh.get('events', 'pulse_height')[0, clean_unsat], bins=bins)
surv_ , _ = np.histogram(dh.get('events', 'pulse_height')[0, correct], bins=bins)

all_2 , _ = np.histogram(dh.get('events', 'pulse_height')[0, clean_unsat_2], bins=bins)
surv_2 , _ = np.histogram(dh.get('events', 'pulse_height')[0, correct_2], bins=bins)

all_noise, _ = np.histogram(dh.get('events', 'pulse_height')[0, dh.get('events', 'labels')[0] == 3], bins=bins)
surv_noise, _ = np.histogram(dh.get('events', 'pulse_height')[0, np.logical_and(dh.get('events', 'labels')[0] == 3, dh.get('events', 'cnn_cut')[0] == 1)], bins=bins)

fig, ax = plt.subplots(1,1,figsize=(7.2, 4.45))
ax.hist(bins[:-1], bins, weights=surv_2/all_2, histtype='step', color='#deebf7', linewidth=2.5, label='Survival rate short pulse shape')
ax.hist(bins[:-1], bins, weights=surv_/all_, histtype='step', color='#9ecae1', linewidth=2.5, label='Survival rate long pulse shape')
ax.hist(bins[:-1], bins, weights=surv_noise/all_noise, histtype='step', color='#3182bd', linewidth=2.5, label='Survival rate noise')
ax.axvline(x=5*resolution, color='black', linestyle='dashed', linewidth=2, label='5 sigma threshold')
ax.legend(loc='lower right')
ax.set_xlim(0,0.05)
ax.set_ylim(-0.05,1.05)
fig.supylabel('Survival rate')
fig.supxlabel('Pulse height (V)')
plt.tight_layout()
plt.savefig('test_data/efficiency_hist.pdf')
plt.show()
<ipython-input-219-aa15a3b3b3d6>:24: RuntimeWarning:

invalid value encountered in true_divide
../_images/12cryoaug_39_1.svg
fig, axes = plt.subplots(1,2,figsize=(8,4), sharex=True, sharey=True)

axes[0].scatter(dh.get('events', 'pulse_height')[0, np.logical_and(dh.get('events', 'labels')[0] != 1, dh.get('events', 'labels')[0] != 8)], 
            dh.get('events', 'decay_time')[0, np.logical_and(dh.get('events', 'labels')[0] != 1, dh.get('events', 'labels')[0] != 8)], 
            rasterized=True, marker='o', s=3, color='grey', label='artifacts')
axes[0].scatter(dh.get('events', 'pulse_height')[0, np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 8)], 
            dh.get('events', 'decay_time')[0, np.logical_or(dh.get('events', 'labels')[0] == 1, dh.get('events', 'labels')[0] == 8)], 
            rasterized=True, marker='o', s=3, color='green', label='events')
axes[0].legend()

axes[1].scatter(dh.get('events', 'pulse_height')[0, dh.get('events', 'cnn_cut')[0] != 1], 
            dh.get('events', 'decay_time')[0, dh.get('events', 'cnn_cut')[0] != 1], 
            rasterized=True, marker='o', s=3, color='grey', label='rejected')
axes[1].scatter(dh.get('events', 'pulse_height')[0, dh.get('events', 'cnn_cut')[0] == 1], 
            dh.get('events', 'decay_time')[0, dh.get('events', 'cnn_cut')[0] == 1], 
            rasterized=True, marker='o', s=3, color='red', label='accepted')
axes[1].legend()

fig.supxlabel('Pulse height (V)')
fig.supylabel('Decay time (ms)')
plt.tight_layout()
plt.savefig('test_data/efficiency_scatter.pdf')
plt.show()
../_images/12cryoaug_40_0.svg

Now lets see the VizTool!

datasets = {
    'Pulse Height Phonon (V)': ['pulse_height', 0, None],
    'Rise Time Phonon (ms)': ['rise_time', 0, None],
    'Decay Time Phonon (ms)': ['decay_time', 0, None],
    'Onset Phonon (ms)': ['onset', 0, None],
    'Slope Phonon (V)': ['slope', 0, None],
    'Variance Phonon (V^2)': ['var', 0, None],
    'Mean Phonon (V)': ['mean', 0, None],
    'Skewness Phonon': ['skewness', 0, None],
}

viz = ai.VizTool(path_h5='test_data/', 
              fname='test_v0_2',
              group='events', 
              datasets=datasets, 
              nmbr_channels=1, 
              batch_size=1000,
              sample_frequency=25000,
              record_length=16384)
viz.set_colors(dh.get('events', 'cnn_cut')[0])
viz.show()
DataHandler Instance created.

And lets see a t-SNE plot of the events, and try to classify them with a random forest classifier.

et = ai.EvaluationTools()

et.add_events_from_file(file='test_data/test_v0_1.h5',
                    channel=0,
                    which_data='add_mainpar',
                    )

et.split_test_train(test_size=0.60)

_, _, X_train, _, y_train = et.get_train()

clf_rf = RandomForestClassifier(criterion='entropy', max_depth=7)

clf_rf.fit(X_train, 
           y_train)

et.add_prediction(pred_method='RFC', 
                  pred=clf_rf.predict(et.features), 
                  true_labels=True)

et.plt_pred_with_tsne_plotly(pred_methods=['RFC'], what='all', verb=True, inline=True)
/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/sklearn/manifold/_t_sne.py:780: FutureWarning:

The default initialization in TSNE will change from 'random' to 'pca' in 1.2.

/Users/felix/.pyenv/versions/3.8.6/lib/python3.8/site-packages/sklearn/manifold/_t_sne.py:790: FutureWarning:

The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.

Done.