Source code for mgplvm.manifolds.torus

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from .base import Manifold
from ..inducing_variables import InducingPoints
from typing import Optional
from sklearn import decomposition


[docs]class Torus(Manifold): def __init__(self, m: int, d: int): """ Parameters ---------- m : int number of conditions/timepoints d : int latent dimensionality """ super().__init__(d) self.m = m self.d2 = d # dimensionality of the group parameterization # per condition self.lprior_const = torch.tensor(-self.d * np.log(2 * np.pi))
[docs] @staticmethod def initialize(initialization, n_samples, m, d, Y): '''initializes latents - can add more exciting initializations as well''' if initialization in ['fa', 'FA']: #Y is N x m; reduce to d x m if Y is None: print('user must provide data for FA initialization') else: n = Y.shape[1] pca = decomposition.FactorAnalysis(n_components=d) Y = Y.transpose(0, 2, 1).reshape(n_samples * m, n) mudata = pca.fit_transform(Y) #m*n_samples x d mudata *= 2 * np.pi / (np.amax(mudata) - np.amin(mudata)) mudata = mudata.reshape(n_samples, m, d) return torch.tensor(mudata, dtype=torch.get_default_dtype()) elif initialization in ['random', 'Random']: mudata = torch.randn(n_samples, m, d) * 0.1 return mudata elif initialization in ['uniform_random']: mudata = torch.rand(n_samples, m, d) * 2 * np.pi return mudata else: print('initialization not recognized') return
[docs] def inducing_points(self, n, n_z, z=None): z = torch.rand(n, self.d, n_z) * 2 * np.pi if z is None else z return InducingPoints(n, self.d, n_z, z=z)
[docs] def lprior(self, g: Tensor) -> Tensor: return self.lprior_const * torch.ones(g.shape[:-1])
# log of the uniform prior (negative log volume) for T^d @property def log_uniform(self) -> Tensor: return -self.d * np.log(2 * np.pi) @property def name(self): return 'Torus(' + str(self.d) + ')'
[docs] @staticmethod def parameterise(x) -> Tensor: return x
[docs] @staticmethod def expmap(x: Tensor) -> Tensor: '''move to [-pi, pi]''' return (x + np.pi) % (2 * np.pi) - np.pi
[docs] @staticmethod def logmap(x: Tensor) -> Tensor: '''move to [-pi, pi]''' return (x + np.pi) % (2 * np.pi) - np.pi
[docs] @staticmethod def log_q(log_base_prob, x, d, kmax): """ log of the variational distribution (~-H(Q))""" ks = np.arange(-kmax, kmax + 1) zs = np.meshgrid(*(ks for _ in range(d))) zs = np.stack([z.flatten() for z in zs]).T * 2. * np.pi zs = torch.tensor(zs, dtype=torch.get_default_dtype()) zs = zs.to(x.device) # meshgrid shape (2kmax+1)^n y = x + zs[:, None, None, None, ...] # meshgrid x n_b x n_samples, m x d lp = torch.logsumexp(log_base_prob(y), dim=0) # n_b x n_samples, m return lp
[docs] @staticmethod def inverse(x: Tensor) -> Tensor: return -x
[docs] @staticmethod def gmul(x: Tensor, y: Tensor) -> Tensor: return x + y
[docs] @staticmethod def distance(x: Tensor, y: Tensor, ell: Optional[Tensor] = None) -> Tensor: # distance = 2 - 2 cos(x-y) # here we use the identity: cox(x-y) = cos(x)cos(y) + sin(x)sin(y) d = x.shape[-2] if ell is None: ell = torch.ones(1, 1, 1) cx = torch.cos(x) / ell #(... n x d x mx) cy = torch.cos(y) / ell sx = torch.sin(x) / ell sy = torch.sin(y) / ell z1_ = torch.cat([cx, sx], dim=-2) #(... n x 2d x mx) z2_ = torch.cat([cy, sy], dim=-2) #(... n x 2d x mx) const = d * ell.square().reciprocal().mean( -2) # (1/n x 1/d x 1) -> (1/n x 1) res = 2 * (const[..., None] - z1_.transpose(-1, -2).matmul(z2_)) res.clamp_min_(0) return res