Source code for mgplvm.lpriors.torus

import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dists
from ..manifolds import Torus
from .common import Lprior


[docs]class LpriorTorus(Lprior): def __init__(self, manif): if not isinstance(manif, Torus): raise Exception("VonMises prior only works with Tori manifolds") super().__init__(manif)
[docs]class VonMises(LpriorTorus): name = "VonMises" def __init__(self, manif, concentration=None, fixed_concentration=False): super().__init__(manif) d = manif.d concentration = torch.ones( d) if concentration is None else concentration self.concentration = nn.Parameter( data=dists.transform_to( dists.constraints.greater_than_eq(0)).inv(concentration), requires_grad=(not fixed_concentration)) @property def prms(self): return dists.transform_to(dists.constraints.greater_than_eq(0))( self.concentration)
[docs] def forward(self, g, batch_idxs=None): concentration = self.prms ginv = self.manif.inverse(g) dg = self.manif.gmul(ginv[..., 0:-1, :], g[..., 1:, :]) vm = dists.VonMises(loc=torch.zeros(self.d).to(g.device), concentration=concentration) vm_all = dists.Independent(vm, 1) return vm_all.log_prob(dg)
@property def msg(self): concentration = self.prms return ('concentration {:.3f}').format(concentration.item())
[docs]class IARP(LpriorTorus): name = "IARP" def __init__(self, p, manif, mu=None, phi=None, concentration=None, fixed_mu=False, fixed_concentration=False, link="logits"): super().__init__(manif) d = manif.d self.p = p phi = 0.0 * torch.ones(d, p) if phi is None else phi self.phi = nn.Parameter(data=phi, requires_grad=True) if link == "atan": self.link = lambda x: (2 * torch.atan(x)) + np.pi self.inv_link = lambda x: torch.tan((0.5 * x) - np.pi) elif link == "logits": self.link = lambda x: np.pi * 2 * dists.utils.logits_to_probs(x) self.inv_link = lambda x: dists.utils.probs_to_logits(x / (2 * np.pi)) else: raise Exception("Link function not implemented for %s" % link) mu = torch.zeros(d) if mu is None else mu self.mu = nn.Parameter(data=self.inv_link(mu), requires_grad=(not fixed_mu)) concentration = torch.ones( d) if concentration is None else concentration self.concentration = nn.Parameter( data=dists.transform_to( dists.constraints.greater_than_eq(0)).inv(concentration), requires_grad=(not fixed_concentration)) @property def prms(self): concentration = dists.transform_to( dists.constraints.greater_than_eq(0))(self.concentration) mu = self.link(self.mu) return mu, self.phi, concentration
[docs] def forward(self, g, batch_idxs=None): mu, phi, concentration = self.prms p = self.p g = (g - mu) % (np.pi * 2) # make sure it's delta = phi * torch.stack( [self.inv_link(g[..., p - j - 1:-j - 1, :]) for j in range(p)], dim=-1) hat = self.link(delta.sum(-1)) vm = dists.VonMises(loc=hat, concentration=concentration) vm_all = dists.Independent(vm, 1) return vm_all.log_prob(g[..., p:, :])
@property def msg(self): mu, phi, concentration = self.prms return ('mu_avg {:.3f} | phi_avg {:.3f} | concentration {:.3f}').format( torch.mean(mu).item(), torch.mean(phi).item(), concentration.item())
[docs]class LARP(LpriorTorus): name = "LinkedARP" def __init__(self, p, manif, mu=None, phi=None, eta=None, fixed_mu=False, fixed_eta=False, link="logits"): super().__init__(manif) d = manif.d self.p = p phi = 0.0 * torch.ones(d, p) if phi is None else phi self.phi = nn.Parameter(data=phi, requires_grad=True) if link == "atan": self.link = lambda x: (2 * torch.atan(x)) + np.pi self.inv_link = lambda x: torch.tan((0.5 * x) - np.pi) elif link == "logits": self.link = lambda x: np.pi * 2 * dists.utils.logits_to_probs(x) self.inv_link = lambda x: dists.utils.probs_to_logits(x / (2 * np.pi)) else: raise Exception("Linke function not implemented for %s" % link) mu = torch.zeros(d) if mu is None else mu self.mu = nn.Parameter(data=self.inv_link(mu), requires_grad=(not fixed_mu)) eta = torch.ones(d) if eta is None else torch.sqrt(eta) self.eta = nn.Parameter(requires_grad=(not fixed_eta)) @property def prms(self): mu = self.link(self.mu) return mu, self.phi, torch.square(self.eta)
[docs] def forward(self, g, batch_idxs=None): mu, phi, eta = self.prms p = self.p g = (g - mu) % (np.pi * 2) # make sure it's on the circle delta = phi * torch.stack( [self.inv_link(g[..., p - j - 1:-j - 1, :]) for j in range(p)], dim=-1) hat = delta.sum(-1) normal = dists.Normal(hat, scale=torch.sqrt(eta)) normal_all = dists.Independent(normal, 1) return normal_all.log_prob(g[..., p:, :])
@property def msg(self): mu, phi, concentration = self.prms return ('mu_avg {:.3f} | phi_avg {:.3f} | eta {:.3f}').format( torch.mean(mu).item(), torch.mean(phi).item(), concentration.item())