Source code for mgplvm.optimisers.data

import torch
import numpy as np
from torch.utils.data import Dataset


[docs]class DataLoader: def __init__(self, data): n_samples, n, m = data.shape self.n = n self.n_samples = n_samples self.m = m self.batch_pool_size = m self.data = data def __iter__(self): self.i = 0 return self def __next__(self): if self.i == 0: self.i += 1 return (None, None, self.data) else: raise StopIteration
[docs]class BatchDataLoader(DataLoader): def __init__(self, data, batch_size=None, sample_size=None, batch_pool=None, sample_pool=None, shuffle_batch=False, shuffle_sample=False, overlap=0): super(BatchDataLoader, self).__init__(data) m = self.m n_samples = self.n_samples self.overlap = overlap self.shuffle_batch = shuffle_batch self.shuffle_sample = shuffle_sample self.batch_pool = list(range(m)) if batch_pool is None else batch_pool self.sample_pool = list( range(n_samples)) if sample_pool is None else sample_pool self.batch_pool_size = len(self.batch_pool) self.sample_pool_size = len(self.sample_pool) self.batch_size = self.batch_pool_size if batch_size is None else batch_size self.sample_size = self.sample_pool_size if sample_size is None else sample_size if sample_pool is not None: self.data = self.data[sample_pool] if batch_pool is not None: self.data = self.data[:, :, batch_pool] if self.batch_size > self.batch_pool_size: raise Exception( "batch size greater than number of conditions in pool") if self.sample_size > self.sample_pool_size: raise Exception( "sample size greater than number of samples in pool") def __iter__(self): self.i = 0 self.k = 0 if self.shuffle_sample: sample_shuffle_idxs = list(range(self.sample_pool_size)) np.random.shuffle(sample_shuffle_idxs) self.sample_pool = [ self.sample_pool[i] for i in sample_shuffle_idxs ] self.data = self.data[sample_shuffle_idxs] if self.shuffle_batch: batch_shuffle_idxs = list(range(self.batch_pool_size)) np.random.shuffle(batch_shuffle_idxs) self.batch_pool = [self.batch_pool[i] for i in batch_shuffle_idxs] self.data = self.data[:, :, batch_shuffle_idxs] return self def __next__(self): if self.i >= self.sample_pool_size: raise StopIteration else: if self.k >= self.batch_pool_size: self.k = 0 self.i += self.sample_size if self.i >= self.sample_pool_size: raise StopIteration else: return self.get_next() else: return self.get_next()
[docs] def get_next(self): i0 = self.i i1 = i0 + self.sample_size k0 = self.k - self.overlap * (self.k > 0) k1 = k0 + self.batch_size if i1 > self.sample_pool_size: i1 = self.sample_pool_size if k1 > self.batch_pool_size: k1 = self.batch_pool_size batch = self.data[i0:i1][:, :, k0:k1] self.k = k1 batch_idxs = list(range(k0, k1)) batch_idxs = [self.batch_pool[i] for i in batch_idxs] sample_idxs = list(range(i0, i1)) sample_idxs = [self.sample_pool[i] for i in sample_idxs] return sample_idxs, batch_idxs, batch