Source code for mgplvm.crossval.crossval

import numpy as np
import copy
import torch
from .train_model import train_model

torch.set_default_dtype(torch.float64)


[docs]def not_in(arr, inds): mask = np.ones(arr.size, dtype=bool) mask[inds] = False return arr[mask]
[docs]def update_params(params, **kwargs): newps = copy.copy(params) for key, value in kwargs.items(): newps[key] = value return newps
print('loading')
[docs]def train_cv(mod, Y, device, train_ps, T1=None, N1=None, nt_train=None, nn_train=None, test=True): """ Parameters ---------- mod : mgplvm.models.svgplvm instance of svgplvm model to perform crossvalidation on. Y : array data with dimensionality (n x m x n_samples) device : torch.device GPU/CPU device on which to run the calculations train_ps : dict dictionary of training parameters. Constructed by crossval.training_params() T1 : Optional[int list] indices of the conditions to use for training N1 : Optional[int list] indices of the neurons to use for training nt_train : Optional[int] number of randomly selected conditions to use for training nn_train : Optional[int] number of randomly selected neurons to use for training Returns ------- mod : mgplvm.svgplvm model trained via crossvalidation """ #print('training') _, n, m = Y.shape data = torch.tensor(Y, device=device, dtype=torch.get_default_dtype()) nt_train = int(round(m / 2)) if nt_train is None else nt_train nn_train = int(round(n / 2)) if nn_train is None else nn_train if T1 is None: # random shuffle of timepoints T1 = np.random.permutation(np.arange(m))[:nt_train] if N1 is None: # random shuffle of neurons N1 = np.random.permutation(np.arange(n))[:nn_train] split = {'Y': Y, 'N1': N1, 'T1': T1} train_ps1 = update_params(train_ps, batch_pool=T1, prior_m=len(T1)) #, mask_Ts = mask_Ts) train_model(mod, data, train_ps1) for p in mod.parameters(): #no gradients for the remaining parameters p.requires_grad = False if 'GP' in mod.lat_dist.name: mod.lat_dist.nu.requires_grad = True mod.lat_dist._scale.requires_grad = True if 'circ' in mod.lat_dist.name: mod.lat_dist._c.requires_grad = True mask_Ts = None else: def mask_Ts(grad): ''' used to 'mask' some gradients for cv''' grad[:, T1, ...] *= 0 return grad for p in mod.lat_dist.parameters( ): #only gradients for the latent distribution p.requires_grad = True train_ps2 = update_params(train_ps, neuron_idxs=N1, mask_Ts=mask_Ts, prior_m=None) train_model(mod, data, train_ps2) if test: test_cv(mod, split, device, n_mc=train_ps['n_mc'], Print=True) return mod, split
[docs]def test_cv(mod, split, device, n_mc=32, Print=False, sample_mean=False, sample_X=False): Y, T1, N1 = split['Y'], split['T1'], split['N1'] n_samples, n, m = Y.shape ##### assess the CV quality #### T2, N2 = not_in(np.arange(m), T1), not_in(np.arange(n), N1) #generate prediction for held out data# Ytest = Y[:, N2, :][..., T2] #(ntrial x N2 x T2) #latent means (ntrial, T2, d) if 'GP' in mod.lat_dist.name: latents = mod.lat_dist.lat_mu.detach()[:, T2, ...] else: latents = mod.lat_dist.prms[0].detach()[:, T2, ...] query = latents.transpose(-1, -2) #(ntrial, d, m) if sample_X: #note this only works when the data is structured as a single trial! n_mc = round(np.sqrt(n_mc)) # g is shape (n_samples, n_mc, m, d) g, lq = mod.lat_dist.sample(torch.Size([n_mc]), torch.tensor(Y).to(device), batch_idxs=None, sample_idxs=None) print(g.shape) assert g.shape[1] == 1 #assume there is only a single 'trial' query = g[:, 0, ...].transpose(-1, -2) #now each sample is a 'trial' Ypred = mod.svgp.sample(query, n_mc=n_mc, noise=False) print(Ypred.shape) Ypred = Ypred.mean(0).mean(0) #average over both sets of MC samples Ypred = Ypred.detach().cpu().numpy()[N2, :][:, T2][None, ...] #(1 x N2 x T2) elif sample_mean: #we don't have a closed form mean prediction so sample from (mu|GP) and average instead #n_mc x n_samples x N x d Ypred = mod.svgp.sample(query, n_mc=n_mc, noise=False) Ypred = Ypred.mean(0).cpu().numpy()[:, N2, :] #(ntrial x N2 x T2) else: Ypred, var = mod.svgp.predict(query[None, ...], False) Ypred = Ypred.detach().cpu().numpy()[0][:, N2, :] #(ntrial, N2, T2) MSE_vals = np.mean((Ypred - Ytest)**2, axis=(0, -1)) MSE = np.mean(MSE_vals) #standard MSE norm_MSE = MSE_vals / np.var(Ytest, axis=(0, -1)) #normalize by neuron variance norm_MSE = np.mean(norm_MSE) #print('means:', np.mean(Ytest), np.mean(Ypred)) var_cap = 1 - np.var(Ytest - Ypred) / np.var(Ytest) ### compute crossvalidated log likelihood ### #mold = mod.m #mod.m = len(T2) #use correct scaling factor for the test data #mod.svgp.m = len(T2) data = torch.tensor(Y, device=device) #(n_mc, n_samples, n), (n_mc, n_samples) svgp_elbo, kl = mod.elbo(data[:, :, T2], n_mc, batch_idxs=T2, neuron_idxs=N2, m=len(T2)) #mod.m = mold #restore original scaling factor #mod.svgp.m = mold svgp_elbo = svgp_elbo.sum(-1) #(n_mc) LLs = svgp_elbo - kl # LL for each batch (n_mc, ) LL = (torch.logsumexp(LLs, 0) - np.log(n_mc)).detach().cpu().numpy() LL = LL / (len(T2) * len(N2) * n_samples) if Print: print('LL', LL) print('var_cap', var_cap) print('MSE', MSE, np.sqrt(np.mean(np.var(Ytest, axis=-1)))) return MSE, LL, var_cap, norm_MSE