mgplvm.crossval.crossval_bgpfa module
- mgplvm.crossval.crossval_bgpfa.train_cv_bgpfa(Y, device, train_ps, fit_ts, d_fit, ell, T1=None, N1=None, nt_train=None, nn_train=None, test=True, lat_scale=1, rel_scale=1, likelihood='Gaussian', model='bgpfa', ard=True, Bayesian=True)[source]
- Parameters
- modmgplvm.models.svgplvm
instance of svgplvm model to perform crossvalidation on.
- Yarray
data with dimensionality (n x m x n_samples)
- devicetorch.device
GPU/CPU device on which to run the calculations
- train_psdict
dictionary of training parameters. Constructed by crossval.training_params()
- T1Optional[int list]
indices of the conditions to use for training
- N1Optional[int list]
indices of the neurons to use for training
- nt_trainOptional[int]
number of randomly selected conditions to use for training
- nn_trainOptional[int]
number of randomly selected neurons to use for training
- likelihood: Gaussian or NegativeBinomial
- model: bgpfa or vgpfa
- ard: True or False
- Returns
- modmgplvm.svgplvm
model trained via crossvalidation
- first construct one model then save parameters and store a new model copying over the generative params