mgplvm.crossval.crossval module
- mgplvm.crossval.crossval.test_cv(mod, split, device, n_mc=32, Print=False, sample_mean=False, sample_X=False)[source]
- mgplvm.crossval.crossval.train_cv(mod, Y, device, train_ps, T1=None, N1=None, nt_train=None, nn_train=None, test=True)[source]
- Parameters
- modmgplvm.models.svgplvm
instance of svgplvm model to perform crossvalidation on.
- Yarray
data with dimensionality (n x m x n_samples)
- devicetorch.device
GPU/CPU device on which to run the calculations
- train_psdict
dictionary of training parameters. Constructed by crossval.training_params()
- T1Optional[int list]
indices of the conditions to use for training
- N1Optional[int list]
indices of the neurons to use for training
- nt_trainOptional[int]
number of randomly selected conditions to use for training
- nn_trainOptional[int]
number of randomly selected neurons to use for training
- Returns
- modmgplvm.svgplvm
model trained via crossvalidation