Source code for mgplvm.kernels.stationary

import abc
import math
import torch
from torch import nn, Tensor
from ..utils import softplus, inv_softplus
from .kernel import Kernel
from typing import Tuple, List
import numpy as np


[docs]class Stationary(Kernel, metaclass=abc.ABCMeta): def __init__(self, n: int, distance, d=None, ell=None, scale=None, learn_scale=True, Y: np.ndarray = None, eps: float = 1e-6, ell_byneuron: bool = True): """ Parameters ---------- n : int number of batches (neurons) distance : distance function d : Optional[int] dimension of the input variables if provided, there is a separate length scale for each input dimension ell: Optional[np.ndarray] lengthscale hyperparameter it should have dimensions n x d if d is not None and n if d is None scale : Optional[np.ndarray] scale hyperparameter (std) it should have dimension n learn_scale : bool optimises the scale hyperparameter if true Y : Optional[np.ndarray] data matrix used for initializing the scale hyperparameter eps: float minimum ell """ super(Stationary, self).__init__() self.eps = eps if scale is not None: _scale_sqr = torch.tensor(scale, dtype=torch.get_default_dtype()).square() elif Y is not None: _scale_sqr = torch.tensor(1 * np.mean(Y**2, axis=(0, -1))) else: _scale_sqr = torch.ones(n,) self._scale_sqr = nn.Parameter(data=inv_softplus(_scale_sqr), requires_grad=learn_scale) self.ard = (d is not None) if ell is None: if d is None: _ell = inv_softplus(2 * torch.ones(n,)) elif ell_byneuron: assert (d is not None) _ell = inv_softplus(2 * torch.ones(n, d)) else: _ell = inv_softplus(2 * torch.ones(1, d)) else: if d is not None: assert ell.shape[-1] == d _ell = inv_softplus( torch.tensor(ell, dtype=torch.get_default_dtype())) self._ell = nn.Parameter(data=_ell - self.eps, requires_grad=True) self.distance = distance
[docs] def diagK(self, x: Tensor) -> Tensor: """ Parameters ---------- x : Tensor input tensor of dims (... n x d x mx) Returns ------- diagK : Tensor diagonal of kernel K(x,x) with dims (... n x mx ) Note ---- For a stationary quad exp kernel, the diagonal is a mx-dimensional vector (scale, scale, ..., scale) """ shp = list(x.shape) del shp[-2] scale_sqr = self.scale_sqr[:, None] return torch.ones(shp).to(scale_sqr.device) * scale_sqr
[docs] def trK(self, x: Tensor) -> Tensor: """ Parameters ---------- x : Tensor input tensor of dims (... n x d x mx) Returns ------- trK : Tensor trace of kernel K(x,x) with dims (... n) Note ---- For a stationary quad exp kernel, the trace is scale * mx """ scale_sqr = self.scale_sqr return torch.ones(x.shape[:-2]).to( scale_sqr.device) * scale_sqr * x.shape[-1]
@property def prms(self) -> Tuple[Tensor, Tensor]: return self.scale_sqr, self.ell @property def scale_sqr(self) -> Tensor: return softplus(self._scale_sqr) @property def scale(self) -> Tensor: return (self.scale_sqr + 1e-20).sqrt() @property def ell(self) -> Tensor: return softplus(self._ell) + self.eps @property def msg(self): return (' scale {:.3f} | ell {:.3f} |').format(self.scale.mean().item(), self.ell.mean().item())
[docs]class QuadExp(Stationary): def __init__(self, n: int, distance, d=None, ell=None, scale=None, learn_scale=True, Y: np.ndarray = None, eps: float = 1e-6, ell_byneuron: bool = True): """ Quadratic exponential kernel Parameters ---------- n : int number of batches (neurons) distance : distance function d : Optional[int] dimension of the input variables if provided, there is a separate length scale for each input dimension ell: Optional[np.ndarray] lengthscale hyperparameter it should have dimensions n x d if d is not None and n if d is None scale : Optional[np.ndarray] scale hyperparameter it should have dimension n learn_scale : bool optimises the scale hyperparameter if true Y : Optional[np.ndarray] data matrix used for initializing the scale hyperparameter eps : float minimum ell """ super(QuadExp, self).__init__(n, distance, d, ell, scale, learn_scale, Y, eps, ell_byneuron)
[docs] def K(self, x: Tensor, y: Tensor) -> Tensor: """ Parameters ---------- x : Tensor input tensor of dims (... n x d x mx) y : Tensor input tensor of dims (... n x d x my) Returns ------- kxy : Tensor quadratic exponential kernel with dims (... n x mx x my) """ scale_sqr, ell = self.prms if self.ard: ell = ell[:, :, None] #(n x d x 1) else: ell = ell[:, None, None] #(n x 1 x 1) distance = self.distance(x, y, ell=ell) # dims (... n x mx x my) kxy = scale_sqr[:, None, None] * torch.exp(-0.5 * distance) return kxy
[docs]class Exp(QuadExp): def __init__(self, n: int, distance, d=None, ell=None, scale=None, learn_scale=True, Y: np.ndarray = None, eps: float = 1E-6): super().__init__(n, distance, d, ell, scale, learn_scale, Y=Y, eps=eps)
[docs] def K(self, x: Tensor, y: Tensor) -> Tensor: """ Parameters ---------- x : Tensor input tensor of dims (... n x d x mx) y : Tensor input tensor of dims (... n x d x my) Returns ------- kxy : Tensor exponential kernel with dims (... n x mx x my) """ scale_sqr, ell = self.prms if self.ard: ell = ell[:, :, None] #(n x d x 1) / (1 x d x 1) else: ell = ell[:, None, None] #(n x 1 x 1) distance = self.distance(x, y, ell=ell) # dims (... n x mx x my) # NOTE: distance means squared distance ||x-y||^2 ? stable_distance = torch.sqrt(distance + 1e-20) # numerically stabilized kxy = scale_sqr[:, None, None] * torch.exp(-stable_distance) return kxy
[docs]class Matern(Stationary): def __init__(self, n: int, distance, d=None, nu=1.5, ell=None, scale=None, learn_scale=True, Y=None, eps: float = 1E-6): ''' Parameters ---------- n : int number of neurons/readouts distance : a squared distance function Note ----- based on the gpytorch implementation: https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/kernels/matern_kernel.py ''' super().__init__(n, distance, d, ell, scale, learn_scale, Y=Y, eps=eps) if nu not in (0.5, 1.5, 2.5): raise Exception("only nu=0.5, 1.5, 2.5 implemented") self.nu = nu
[docs] def K(self, x: Tensor, y: Tensor) -> Tensor: """ Parameters ---------- x : Tensor input tensor of dims (... n x d x mx) y : Tensor input tensor of dims (... n x d x my) Returns ------- kxy : Tensor matern kernel with dims (... n x mx x my) """ scale_sqr, ell = self.prms if self.ard: ell = ell[:, :, None] else: ell = ell[:, None, None] distance = (self.distance(x, y, ell=ell) + 1E-20).sqrt() # NOTE: distance means squared distance ||x-y||^2 ? z1 = torch.exp(-math.sqrt(self.nu * 2) * distance) if self.nu == 0.5: z2 = 1 elif self.nu == 1.5: z2 = (math.sqrt(3) * distance).add(1) elif self.nu == 2.5: z2 = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2) return scale_sqr[:, None, None] * z1 * z2
@property def msg(self): return (' nu {:.1f} | scale {:.3f} | ell {:.3f} |').format( self.nu, self.scale.mean().item(), self.ell.mean().item())