Source code for mgplvm.rdist.relie

import abc
import torch
from torch import nn, Tensor
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
from torch.distributions import transform_to, constraints
from ..utils import softplus, inv_softplus
from ..manifolds.base import Manifold
from .common import Rdist
from typing import Optional
from ..base import Module


[docs]class ReLieBase(Rdist): name = "ReLieBase" def __init__(self, manif: Manifold, f, kmax: int = 5, diagonal: bool = False): super(ReLieBase, self).__init__(manif, kmax) self.f = f self.diagonal = diagonal
[docs] def lat_prms(self, Y=None, batch_idxs=None, sample_idxs=None): gmu, gamma = self.f(Y, batch_idxs, sample_idxs) return gmu, gamma
[docs] def lat_gmu(self, Y=None, batch_idxs=None, sample_idxs=None): return self.lat_prms(Y, batch_idxs, sample_idxs)[0]
[docs] def lat_gamma(self, Y=None, batch_idxs=None, sample_idxs=None): return self.lat_prms(Y, batch_idxs, sample_idxs)[1]
@property def prms(self): self.f.prms
[docs] def mvn(self, gamma=None): gamma = self.lat_gamma() if gamma is None else gamma n_samples, m, _, _ = gamma.shape mu = torch.zeros(n_samples, m, self.d).to(gamma.device) return MultivariateNormal(mu, scale_tril=gamma)
[docs] def sample(self, size, Y=None, batch_idxs=None, sample_idxs=None, kmax=5, analytic_kl=False, prior=None): """ generate samples and computes its log entropy """ gmu, gamma = self.lat_prms(Y, batch_idxs, sample_idxs) q = self.mvn(gamma) # sample a batch with dims: (n_mc x n_samples x batch_size x d) x = q.rsample(size) m = x.shape[-2] mu = torch.zeros(m).to(gamma.device)[..., None] if self.diagonal: #compute diagonal covariance lq = torch.stack([ self.manif.log_q( Normal(mu, gamma[..., j, j][..., None]).log_prob, x[..., j, None], 1, self.kmax).sum(dim=-1) for j in range(self.d) ]).sum(dim=0) else: #compute entropy with full covariance matrix lq = self.manif.log_q(q.log_prob, x, self.manif.d, self.kmax) gtilde = self.manif.expmap(x) # apply g_mu with dims: (n_mc x m x d) g = self.manif.gmul(gmu, gtilde) return g, lq
[docs] def gmu_parameters(self): return self.f.gmu_parameters()
[docs] def concentration_parameters(self): return self.f.concentration_parameters()
[docs] def msg(self, Y=None, batch_idxs=None, sample_idxs=None): mu, gamma = self.lat_prms(Y=Y, batch_idxs=batch_idxs, sample_idxs=sample_idxs) gamma = gamma.diagonal(dim1=-1, dim2=-2) mu_mag = torch.sqrt(torch.mean(mu**2)).item() sig = torch.median(gamma).item() string = (' |mu| {:.3f} | sig {:.3f} |').format(mu_mag, sig) return string
class _F(Module): def __init__(self, manif: Manifold, m: int, n_samples: int, kmax: int = 5, sigma: float = 1.5, gamma: Optional[Tensor] = None, fixed_gamma=False, diagonal=False, mu=None, initialization: Optional[str] = 'random', Y=None): super(_F, self).__init__() self.manif = manif self.diagonal = diagonal if mu is None: gmu = self.manif.initialize(initialization, n_samples, m, manif.d, Y) else: assert mu.shape == (n_samples, m, manif.d2) gmu = torch.tensor(mu) self.gmu = nn.Parameter(data=gmu, requires_grad=True) if gamma is None: gamma = torch.ones(n_samples, m, manif.d) * sigma assert gamma.shape == (n_samples, m, manif.d) if diagonal: gamma = inv_softplus(gamma) else: gamma = torch.diag_embed(gamma) gamma = transform_to(constraints.lower_cholesky).inv(gamma) assert (gamma is not None) self.gamma = nn.Parameter(data=gamma, requires_grad=(not fixed_gamma)) def forward(self, Y=None, batch_idxs=None, sample_idxs=None): gmu, gamma = self.prms if sample_idxs is not None: gmu = gmu[sample_idxs] gamma = gamma[sample_idxs] if batch_idxs is None: return gmu, gamma else: return gmu[:, batch_idxs, :], gamma[:, batch_idxs, :] @property def prms(self): gmu = self.manif.parameterise(self.gmu) if self.diagonal: gamma = torch.diag_embed(softplus(self.gamma)) else: gamma = torch.distributions.transform_to( MultivariateNormal.arg_constraints['scale_tril'])(self.gamma) return gmu, gamma def gmu_parameters(self): return [self.gmu] def concentration_parameters(self): return [self.gamma]
[docs]class ReLie(ReLieBase): name = "ReLie" def __init__(self, manif: Manifold, m: int, n_samples: int, kmax: int = 5, sigma: float = 1.5, gamma: Optional[Tensor] = None, fixed_gamma=False, diagonal=False, mu=None, initialization: Optional[str] = 'random', Y=None): """ Parameters ---------- manif: Manifold manifold of ReLie m : int number of conditions/timepoints n_samples: int number of samples kmax : Optional[int] number of terms used in the ReLie approximation is (2kmax+1) sigma : Optional[float] initial diagonal std of the variational distribution intialization : Optional[str] string to specify type of initialization ('random'/'PCA'/'identity' depending on manifold) mu : Optional[np.ndarray] initialization of the vartiational means (n_samples x m x d2) Y : Optional[np.ndarray] data used to initialize latents (n x m) Notes ----- gamma is the reference distribution which is inverse transformed before storing since it's transformed by constraints.tril when used. If no gammas is None, it is initialized as a diagonal matrix with value sigma If diagonal, constrain the covariance to be diagonal. The diagonal approximation is useful for T^n as it saves an exponentially growing ReLie complexity The diagonal approximation only works for T^n and R^n """ f = _F(manif, m, n_samples, kmax, sigma, gamma, fixed_gamma, diagonal, mu, initialization, Y) super(ReLie, self).__init__(manif, f, kmax, diagonal) @property def prms(self): return self.f.prms