Applying mGPLVM to synthetic neural data generated from circular latents

Calvin Kao & Kris Jensen (16 March 2022)

[1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import optim

import mgplvm as mgp

torch.manual_seed(1)
np.random.seed(0)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
loading

Generate synthetic data

Here, we generate synthetic neural data from the following generative model:

\[\begin{split}\begin{align} \theta_{t} &\sim U(0, 2\pi)\\ f_{it} &\sim \mathcal{GP}(0, K(\theta, \theta'))\\ y_{it} &\sim \mathcal{N}(f_{it}, \sigma_i^2 I) \end{align}\end{split}\]
[2]:
d = 1  # dimensions of latent space, here we just have a ring i.e. T(1)
n = 50  # number of neurons
m = 100  # number of conditions / time points
n_z = 50  # number of inducing points
n_samples = 1  # number of samples
[3]:
gen = mgp.syndata.Gen(
    mgp.syndata.Torus(d), n, m, variability=0.1, l=0.5, n_samples=n_samples, sigma=0.1
)
Y = gen.gen_data()
print(f"Dimension of neural data Y: {Y.shape}")
Dimension of neural data Y: (1, 50, 100)
[4]:
plt.figure()
plt.hist(gen.gs[0][0, :, 0], color="k", rwidth=0.95)
plt.xlim(0, 2 * np.pi)
plt.ylabel("counts")
plt.xlabel("ground-truth $\\theta$")
plt.show()
../_images/notebooks_synthetic_torus_5_0.png
[5]:
print(f"Raw data")
plt.figure()
Yplot = Y[0, :, :]
plt.imshow(Yplot, cmap="Greys", aspect="auto")
plt.xlabel("time")
plt.ylabel("neuron")
plt.xlim(0, m)
plt.ylim(0, n)
plt.xticks([])
plt.yticks([])
plt.show()
Raw data
../_images/notebooks_synthetic_torus_6_1.png

Let’s now construct the mGPLVM model!

[6]:
def build_model():
    # specify manifold, kernel and rdist
    manif = mgp.manifolds.Torus(m, d)  # latent distribution manifold
    lat_dist = mgp.rdist.ReLie(manif, m, n_samples)  # construct ReLie distribution
    # Note: we construct the kernel and likelihood by passing the data in for initialization
    kernel = mgp.kernels.QuadExp(
        n, manif.distance
    )  # Use an exponential quadratic (RBF) kernel
    lik = mgp.likelihoods.Gaussian(n)  # Gaussian likelihood
    lprior = mgp.lpriors.Uniform(manif)  # Prior on the manifold distribution
    z = manif.inducing_points(n, n_z)  # build inducing points
    model = mgp.models.SvgpLvm(
        n, m, n_samples, z, kernel, lik, lat_dist, lprior, whiten=True
    ).to(device)
    return model
[7]:
data = torch.tensor(Y, device=device, dtype=torch.get_default_dtype())
model = build_model()

train_opts = {
    "lrate": 5e-2,
    "max_steps": 1000,
    "n_mc": 64,
    "print_every": 100,
    "burnin": 30 / 5e-2,
    "optimizer": optim.Adam,
}

# train model
progress = mgp.optimisers.svgp.fit(data, model, **train_opts)
iter   0 | elbo -5.604 | kl  0.002 | loss  5.604 | |mu| 0.110 | sig 1.500 | scale 1.000 | ell 2.000 | lik_sig 1.000 |
iter 100 | elbo -1.012 | kl  0.003 | loss  1.013 | |mu| 0.655 | sig 1.342 | scale 0.993 | ell 2.024 | lik_sig 0.948 |
iter 200 | elbo -0.686 | kl  0.008 | loss  0.688 | |mu| 0.751 | sig 1.014 | scale 0.984 | ell 2.073 | lik_sig 0.645 |
iter 300 | elbo -0.136 | kl  0.026 | loss  0.146 | |mu| 1.366 | sig 0.397 | scale 0.963 | ell 2.197 | lik_sig 0.272 |
iter 400 | elbo  0.343 | kl  0.052 | loss -0.318 | |mu| 1.711 | sig 0.108 | scale 0.912 | ell 1.823 | lik_sig 0.153 |
iter 500 | elbo  0.528 | kl  0.066 | loss -0.490 | |mu| 1.746 | sig 0.054 | scale 0.875 | ell 1.471 | lik_sig 0.129 |
iter 600 | elbo  0.568 | kl  0.074 | loss -0.522 | |mu| 1.749 | sig 0.038 | scale 0.846 | ell 1.296 | lik_sig 0.124 |
iter 700 | elbo  0.564 | kl  0.078 | loss -0.511 | |mu| 1.742 | sig 0.031 | scale 0.807 | ell 1.224 | lik_sig 0.128 |
iter 800 | elbo  0.612 | kl  0.080 | loss -0.552 | |mu| 1.733 | sig 0.028 | scale 0.775 | ell 1.161 | lik_sig 0.111 |
iter 900 | elbo  0.657 | kl  0.082 | loss -0.593 | |mu| 1.739 | sig 0.025 | scale 0.745 | ell 1.096 | lik_sig 0.103 |
[8]:
plt.figure()
plt.plot(progress, "k")
plt.xlabel("training iterations")
plt.ylabel("training loss")
[8]:
Text(0, 0.5, 'training loss')
../_images/notebooks_synthetic_torus_10_1.png

Now that we have fit the model, let’s see if we have correctly inferred the ground-truth latents up to an arbitrary rotational bias. To do this, we plot:

  1. ground-truth latents against inferred latents, and

  2. the pairwise distance between grounth-truth latents against that of inferred latents.

We know we have inferred the latents correctly when both plots fall on a straight line, with the first plot wrapping around at the \((0, 2\pi)\) boundaries:

[9]:
learned_latents = model.lat_dist.prms[0].data.cpu()
true_latents = gen.gs[0]

plearn = (
    model.lat_dist.manif.distance(
        learned_latents.transpose(1, 2), learned_latents.transpose(1, 2)
    )
    .numpy()
    .flatten()
)
ptrue = gen.manifold.manifs[0].distance(true_latents, true_latents).flatten()
[10]:
plt.figure(figsize=(4, 4))
plt.plot(true_latents[0], learned_latents[0], "ko", alpha=0.2)
plt.xlabel("ground truth $\\theta$")
plt.ylabel("inferred $\\theta$")
plt.show()
../_images/notebooks_synthetic_torus_13_0.png
[11]:
plt.figure(figsize=(6, 6))
plt.plot(plearn, ptrue, "ko", markersize=1.5, alpha=0.2)
plt.xlabel("inferred pairwise distance")
plt.ylabel("true pairwise distance")
plt.show()
../_images/notebooks_synthetic_torus_14_0.png
[12]:
def generate_binary_array(n, l):
    if n == 0:
        return l
    else:
        if len(l) == 0:
            return generate_binary_array(n - 1, [np.array([-1]), np.array([1])])
        else:
            return generate_binary_array(
                n - 1,
                (
                    [np.concatenate([i, [-1]]) for i in l]
                    + [np.concatenate([i, [1]]) for i in l]
                ),
            )


def align_torus(x, target):
    target = torch.tensor(target)

    def dist(newmus, params):
        mus = mgp.manifolds.Torus.gmul(newmus, params)
        loss = mgp.manifolds.Torus.distance(mus, target)
        return loss.mean()

    mus = x
    optloss = np.inf

    for coords in generate_binary_array(d, []):
        coords = torch.tensor(coords).reshape(1, d)
        newmus = coords * mus

        for i in range(5):  # random restarts to avoid local minima
            # params = torch.zeros(mod.d)
            params = torch.rand(d) * 2 * np.pi
            params.requires_grad_()
            optimizer = optim.LBFGS([params])

            def closure():
                optimizer.zero_grad()
                loss = dist(newmus, params)
                loss.backward()
                return loss

            optimizer.step(closure)
            loss = closure()
            if loss < optloss:
                optloss = loss
                optcoords = coords
                optparams = params.data.cpu()

    f = lambda x: (mgp.manifolds.Torus.gmul(optcoords * x, optparams) + 2 * np.pi) % (
        2 * np.pi
    )
    return f
[13]:
align = align_torus(learned_latents, true_latents)
[14]:
plt.figure(figsize=(4, 4))
plt.plot(align(learned_latents)[0, :, 0], true_latents[0, :, 0], "ko", alpha=0.2)
plt.show()
../_images/notebooks_synthetic_torus_17_0.png

Let’s now plot the inferred tuning curves!

[15]:
query = torch.tensor(
    np.linspace(0, 2 * np.pi, 50), dtype=torch.get_default_dtype(), device=device
)[None, None, ...]

aligned_query = align(query.cpu())  # align the query for the model
fmean, fvar = model.obs.predict(aligned_query.to(device), full_cov=False)
fstd = fvar.sqrt()
inds = [np.argmin((gen.gprefs[0] - val) ** 2) for val in 0.25 + np.arange(4) * 1.5]
plt.figure(figsize=(8, 6))

for i in inds:
    xs = query.cpu().numpy()
    m, std = [arr.cpu().detach().numpy() for arr in (fmean, fstd)]
    xs = xs[0, 0, :]
    m = m[0, i, :]
    std = std[0, i, :]
    plt.plot(xs, m, "k--")
    plt.fill_between(xs, m - 2 * std, m + 2 * std, color="k", alpha=0.2)
    true_y = gen.gen_data(gs_in=xs[None, ...], sigma=np.zeros((n, 1)))
    plt.plot(xs, true_y[0, i, :], "g-")

plt.xlabel(r"$\theta$ [radian]")
plt.ylabel(r"rate [a.u.]")
plt.xlim(0, 2 * np.pi)
plt.show()
../_images/notebooks_synthetic_torus_19_0.png