Source code for mgplvm.models.bfa

# Bayesian Factor Analaysis
import abc
import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
from mgplvm.utils import softplus, inv_softplus
from ..base import Module
from ..kernels import Kernel
from ..inducing_variables import InducingPoints
from typing import Tuple, List, Optional, Union
from torch.distributions import MultivariateNormal, LowRankMultivariateNormal, kl_divergence, transform_to, constraints, Normal
from ..likelihoods import Likelihood
from sklearn import decomposition
from .gp_base import GpBase
import itertools

jitter: float = 1E-8
log2pi: float = np.log(2 * np.pi)


[docs]def batch_capacitance_tril(W, D): r""" Copied from pytorch source code Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` and a batch of vectors :math:`D`. """ m = W.size(-1) Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2) K = torch.matmul(Wt_Dinv, W).contiguous() K.view(-1, m * m)[:, ::m + 1] += 1 return torch.cholesky(K)
[docs]class Bfa(GpBase): """ Bayesian Factor Analysis Assumes Gaussian observation noise Computes log_prob and posterior predictions exactly """ name = "Bfa" def __init__(self, n: int, d: int, sigma: Optional[Tensor] = None, learn_sigma=True, Y=None, learn_neuron_scale=False, ard=False, learn_scale=None): super().__init__() if Y is not None: n_samples_fa, n_fa, m_fa = Y.shape mod = decomposition.FactorAnalysis(n_components=d) Y_fa = Y.transpose(0, 2, 1).reshape(n_samples_fa * m_fa, n_fa) mudata = mod.fit_transform(Y_fa) #m*n_samples x d C = torch.tensor(mod.components_.T) # (n x d) #### initialize noise parameters #### if sigma is None: if Y is None: sigma = torch.ones(n,) # TODO: FA init else: sigma = torch.tensor(np.sqrt(mod.noise_variance_)) self._sigma = nn.Parameter(data=sigma, requires_grad=learn_sigma) self.n = n #### initialize prior parameters #### _scale = torch.ones(1) _dim_scale = torch.ones(d) _neuron_scale = torch.ones(n) if learn_scale is None: learn_scale = not (ard or learn_neuron_scale) if Y is not None: #initialize from FA if learn_scale: _scale = torch.square(C).mean().sqrt() #global scale if learn_neuron_scale: _neuron_scale = torch.square(C).mean(1).sqrt() #per neuron if ard: _dim_scale = torch.square(C).mean(0).sqrt() #per latent self._scale = nn.Parameter(inv_softplus(_scale), requires_grad=learn_scale) self._neuron_scale = nn.Parameter(inv_softplus(_neuron_scale), requires_grad=learn_neuron_scale) self._dim_scale = nn.Parameter(inv_softplus(_dim_scale), requires_grad=ard) @property def prms(self) -> Tensor: """p(y_i | f_i) = N(0, sigma^2)""" variance = torch.square(self._sigma) return variance @property def sigma(self) -> Tensor: return (1e-20 + self.prms).sqrt() @property def scale(self): return softplus(self._scale) @property def neuron_scale(self): return softplus(self._neuron_scale)[:, None] @property def dim_scale(self): return softplus(self._dim_scale)[:, None] def _dist(self, x): """ construct low rank prior MVN = N(0, X^T X + sigma^2 I) """ m = x.shape[-1] x = self.scale * self.dim_scale * x cov_factor = x[..., None, :, :].transpose(-1, -2) #(n_samples x 1 x m x d) cov_factor = self.neuron_scale[ ..., None] * cov_factor #(n_samples x n x m x d) cov_diag = self.prms[:, None] * torch.ones(m).to(x.device) #(n x m) dist = LowRankMultivariateNormal(loc=torch.zeros(self.n, m).to(x.device), cov_factor=cov_factor, cov_diag=cov_diag) return dist
[docs] def log_prob(self, y, x): """compute prior p(y) = N(y|0, X^T X)""" lp = self._dist(x).log_prob(y) #(n_mc x n_samples x n) return lp
[docs] def elbo(self, y: Tensor, x: Tensor, sample_idxs: Optional[List[int]] = None, m: Optional[int] = None) -> Tuple[Tensor, Tensor]: """ Parameters ---------- y : Tensor data tensor with dimensions (n_samples x n x m) x : Tensor (single kernel) or Tensor list (product kernels) input tensor(s) with dimensions (n_mc x n_samples x d x m) Returns ------- lik, prior_kl : Tuple[torch.Tensor, torch.Tensor] lik has dimensions (n_mc x n) prior_kl has dimensions (n) and is zero """ lik = self.log_prob(y, x) #( (n_mc) x n_samples x n) lik = lik.sum(-2) prior_kl = torch.zeros(self.n).to(x.device) return lik, prior_kl
[docs] def predict(self, xstar, y, x, full_cov=False): """ compute posterior p(f* | x, y) """ prec = self._dist(x.squeeze()).precision_matrix #(K+sigma^2I)^-1 m = x.shape[-1] d = x.shape[-2] x = self.scale * self.dim_scale * x xstar = self.scale * self.dim_scale * xstar x = x[..., None, :, :] #(...,d,m) xstar = xstar[..., None, :, :] #(...,d,m) xt = x.transpose(-1, -2) #(...,m,d) variance = self.prms #(n) (p(Y|F) variance) cov_diag = variance[..., None] * torch.ones(m) #(n x m) capacitance_tril = batch_capacitance_tril(xt, cov_diag) xdinv = (x / variance[..., None, None]) A = torch.triangular_solve(xdinv, capacitance_tril, upper=False)[0] #(...,d,m) y = y[..., None] _mu1 = xdinv.matmul(y) _mu2 = x.matmul(A.transpose(-1, -2).matmul(A.matmul(y))) mu = xstar.transpose(-1, -2).matmul(_mu1 - _mu2).squeeze(-1) if not full_cov: v1 = torch.square(xstar).sum(-2) v2 = xstar.transpose(-1, -2).matmul( x / variance[:, None, None].sqrt()).square().sum(-1) v3 = A.matmul(x.transpose(-1, -2)).matmul(xstar).square().sum(-2) v = v1 - v2 + v3 return mu, v1 - v2 + v3 else: xxT = xdinv.matmul(x.transpose(-1, -2)) xAT = x.matmul(A.transpose(-1, -2)) xATAxT = xAT.matmul(xAT.transpose(-1, -2)) z = torch.eye(d) - xxT + xATAxT c = xstar.transpose(-1, -2).matmul(z).matmul(xstar) return mu, c
[docs] def g0_parameters(self): return []
[docs] def g1_parameters(self): return [self._sigma, self._neuron_scale, self._dim_scale, self._scale]
@property def msg(self): return ('scale {:.3f} |').format( (self.scale.mean() * self.neuron_scale.mean() * self.dim_scale.mean()).item())
[docs]class Bvfa(GpBase): name = "Bvfa" def __init__(self, n: int, d: int, m: int, n_samples: int, likelihood: Likelihood, q_mu: Optional[Tensor] = None, q_sqrt: Optional[Tensor] = None, tied_samples=True, Y=None, learn_neuron_scale=False, ard=False, learn_scale=None, rel_scale=1, scale=None, dim_scale=None, neuron_scale=None): """ __init__ method for Base Variational Factor Analysis Parameters ---------- n : int number of neurons d: int latent dimensionality m : int number of conditions n_samples : int number of samples likelihood : Likelihood likliehood module used for computing variational expectation q_mu : Optional Tensor optional Tensor for initialization q_sqrt : Optional Tensor optional Tensor for initialization tied_samples : Optional bool """ super().__init__() self.n = n self.d = d self.m = m self.tied_samples = tied_samples self.n_samples = n_samples #self.z, self.kernel = [NoneClass() for i in range(2)] #### initialize prior parameters #### _scale = torch.ones(1) _dim_scale = torch.ones(d) _neuron_scale = torch.ones(n) if learn_scale is None: learn_scale = not (ard or learn_neuron_scale) if Y is not None: #initialize from FA n_samples_fa, n_fa, m_fa = Y.shape mod = decomposition.FactorAnalysis(n_components=d) Y_fa = Y.transpose(0, 2, 1).reshape(n_samples_fa * m_fa, n_fa) mudata = mod.fit_transform(Y_fa) #m*n_samples x d C = torch.tensor(mod.components_.T) # (n x d) #print(C.shape) if learn_scale: _scale = rel_scale * torch.square( C).mean().sqrt() #global scale if learn_neuron_scale: _neuron_scale = rel_scale * torch.square(C).mean( 1).sqrt() #per neuron if ard: _dim_scale = rel_scale * torch.square(C).mean( 0).sqrt() #per latent ##optionally provide these as params## scale = _scale if scale is None else scale dim_scale = _dim_scale if dim_scale is None else dim_scale neuron_scale = _neuron_scale if neuron_scale is None else neuron_scale self._scale = nn.Parameter(inv_softplus(scale), requires_grad=learn_scale) self._neuron_scale = nn.Parameter(inv_softplus(neuron_scale), requires_grad=learn_neuron_scale) self._dim_scale = nn.Parameter(inv_softplus(dim_scale), requires_grad=ard) #### initialize variational distribution (should we initialize this to the Gaussian ground truth?)#### if q_mu is None: if tied_samples: q_mu = torch.zeros(1, n, d) else: q_mu = torch.zeros(n_samples, n, d) if q_sqrt is None: if tied_samples: q_sqrt = torch.diag_embed(torch.ones(1, n, d)) else: q_sqrt = torch.diag_embed(torch.ones(n_samples, n, d)) else: q_sqrt = transform_to(constraints.lower_cholesky).inv(q_sqrt) assert (q_mu is not None) assert (q_sqrt is not None) if self.tied_samples: assert (q_mu.shape[0] == 1) assert (q_sqrt.shape[0] == 1) else: assert (q_mu.shape[0] == n_samples) assert (q_sqrt.shape[0] == n_samples) self._q_mu = nn.Parameter(q_mu, requires_grad=True) self._q_sqrt = nn.Parameter(q_sqrt, requires_grad=True) self.likelihood = likelihood @property def scale(self): return softplus(self._scale) @property def neuron_scale(self): return softplus(self._neuron_scale)[:, None] @property def dim_scale(self): return softplus(self._dim_scale)[:, None] @property def q_mu(self): return self._q_mu @property def q_sqrt(self): return transform_to(constraints.lower_cholesky)(self._q_sqrt)
[docs] def prior_kl(self, sample_idxs=None): """ KL(p(f) || q(f)) """ q_mu, q_sqrt = self.prms assert (q_mu.shape[0] == q_sqrt.shape[0]) if not self.tied_samples and sample_idxs is not None: q_mu = q_mu[sample_idxs] q_sqrt = q_sqrt[sample_idxs] q = MultivariateNormal(q_mu, scale_tril=q_sqrt) e = torch.eye(self.d).to(q_mu.device) p_mu = torch.zeros(self.n, self.d).to(q_mu.device) prior = MultivariateNormal(p_mu, scale_tril=e) return kl_divergence(q, prior) ##consider implementing this directly
[docs] def elbo(self, y: Tensor, x: Tensor, sample_idxs: Optional[List[int]] = None, m: Optional[int] = None) -> Tuple[Tensor, Tensor]: """ Parameters ---------- y : Tensor data tensor with dimensions (n_samples x n x m) x : Tensor (single kernel) or Tensor list (product kernels) input tensor(s) with dimensions (n_mc x n_samples x d x m) m : Optional int used to scale the svgp likelihood. If not provided, self.m is used which is provided at initialization. This parameter is useful if we subsample data but want to weight the prior as if it was the full dataset. We use this e.g. in crossvalidation Returns ------- lik, prior_kl : Tuple[torch.Tensor, torch.Tensor] lik has dimensions (n_mc x n) prior_kl has dimensions (n) """ assert (x.shape[-3] == y.shape[-3]) assert (x.shape[-1] == y.shape[-1]) batch_size = x.shape[-1] sample_size = x.shape[-3] # prior KL(q(u) || p(u)) (1 x n) if tied_samples otherwise (n_samples x n) prior_kl = self.prior_kl(sample_idxs) # predictive mean and var at x f_mean, f_var = self.predict(x, full_cov=False, sample_idxs=sample_idxs) prior_kl = prior_kl.sum(-2) if not self.tied_samples: prior_kl = prior_kl * (self.n_samples / sample_size) #(n_mc, n_samles, n) lik = self.likelihood.variational_expectation(y, f_mean, f_var) # scale is (m / batch_size) * (self.n_samples / sample size) # to compute an unbiased estimate of the likelihood of the full dataset m = (self.m if m is None else m) scale = (m / batch_size) * (self.n_samples / sample_size) lik = lik.sum(-2) lik = lik * scale return lik, prior_kl
[docs] def sample(self, query: Tensor, n_mc: int = 1000, square: bool = False, noise: bool = True): """ Parameters ---------- query : Tensor (single kernel) test input tensor with dimensions (n_samples x d x m) n_mc : int numper of samples to return square : bool determines whether to square the output noise : bool determines whether we also sample explicitly from the noise model or simply return samples of the mean Returns ------- y_samps : Tensor samples from the model (n_mc x n_samples x d x m) """ query = query[None, ...] #add batch dimension (1 x n_samples x d x m) mu, v = self.predict(query, False) #1xn_samplesxnxm, 1xn_samplesxnxm # remove batch dimension mu = mu[0] #n_samples x n x m, v = v[0] # n_samples x n x m #sample from p(f|u) dist = Normal(mu, torch.sqrt(v)) f_samps = dist.sample((n_mc,)) #n_mc x n_samples x n x m if noise: #sample from observation function p(y|f) y_samps = self.likelihood.sample(f_samps) #n_mc x n_samples x n x m else: #compute mean observations mu(f) for each f y_samps = self.likelihood.dist_mean( f_samps) #n_mc x n_samples x n x m if square: y_samps = y_samps**2 return y_samps
[docs] def predict(self, x: Tensor, full_cov: bool, sample_idxs=None) -> Tuple[Tensor, Tensor]: """ Parameters ---------- x : Tensor (single kernel) or Tensor list (product kernels) test input tensor(s) with dimensions (n_b x n_samples x d x m) full_cov : bool returns full covariance if true otherwise returns the diagonal Returns ------- mu : Tensor mean of predictive density at test inputs [ s ] v : Tensor variance/covariance of predictive density at test inputs [ s ] if full_cov is true returns full covariance, otherwise returns diagonal variance """ q_mu, q_sqrt = self.prms assert (q_mu.shape[0] == q_sqrt.shape[0]) if (not self.tied_samples) and sample_idxs is not None: q_mu = q_mu[sample_idxs] q_sqrt = q_sqrt[sample_idxs] x = self.scale * self.dim_scale * x #multiply each dimension by the prior scale mu = q_mu.matmul(x) # n_b x n_samples x n x m l = x[..., None, :, :].transpose(-1, -2).matmul( q_sqrt) # n_b x n_samples x m x d if not full_cov: return mu, torch.square(l).sum(-1) else: return mu, l.matmul(l.transpose(-1, -2))
@property def prms(self) -> Tuple[Tensor, Tensor]: q_mu = self.q_mu q_sqrt = self.q_sqrt #multiply the posterior by a scale factor for each neuron q_mu, q_sqrt = self.neuron_scale * q_mu, self.neuron_scale[ ..., None] * q_sqrt return q_mu, q_sqrt
[docs] def g0_parameters(self): return [self._q_mu, self._q_sqrt]
[docs] def g1_parameters(self): return list( itertools.chain.from_iterable([ self.likelihood.parameters(), [self._scale, self._neuron_scale, self._dim_scale] ]))
@property def msg(self): newmsg = ('scale {:.3f} |').format( (self.scale.mean() * self.neuron_scale.mean() * self.dim_scale.mean()).item()) return newmsg + self.likelihood.msg
[docs]class Fa(GpBase): """ Standard non-Bayesian Factor Analysis Assumes Gaussian observation noise Computes log_prob and posterior predictions exactly """ name = "Fa" def __init__(self, n: int, d: int, sigma: Optional[Tensor] = None, learn_sigma=True, Y=None, C=None): """ n: number of neurons d: number of latents """ super().__init__() self.n = n if Y is None: _C = torch.randn(n, d) * d**(-0.5) # TODO: FA init _sigma = torch.ones(n,) * 0.5 # TODO: FA init else: n_samples, n, m = Y.shape mod = decomposition.FactorAnalysis(n_components=d) Y = Y.transpose(0, 2, 1).reshape(n_samples * m, n) mudata = mod.fit_transform(Y) #m*n_samples x d _sigma = torch.tensor(np.sqrt(mod.noise_variance_)) _C = torch.tensor(mod.components_.T) sigma = _sigma if sigma is None else sigma C = _C if C is None else C self._sigma = nn.Parameter(data=sigma, requires_grad=learn_sigma) self.C = nn.Parameter(data=C, requires_grad=True) @property def prms(self) -> Tensor: """p(y_i | f_i) = N(0, sigma^2_i)""" variance = torch.square(self._sigma) return variance @property def sigma(self) -> Tensor: return (1e-20 + self.prms).sqrt()
[docs] def log_prob(self, y, x): """ compute p(y|X) = N(y|CX, I) x is (n_mc x n_samples x d x m) y is (n_samples x n x m) """ mean = self.C @ x #(... x n x m) mean = mean.transpose(-1, -2) #(... x m x n) dist = Normal(loc=mean, scale=self.sigma) lp = dist.log_prob(y.transpose(-1, -2)) #(... x m x n) #print('lp:', lp.shape) return lp
[docs] def elbo(self, y: Tensor, x: Tensor, sample_idxs: Optional[List[int]] = None, m: Optional[int] = None) -> Tuple[Tensor, Tensor]: """ Parameters ---------- y : Tensor data tensor with dimensions (n_samples x n x m) x : Tensor (single kernel) or Tensor list (product kernels) input tensor(s) with dimensions (n_mc x n_samples x d x m) Returns ------- lik, prior_kl : Tuple[torch.Tensor, torch.Tensor] lik has dimensions (n_mc x n) prior_kl has dimensions (n) and is zero """ lik = self.log_prob(y, x) #(n_mc x n_samples x m x n) lik = lik.sum(-2).sum(-2) #n_mc x n prior_kl = torch.zeros(self.n).to(x.device) return lik, prior_kl
[docs] def predict(self, xstar, full_cov=False): """ compute posterior p(f* | x, y, C) = N(C@x*, Sig) """ mu = self.C @ xstar #(n_samples x n x m) cov = torch.zeros(mu.shape) #p(f|C, x) is a delta function if not full_cov: return mu, cov else: return mu, torch.diag_embed(cov)
[docs] def sample(self, query: Tensor, n_mc: int = 1000, square: bool = False, noise: bool = True): """ Parameters ---------- query : Tensor (single kernel) test input tensor with dimensions (n_samples x d x m) n_mc : int numper of samples to return square : bool determines whether to square the output noise : bool determines whether we also sample explicitly from the noise model or simply return samples of the mean Returns ------- y_samps : Tensor samples from the model (n_mc x n_samples x d x m) """ query = query[None, ...] #add batch dimension (1 x n_samples x d x m) mu, _ = self.predict(query, False) #1xn_samplesxnxm, 1xn_samplesxnxm # remove batch dimension mu = mu[0] #n_samples x n x m, #sample from p(f|x) which is a delta function for FA f_samps = mu if noise: #sample from observation function p(y|f) dist = Normal(loc=f_samps, scale=self.sigma[..., None]) y_samps = dist.sample(n_mc) #n_mc x n_samples x n x m else: #compute mean observations mu(f) for each f y_samps = torch.ones( n_mc, mu.shape[0], mu.shape[1], mu.shape[2]).to( query.device) * f_samps #n_mc x n_samples x n x m if square: y_samps = y_samps**2 return y_samps
[docs] def g0_parameters(self): return []
[docs] def g1_parameters(self): return [self._sigma, self.C]
@property def msg(self): return ''
[docs]class vFa(GpBase): """ Variational non-Bayesian Factor Analysis Allows for non-Gaussian noise """ name = "vFa" def __init__(self, n: int, d: int, m: int, n_samples: int, likelihood: Likelihood, Y=None, rel_scale=1, C=None): """ n: number of neurons d: number of latents """ super().__init__() self.n = n self.m = m self.n_samples = n_samples if Y is None: _C = torch.randn(n, d) * d**(-0.5) * rel_scale else: n_samples, n, m = Y.shape mod = decomposition.FactorAnalysis(n_components=d) Y = Y.transpose(0, 2, 1).reshape(n_samples * m, n) mudata = mod.fit_transform(Y) #m*n_samples x d _C = torch.tensor(mod.components_.T) * rel_scale C = _C if C is None else C self.C = nn.Parameter(data=C, requires_grad=True) self.likelihood = likelihood
[docs] def elbo(self, y: Tensor, x: Tensor, sample_idxs: Optional[List[int]] = None, m: Optional[int] = None) -> Tuple[Tensor, Tensor]: """ Parameters ---------- y : Tensor data tensor with dimensions (n_samples x n x m) x : Tensor (single kernel) or Tensor list (product kernels) input tensor(s) with dimensions (n_mc x n_samples x d x m) Returns ------- lik, prior_kl : Tuple[torch.Tensor, torch.Tensor] lik has dimensions (n_mc x n) prior_kl has dimensions (n) and is zero """ assert (x.shape[-3] == y.shape[-3]) assert (x.shape[-1] == y.shape[-1]) batch_size = x.shape[-1] sample_size = x.shape[-3] # predictive mean and var at x f_mean = self.C @ x #(... x n x m) if sample_idxs is not None: f_mean = f_mean[:, sample_idxs, ...] f_var = torch.zeros(f_mean.shape).to(f_mean.device) + 1e-12 #(n_mc, n_samles, n) lik = self.likelihood.variational_expectation(y, f_mean, f_var) # scale is (m / batch_size) * (self.n_samples / sample size) # to compute an unbiased estimate of the likelihood of the full dataset m = (self.m if m is None else m) scale = (m / batch_size) * (self.n_samples / sample_size) lik = lik.sum(-2) lik = lik * scale prior_kl = torch.zeros(1, lik.shape[-1]).to( lik.device) #not Bayesian; no prior term (1xn) return lik, prior_kl
[docs] def predict(self, xstar, full_cov=False): """ compute posterior p(f* | x, y, C) = N(C@x*, Sig) """ mu = self.C @ xstar #(n_samples x n x m) cov = torch.zeros(mu.shape) #p(f|C, x) is a delta function if not full_cov: return mu, cov else: return mu, torch.diag_embed(cov)
[docs] def sample(self, query: Tensor, n_mc: int = 1000, square: bool = False, noise: bool = True): """ Parameters ---------- query : Tensor (single kernel) test input tensor with dimensions (n_samples x d x m) n_mc : int numper of samples to return square : bool determines whether to square the output noise : bool determines whether we also sample explicitly from the noise model or simply return samples of the mean Returns ------- y_samps : Tensor samples from the model (n_mc x n_samples x d x m) """ query = query[None, ...] #add batch dimension (1 x n_samples x d x m) mu, _ = self.predict(query, False) #1xn_samplesxnxm, 1xn_samplesxnxm # remove batch dimension mu = mu[0] #n_samples x n x m, #sample from p(f|x) which is a delta function for FA f_samps = mu if noise: #sample from observation function p(y|f) dist = Normal(loc=f_samps, scale=self.sigma[..., None]) y_samps = dist.sample(n_mc) #n_mc x n_samples x n x m else: #compute mean observations mu(f) for each f y_samps = torch.ones( n_mc, mu.shape[0], mu.shape[1], mu.shape[2]).to( query.device) * f_samps #n_mc x n_samples x n x m if square: y_samps = y_samps**2 return y_samps
@property def prms(self) -> Tensor: return self.C
[docs] def g0_parameters(self): return []
[docs] def g1_parameters(self): return list( itertools.chain.from_iterable( [self.likelihood.parameters(), [self.C]]))
@property def msg(self): newmsg = ('C norm {:.3f} |').format((self.C**2).mean().item()) return newmsg + self.likelihood.msg