import torch
import numpy as np
from torch import nn, Tensor
from torch.distributions.multivariate_normal import MultivariateNormal
from ..utils import softplus, inv_softplus
from ..manifolds.base import Manifold
from .common import Rdist
from typing import Optional
from ..fast_utils.toeplitz import sym_toeplitz_matmul
[docs]class GPbase(Rdist):
name = "GPbase" # it is important that child classes have "GP" in their name, this is used in control flow
def __init__(self,
manif: Manifold,
m: int,
n_samples: int,
ts: torch.Tensor,
_scale=0.9,
ell=None):
"""
Parameters
----------
manif: Manifold
manifold of ReLie
m : int
number of conditions/timepoints
n_samples: int
number of samples
ts: Tensor
input timepoints for each sample (n_samples x 1 x m)
mu : Optional[np.ndarray]
initialization of the vartiational means (m x d2)
Notes
-----
Our GP has prior N(0, K)
We parameterize our posterior as N(K2 v, K2 I^2 K2)
where K2 K2 = K and I(s) is some inner matrix which can take different forms.
s is a vector of scale parameters for each time point.
"""
super(GPbase, self).__init__(manif, 1) #kmax = 1
self.manif = manif
self.d = manif.d
self.m = m
#initialize GP mean parameters
nu = torch.randn((n_samples, self.d, m)) * 0.01
self._nu = nn.Parameter(data=nu, requires_grad=True) #m in the notes
#initialize covariance parameters
_scale = torch.ones(n_samples, self.d, m) * _scale #n_diag x T
self._scale = nn.Parameter(data=inv_softplus(_scale),
requires_grad=True)
#initialize length scale
if ell is None:
_ell = torch.ones(1, self.d,
1) * (torch.max(ts) - torch.min(ts)) / 20
else:
if type(ell) in [float, int]:
_ell = torch.ones(1, self.d, 1) * ell
else:
_ell = ell
self._ell = nn.Parameter(data=inv_softplus(_ell), requires_grad=True)
#pre-compute time differences (only need one row for the toeplitz stuff)
self.ts = ts
dts_sq = torch.square(ts - ts[..., :1]) #(n_samples x 1 x m)
#sum over _input_ dimension, add an axis for _output_ dimension
dts_sq = dts_sq.sum(-2)[:, None, ...] #(n_samples x 1 x m)
#print('dts sqr:', dts_sq.shape)
self.dts_sq = nn.Parameter(data=dts_sq, requires_grad=False)
self.dt = (ts[0, 0, 1] - ts[0, 0, 0]).item() #scale by dt
@property
def scale(self) -> torch.Tensor:
#print(self._scale.shape, type(self._scale))
#print(softplus(self._scale).shape)
return softplus(self._scale)
@property
def nu(self) -> torch.Tensor:
return self._nu
@property
def ell(self) -> torch.Tensor:
return softplus(self._ell)
@property
def prms(self):
return self.nu, self.scale, self.ell
@property
def lat_mu(self):
"""return variational mean mu = K_half @ nu"""
nu = self.nu
K_half = self.K_half() #(n_samples x d x m)
mu = sym_toeplitz_matmul(K_half, nu[..., None])[..., 0]
return mu.transpose(-1, -2) #(n_samples x m x d)
[docs] def K_half(self, sample_idxs=None):
"""compute one column of the square root of the prior matrix"""
nu = self.nu #mean parameters
#K^(1/2) has length scale ell/sqrt(2) if K has ell
ell_half = self.ell / np.sqrt(2)
#K^(1/2) has sig var sig*2^1/4*pi^(-1/4)*ell^(-1/2) if K has sig^2 (1 x d x 1)
sig_sqr_half = 1 * (2**(1 / 4)) * np.pi**(-1 / 4) * self.ell**(
-1 / 2) * self.dt**(1 / 2)
if (sample_idxs is None) or (self.dts_sq.shape[0] == 1):
dts = self.dts_sq[:, ...]
else:
dts = self.dts_sq[sample_idxs, ...]
# (n_samples x d x m)
K_half = sig_sqr_half * torch.exp(-dts / (2 * torch.square(ell_half)))
return K_half
[docs] def I_v(self, v, sample_idxs=None):
"""
Compute I @ v for some vector v.
This should be implemented for each class separately.
v is (n_samples x d x m x n_mc) where n_samples is the number of sample_idxs
"""
pass
[docs] def kl(self, batch_idxs=None, sample_idxs=None):
"""
Compute KL divergence between prior and posterior.
This should be implemented for each class separately
"""
pass
[docs] def full_cov(self):
"""Compute the full covariance Khalf @ I @ I @ Khalf"""
v = torch.diag_embed(torch.ones(
self._scale.shape)) #(n_samples x d x m x m)
I = self.I_v(v) #(n_samples x d x m x m)
K_half = self.K_half() #(n_samples x d x m)
Khalf_I = sym_toeplitz_matmul(K_half, I) #(n_samples x d x m x m)
K_post = Khalf_I @ Khalf_I.transpose(-1, -2) #Kpost = Khalf@I@I@Khalf
return K_post.detach()
[docs] def sample(self,
size,
Y=None,
batch_idxs=None,
sample_idxs=None,
kmax=5,
analytic_kl=False,
prior=None):
"""
generate samples and computes its log entropy
"""
#compute KL analytically
lq = self.kl(batch_idxs=batch_idxs,
sample_idxs=sample_idxs) #(n_samples x d)
K_half = self.K_half(sample_idxs=sample_idxs) #(n_samples x d x m)
n_samples, d, m = K_half.shape
# sample a batch with dims: (n_samples x d x m x n_mc)
v = torch.randn(n_samples, d, m, size[0]) # v ~ N(0, 1)
#compute I @ v (n_samples x d x m x n_mc)
I_v = self.I_v(v, sample_idxs=sample_idxs)
nu = self.nu #mean parameter (n_samples, d, m)
if sample_idxs is not None:
nu = nu[sample_idxs, ...]
samp = nu[..., None] + I_v #add mean parameter to each sample
#compute K@(I@v+nu)
x = sym_toeplitz_matmul(K_half, samp) #(n_samples x d x m x n_mc)
x = x.permute(-1, 0, 2, 1) #(n_mc x n_samples x m x d)
if batch_idxs is not None: #only select some time points
x = x[..., batch_idxs, :]
#(n_mc x n_samples x m x d), (n_samples x d)
return x, lq
[docs] def gmu_parameters(self):
return [self._nu]
[docs] def concentration_parameters(self):
return [self._scale, self._ell]
[docs] def msg(self, Y=None, batch_idxs=None, sample_idxs=None):
mu_mag = torch.sqrt(torch.mean(self.nu**2)).item()
sig = torch.median(self.scale).item()
ell = self.ell.mean().item()
string = (' |mu| {:.3f} | sig {:.3f} | prior_ell {:.3f} |').format(
mu_mag, sig, ell)
return string