Source code for mgplvm.lpriors.euclidean

import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dists
from torch.distributions import transform_to, constraints
from ..kernels import Kernel
from ..manifolds import Euclid
from ..manifolds.base import Manifold
from ..models import Svgp
from ..inducing_variables import InducingPoints
from ..likelihoods import Gaussian
from .common import Lprior
from ..utils import softplus, inv_softplus
from typing import Optional


[docs]class LpriorEuclid(Lprior): def __init__(self, manif): if not isinstance(manif, Euclid): raise Exception("GP prior only works with Euclidean manifolds") super().__init__(manif)
[docs]class GP(LpriorEuclid): name = "GP" def __init__(self, n, m, n_samples, manif: Manifold, kernel: Kernel, ts: torch.Tensor, n_z: int = 20, d=1, learn_sigma=False): """ __init__ method for GP prior class (only works for Euclidean manif) Parameters ---------- n : int number of output dimensions (i.e. dimensionality of the latent space) m : int number of time points n_samples : int number of samples (each with a separate GP posterior) manif : mgplvm.manifolds.Manifold latent manifold kernel : mgplvm.kernels.kernel kernel used in the prior (does not haave to mtach the p(Y|G) kernel) ts: Tensor input timepoints for each sample (n_samples x d x m) n_z : Optional[int] number of inducing points used in the GP prior d : Optional[int] number of input dimensions -- defaults to 1 since the input is assumed to be time, but could also be other higher-dimensional observed variables. """ super().__init__(manif) self.n = n self.m = m self.n_samples = n_samples self.d = d #1d latent and n_z inducing points zinit = torch.linspace(0., torch.max(ts).item(), n_z).reshape(1, 1, n_z) #separate inducing points for each latent dimension z = InducingPoints(n, d, n_z, z=zinit.repeat(n, d, 1)) self.ts = ts #consider fixing this to a small value as in GPFA self.lik = Gaussian(n, sigma=torch.ones(n) * 0.2, learn_sigma=learn_sigma) self.svgp = Svgp(kernel, n, m, n_samples, z, self.lik, whiten=True, tied_samples=False) #construct svgp @property def prms(self): q_mu, q_sqrt, z = self.svgp.prms sigma_n = self.svgp.likelihood.prms return q_mu, q_sqrt, z, sigma_n
[docs] def forward(self, x, batch_idxs=None): ''' x is a latent of shape (n_mc x n_samples x mx x d) ts is the corresponding timepoints of shape (n_samples x mx) ''' n_mc, n_samples, m, n = x.shape assert (m == self.m) batch_size = m ts = self.ts.to(x.device) assert (n == self.n) # x now has shape (n_mc, n_samples , n, m) x = x.transpose(-1, -2) ts = ts.reshape(1, n_samples, self.d, -1).repeat(n_mc, 1, 1, 1) svgp_lik, svgp_kl = self.svgp.elbo(x, ts) # Here, we need to rescale the KL term so that it is over the batch not the full dataset, as that is what is expected in SVGPLVM elbo = (batch_size / m) * (svgp_lik - svgp_kl) # as the inducing points are shared across the full batch return elbo.sum(-1) #sum over dimensions
@property def msg(self): ell = self.svgp.kernel.prms[1].mean() noise = self.lik.sigma.mean() return (' prior ell {:.3f} | prior noise {:.3f} |').format( ell.item(), noise.item())
[docs]def fio_id(x): return x
[docs]def fio_ReLU(x): return torch.nn.functional.relu(x)
[docs]def fio_tanh(x): return torch.tanh(x)
[docs]class DS(LpriorEuclid): name = "DS" def __init__( self, manif: Manifold, fio=fio_id, ): """ x_t = f(A*x_(t-1)) + N(0, Q) where A is Hurwitz and Q is diagonal f can be the identity (default; LDS prior) or some non-linear function. """ super().__init__(manif) d = self.d Q = torch.diag_embed(torch.ones(d) * 0.5) self.Q = nn.Parameter(data=Q, requires_grad=False ) # fixes the scale and orientation of the latents A = torch.diag_embed(torch.ones(d)) self.A = nn.Parameter(data=A, requires_grad=True) print('initialized DS') @property def prms(self): O, R = torch.qr(self.A) signs = torch.diag_embed(torch.sign(torch.diag(R))) O = O @ signs R = signs @ R Lsqrt = torch.diag(R) L_I = torch.sqrt(torch.square(Lsqrt) + 1)**(-1) A = torch.diag_embed(Lsqrt) @ O @ torch.diag_embed(L_I) Q = self.Q return A, Q
[docs] def forward(self, x, batch_idxs=None): """ x: (n_mc, n_samples, m, d) """ A, Q = self.prms xA = torch.matmul(x, A) #(n_mc, n_samples, m, d) dx = x[..., 1:, :] - xA[..., :-1, :] mu = torch.zeros(self.d).to(x.device) normal = dists.MultivariateNormal(mu, scale_tril=Q) lq = normal.log_prob(dx) #(n_mc x n_samplesx m-1) lq = lq.sum(-1).sum(-1) #(n_mc) #in the future, we may want an explicit prior over the initial point return lq
@property def msg(self): A, Q = self.prms lp_msg = (' A {:.3f} |').format(torch.diag(A).mean().item()) return lp_msg