mgplvm.crossval.crossval module

mgplvm.crossval.crossval.not_in(arr, inds)[source]
mgplvm.crossval.crossval.test_cv(mod, split, device, n_mc=32, Print=False, sample_mean=False, sample_X=False)[source]
mgplvm.crossval.crossval.train_cv(mod, Y, device, train_ps, T1=None, N1=None, nt_train=None, nn_train=None, test=True)[source]
Parameters
modmgplvm.models.svgplvm

instance of svgplvm model to perform crossvalidation on.

Yarray

data with dimensionality (n x m x n_samples)

devicetorch.device

GPU/CPU device on which to run the calculations

train_psdict

dictionary of training parameters. Constructed by crossval.training_params()

T1Optional[int list]

indices of the conditions to use for training

N1Optional[int list]

indices of the neurons to use for training

nt_trainOptional[int]

number of randomly selected conditions to use for training

nn_trainOptional[int]

number of randomly selected neurons to use for training

Returns
modmgplvm.svgplvm

model trained via crossvalidation

mgplvm.crossval.crossval.update_params(params, **kwargs)[source]