Source code for cait.datasets._pt_sampler

import numpy as np

try: from torch.utils.data.sampler import SubsetRandomSampler
except ImportError: SubsetRandomSampler = None
    
[docs]def get_random_samplers(test_size, val_size, dataset_size=None, only_idx=None, shuffle_dataset=True, random_seed=None): """ Chooses the indices for the Split datasets. :param test_size: float between 0 and 1, the size of the testset :param val_size: float between 0 and 1, the size of the validation set :param dataset_size: Size of the whole dataset, is a number :param only_idx: list of ints or None, if set only these indices from the dataset are included :param shuffle_dataset: When true, the indices are dataset is shuffled befor the indices are assigned :param random_seed: set of some value to get the same datasets always for comparability :return: indices for training, validation and test set """ # CHECK IF TORCH IS INSTALLED if SubsetRandomSampler is None: raise RuntimeError("Install 'torch>=1.8' to use this feature.") if dataset_size is None and only_idx is None: raise KeyError('At least one of dataset_size and only_idx must be set!') if only_idx is None: only_idx = list(range(dataset_size)) val_split = int((1 - (val_size + test_size)) * len(only_idx)) # floor rounds down test_split = int((1 - (test_size)) * len(only_idx)) if shuffle_dataset: if random_seed is not None: np.random.seed(random_seed) np.random.shuffle(only_idx) train_indices, val_indices, test_indices = only_idx[:val_split], only_idx[val_split:test_split], only_idx[test_split:] return SubsetRandomSampler(train_indices), \ SubsetRandomSampler(val_indices), \ SubsetRandomSampler(test_indices)