Source code for mgplvm.crossval.train_model

import mgplvm
from mgplvm import optimisers
import numpy as np
import pickle
import torch
from torch import optim


[docs]def training_params(**kwargs): params = { 'max_steps': 1001, 'burnin': 150, 'callback': None, 'optimizer': optim.Adam, 'batch_size': None, 'ts': None, 'print_every': 50, 'lrate': 5E-2, 'batch_pool': None, 'neuron_idxs': None, 'mask_Ts': None, 'n_mc': 32, 'prior_m': None, 'analytic_kl': False, 'accumulate_gradient': True, 'batch_mc': None } for key, value in kwargs.items(): if key in params.keys(): params[key] = value else: print('adding', key) return params
[docs]def train_model(mod, data, params): dataloader = optimisers.data.BatchDataLoader( data, batch_size=params['batch_size'], batch_pool=params['batch_pool']) trained_mod = optimisers.svgp.fit( dataloader, mod, optimizer=params['optimizer'], max_steps=int(round(params['max_steps'])), burnin=params['burnin'], n_mc=params['n_mc'], lrate=params['lrate'], print_every=params['print_every'], stop=params['callback'], neuron_idxs=params['neuron_idxs'], mask_Ts=params['mask_Ts'], prior_m=params['prior_m'], analytic_kl=params['analytic_kl'], accumulate_gradient=params['accumulate_gradient'], batch_mc=params['batch_mc']) return trained_mod