Source code for mgplvm.manifolds.s3

import numpy as np
from . import quaternion
from scipy import special
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from .base import Manifold
from typing import Tuple, Optional, List
from ..inducing_variables import InducingPoints


[docs]class S3(Manifold): # log of the uniform prior (negative log volume) log_uniform = (special.loggamma(2) - np.log(2) - 2 * np.log(np.pi)) def __init__(self, m: int, d: Optional[int] = None): """ Parameters ---------- m : int number of conditions/timepoints d : int latent dimensionality """ super().__init__(d=3) self.m = m self.d2 = 4 # dimensionality of the group parameterization # per condition self.lprior_const = torch.tensor( special.loggamma(2) - np.log(2) - 2 * np.log(np.pi))
[docs] @staticmethod def initialize(initialization, n_samples, m, d, Y): '''initializes latents - can add more exciting initializations as well''' # initialize at identity mudata = torch.zeros(n_samples, m, 4) mudata[:, :, 0] = 1 return mudata
[docs] def inducing_points(self, n, n_z, z=None): if z is None: z = torch.randn(n, self.d2, n_z) z = z / torch.norm(z, dim=1, keepdim=True) return InducingPoints(n, self.d2, n_z, z=z, parameterise=lambda x: self.expmap2(x, dim=-2))
@property def name(self): return 'S(' + str(self.d) + ')'
[docs] def lprior(self, g): return self.lprior_const * torch.ones(g.shape[:-1])
[docs] @staticmethod def parameterise(x) -> Tensor: norms = torch.norm(x, dim=-1, keepdim=True) return x / norms
[docs] @staticmethod def expmap(x: Tensor, dim: int = -1) -> Tensor: '''same as SO(3)''' theta = torch.norm(x, dim=dim, keepdim=True) v = x / theta y = torch.cat((torch.cos(theta), torch.sin(theta) * v), dim=dim) return y # , theta, v
[docs] @staticmethod def expmap2(x: Tensor, dim: int = -1) -> Tensor: return F.normalize(x, dim=dim)
[docs] @staticmethod def logmap(q: Tensor, dim: int = -1) -> Tensor: '''same as SO(3)''' x = q[..., 0] y = torch.norm(q[..., 1:], dim=dim) theta = 2 * torch.atan2(y, x) return theta * (q[..., 1:] / y)
[docs] @staticmethod def inverse(q: Tensor) -> Tensor: return quaternion.conj(q)
[docs] @staticmethod def gmul(x: Tensor, y: Tensor) -> Tensor: '''same as SO(3)''' return quaternion.product(x, y)
[docs] @staticmethod def log_q(log_base_prob, x, d, kmax, dim=-1): ''' theta = |x|/2 ''' theta = torch.norm(x, dim=dim, keepdim=True) #vector magintudes v = x / theta #unit vectors ks = np.arange(-kmax, kmax + 1) zs = np.meshgrid( *(ks for _ in range(1))) #construct equivalent elements zs = np.stack([z.flatten() for z in zs ]).T * 2 * np.pi #need to add multiples of 2pi zs = torch.tensor(zs, dtype=torch.get_default_dtype()).to(theta.device) theta = theta + zs[:, None, None, None, ...] # (nk, n_b, n_samples, m, 1) x = theta * v # |J|->1 as phi -> 0; cap at 1e-5 for numerical stability phi = 2 * theta + 1e-5 #magnitude of rotation l0 = torch.square(phi) l1 = 2 - 2 * torch.cos(phi) # |J^(-1)| = phi^2/(2 - 2*cos(phi)) = 2|x|^2/(1-cos(2|x|)) ljac = torch.log(l0) - torch.log(l1) lp = torch.logsumexp(log_base_prob(x) + ljac[..., 0], dim=0) return lp
[docs] @staticmethod def distance(x: Tensor, y: Tensor, ell: Optional[Tensor] = None) -> Tensor: # distance: 2 - 2 (x dot y) if ell is None: ell = torch.ones(1, 1, 1) z = x.transpose(-1, -2).matmul(y) res = 2 * (1 - z) / ell**2 res.clamp_min_(0) return res