mgplvm.crossval.crossval_bgpfa module

mgplvm.crossval.crossval_bgpfa.train_cv_bgpfa(Y, device, train_ps, fit_ts, d_fit, ell, T1=None, N1=None, nt_train=None, nn_train=None, test=True, lat_scale=1, rel_scale=1, likelihood='Gaussian', model='bgpfa', ard=True, Bayesian=True)[source]
Parameters
modmgplvm.models.svgplvm

instance of svgplvm model to perform crossvalidation on.

Yarray

data with dimensionality (n x m x n_samples)

devicetorch.device

GPU/CPU device on which to run the calculations

train_psdict

dictionary of training parameters. Constructed by crossval.training_params()

T1Optional[int list]

indices of the conditions to use for training

N1Optional[int list]

indices of the neurons to use for training

nt_trainOptional[int]

number of randomly selected conditions to use for training

nn_trainOptional[int]

number of randomly selected neurons to use for training

likelihood: Gaussian or NegativeBinomial
model: bgpfa or vgpfa
ard: True or False
Returns
modmgplvm.svgplvm

model trained via crossvalidation

first construct one model then save parameters and store a new model copying over the generative params