(Bayesian) GPFA

Kristopher T. Jensen & Guillaume Hennequin (July 8, 2021)

In this short example notebook, we fit Gaussian Process Factor Analysis (GPFA) to neural recordings from a continuous primate reaching task and run a couple of simple analyses. While the original GPFA framework was developed by Byron Yu et al. (2009), here we fit a recent Bayesian extension by Jensen & Kao et al. (2021) that relies on automatic relevance determination to also learn the dimensionality of the latent embedding space.

Bayesian GPFA is defined by the following generative model:

\({\bf Y} \in \mathbb{R}^{N \times T}\) (data)
\({\bf X} \in \mathbb{R}^{D \times T}\) (latent variables)
\({\bf C} \in \mathbb{R}^{N \times D}\) (readout matrix)
\({\bf x}_d \sim \mathcal{GP}(0, K_{\text{RBF}})\)
\(c_{nd} \sim \mathcal{N}(c_{nd}; \mu = 0, \sigma^2 = s_d^2)\)
\({y}_{nt |{\bf C}, {\bf X}} = p(y_{nt} | ({\bf CX})_{nt})\)
We learn the scales \(\{ s_d \}\) by maximizing a lower bound (ELBO) on the model log marginal likelihood:
$:nbsphinx-math:mathcal{L} \leq `:nbsphinx-math:log p ({:nbsphinx-math:bf Y`}) = \log `:nbsphinx-math:int p({:nbsphinx-math:bf Y`} | {\bf C}, {\bf X}) p({\bf C}) p({\bf X}) , d{\bf C} , d{\bf X} $
This is achieved using variational inference which also provides an estimate of the posterior distribution over our latent variables that we can use for further analyses: \(q({\bf X}) \approx p({\bf X} | {\bf Y})\).

See Jensen & Kao et al. (2021) for further details about the generative model and inference procedure.

We start by downloading an example dataset which was originally recorded by O’Doherty et al. (2018). Here we consider a single recording session where we have binned the data in 25 ms bins in advance. We have put this data on google drive for ease of access in this tutorial; note that the original dataset is available from https://zenodo.org/record/3854034#.YNCEy5P0nUI.

[1]:
!mkdir -p data
!wget --no-check-certificate 'https://drive.google.com/u/2/uc?id=1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-&export=download' -O data/test_data.tar.gz
--2022-03-18 15:16:34--  https://drive.google.com/u/2/uc?id=1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-&export=download
Resolving drive.google.com (drive.google.com)... 142.250.178.14, 2a00:1450:4009:815::200e
Connecting to drive.google.com (drive.google.com)|142.250.178.14|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://drive.google.com/uc?id=1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-&export=download [following]
--2022-03-18 15:16:34--  https://drive.google.com/uc?id=1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-&export=download
Reusing existing connection to drive.google.com:443.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-0o-08-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/g3tq2qr01nd9huo6vr1po3c0ird1gh7q/1647616575000/03061924325282805644/*/1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-?e=download [following]
Warning: wildcards not supported in HTTP.
--2022-03-18 15:16:43--  https://doc-0o-08-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/g3tq2qr01nd9huo6vr1po3c0ird1gh7q/1647616575000/03061924325282805644/*/1kYJHADLpUVtBnLxlphk1ff3AC5UB21N-?e=download
Resolving doc-0o-08-docs.googleusercontent.com (doc-0o-08-docs.googleusercontent.com)... 172.217.169.65, 2a00:1450:4009:819::2001
Connecting to doc-0o-08-docs.googleusercontent.com (doc-0o-08-docs.googleusercontent.com)|172.217.169.65|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4041709 (3.9M) [application/x-gzip]
Saving to: ‘data/test_data.tar.gz’

data/test_data.tar. 100%[===================>]   3.85M  --.-KB/s    in 0.08s

2022-03-18 15:16:43 (50.8 MB/s) - ‘data/test_data.tar.gz’ saved [4041709/4041709]

We proceed to unzip the data and check that it has been correctly downloaded and decompressed. We should now have a ~110mb file named ‘Doherty_example.pickled’.

[2]:
!tar -xzvf data/test_data.tar.gz --directory data
!ls -ltrh data
Doherty_example.pickled
total 114M
-rw-r--r-- 1 tck29 tck29grp 110M Jun 18  2021 Doherty_example.pickled
-rw-r--r-- 1 tck29 tck29grp 3.9M Mar 18 15:16 test_data.tar.gz

We proceed to install the Bayesian GPFA implementation used in Jensen & Kao et al. This is freely available but the codebase is still under active development with ongoing work on other latent variable models (here we use the ‘bGPFA’ branch and ignore other models).

That sets us up to actually run the code! We start by loading a few packages and setting some default parameters.

[3]:
import torch
import mgplvm as mgp
import numpy as np
import matplotlib.pyplot as plt
import pickle
import time
from sklearn.decomposition import FactorAnalysis
from sklearn.linear_model import LinearRegression, Ridge
from scipy.interpolate import CubicSpline
from scipy.ndimage import gaussian_filter1d
plt.rcParams['font.size'] = 20
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
np.random.seed(0)
torch.manual_seed(0)
device = mgp.utils.get_device() # use GPU if available, otherwise CPU
loading

We proceed to load the data which we stored in a ‘pickle’ format. While we could fit the full 30 minute dataset, it would take a bit too long for a real-time tutorial so we subsample both the number of time points and neurons, restricting our analysis to a single epoch of ~20 reaches and only considering neurons with high firing rates.

[4]:
data = pickle.load(open('data/Doherty_example.pickled', 'rb')) # load example data
binsize = 25 # binsize in ms
timepoints = np.arange(3000, 4500) #subsample ~40 seconds of data so things will run somewhat quicker
fit_data = {'Y': data['Y'][..., timepoints], 'locs': data['locs'][timepoints, :], 'targets': data['targets'][timepoints, :], 'binsize': binsize}
Y = fit_data['Y'] # these are the actual recordings and is the input to our model
targets = fit_data['targets'] # these are the target locations
locs = fit_data['locs'] # these are the hand positions

Y = Y[:, np.mean(Y,axis = (0, 2))/0.025 > 8, :] #subsample highly active neurons so things will run a bit quicker
ntrials, n, T = Y.shape # Y should have shape: [number of trials (here 1) x neurons x time points]
data = torch.tensor(Y).to(device) # put the data on our GPU/CPU
ts = np.arange(Y.shape[-1]) #much easier to work in units of time bins here
fit_ts = torch.tensor(ts)[None, None, :].to(device) # put our time points on GPU/CPU

# finally let's just identify bins where the target changes
deltas = np.concatenate([np.zeros(1), np.sum(np.abs(targets[1:, :] - targets[:-1, :]), axis = 1)])
switches = np.where(deltas > 1e-5)[0] # change occurs during time bin s
dswitches = np.concatenate([np.ones(1)*10, switches[1:] - switches[:-1]]) # when the target changes during a bin there will be two discontinuities
inds = np.zeros(len(switches)).astype(bool)
inds[dswitches > 1.5] = 1 # index of the bin where the target changes or the first bin with a new target
switches = switches[inds]

We also plot the data to visually inspect how the firing rates of different neurons seem to (co)vary over time.

[5]:
### plot the activity we just loaded ###
plt.figure(figsize = (12, 6))
plt.imshow(Y[0, ...], cmap = 'Greys', aspect = 'auto', vmin = np.quantile(Y, 0.01), vmax = np.quantile(Y, 0.99))
plt.xlabel('time')
plt.ylabel('neuron')
plt.title('Raw activity', fontsize = 25)
plt.xticks([])
plt.yticks([])
plt.show()
../_images/notebooks_bGPFA_11_0.png

In the following code snippet, we set a couple of model parameters relating to the optimization process or initialization. Most of the initialization is done directly from the data but it can be useful to include if we have prior knowledge about e.g. the timescale of the behavior we care about.

[6]:
### set some parameters for fitting ###
ell0 = 200/binsize # initial timescale (in bins) for each dimension. This could be the ~timescale of the behavior of interest (otherwise a few hundred ms is a reasonable default)
rho = 2 # sets the intial scale of each latent (s_d in Jensen & Kao). rho=1 is a natural choice with Gaussian noise; less obvious with non-Gaussian noise but rho=1-5 works well empirically.
max_steps = 1001 # number of training iterations
n_mc = 5 # number of monte carlo samples per iteration
print_every = 100 # how often we print training progress
d_fit = 10 # lets fit up to 10 latent dimensions (in theory this could be up to the number of neurons; should be thought of as an upper bound to how high-dimensional the activity is)

Having specified our parameters, we can construct the bGPFA model. In this particular library, we need to separately specify (i) the noise model, (ii) the latent manifold (see Jensen et al. 2020 for LVMs on non-Euclidean manifolds), (iii) the prior and variational distribution (for GPFA, these are both Gaussian processes), and (iv) the observation model (for GPFA this is linear but see e.g. Wu et al. 2017 for non-linear LVMs).

For this dataset we use a negative binomial noise model which contains the Poisson model as a special case but also allows for overdispersed firing statistics.

[7]:
### construct the actual model ###
ntrials, n, T = Y.shape # Y should have shape: [number of trials (here 1) x neurons x time points]
lik = mgp.likelihoods.NegativeBinomial(n, Y=Y) # we use a negative binomial noise model in this example (recommended for ephys data)
manif = mgp.manifolds.Euclid(T, d_fit) # our latent variables live in a Euclidean space for bGPFA (see Jensen et al. 2020 for alternatives)
var_dist = mgp.rdist.GP_circ(manif, T, ntrials, fit_ts, _scale=1, ell = ell0) # circulant variational GP posterior (c.f. Jensen & Kao et al. 2021)
lprior = mgp.lpriors.Null(manif) # here the prior is defined implicitly in our variational distribution, but if we wanted to fit e.g. Factor analysis this would be a Gaussian prior
mod = mgp.models.Lvgplvm(n, T, d_fit, ntrials, var_dist, lprior, lik, Y = Y, learn_scale = False, ard = True, rel_scale = rho).to(device) #create bGPFA model with ARD

This finally sets us up to train the model! We will train it for only 1000 iterations and with a fairly high learning rate so it will finish in a reasonable amount of time. We also define a function ‘cb()’ to intermittently print the learned scale parameters \(\{ s_d \}\) which will go to zero for dimensions that are discarded by the ARD procedure.

[8]:
### training will proceed for 1000 iterations (this takes ~2 minutes) ###
t0 = time.time()
def cb(mod, i, loss):
    """here we construct an (optional) function that helps us keep track of the training"""
    if i % print_every == 0:
        sd = np.log(mod.obs.dim_scale.detach().cpu().numpy().flatten())
        print('iter:', i, 'time:', str(round(time.time()-t0))+'s', 'log scales:', np.round(sd[np.argsort(-sd)], 1))

# helper function to specify training parameters
train_ps = mgp.crossval.training_params(max_steps = max_steps, n_mc = n_mc, lrate = 7.5e-2, callback = cb, print_every = np.nan)
print('fitting', n, 'neurons and', T, 'time bins for', max_steps, 'iterations')
mod_train = mgp.crossval.train_model(mod, data, train_ps)
fitting 93 neurons and 1500 time bins for 1001 iterations
iter: 0 time: 1s log scales: [-1.4 -2.  -2.2 -2.4 -2.4 -2.5 -2.6 -2.6 -2.7 -2.8]
iter: 100 time: 12s log scales: [-1.5 -2.  -2.2 -2.4 -2.4 -2.5 -2.6 -2.7 -2.7 -2.8]
iter: 200 time: 22s log scales: [-1.6 -1.7 -2.1 -2.3 -2.4 -2.5 -2.7 -2.9 -2.9 -3. ]
iter: 300 time: 35s log scales: [-1.4 -1.8 -2.  -2.5 -2.5 -2.9 -2.9 -3.4 -3.5 -3.5]
iter: 400 time: 46s log scales: [-1.3 -1.9 -2.  -2.6 -2.7 -3.2 -3.5 -4.3 -4.3 -4.4]
iter: 500 time: 58s log scales: [-1.3 -1.9 -2.  -2.7 -2.8 -3.5 -4.4 -4.7 -4.8 -4.8]
iter: 600 time: 70s log scales: [-1.3 -2.  -2.  -2.7 -2.8 -4.1 -4.8 -5.  -5.1 -5.1]
iter: 700 time: 81s log scales: [-1.3 -2.  -2.  -2.7 -2.8 -4.7 -5.1 -5.2 -5.3 -5.3]
iter: 800 time: 95s log scales: [-1.3 -2.  -2.  -2.7 -2.8 -5.  -5.3 -5.4 -5.5 -5.5]
iter: 900 time: 108s log scales: [-1.3 -2.  -2.1 -2.7 -2.8 -5.2 -5.5 -5.5 -5.6 -5.6]
iter: 1000 time: 122s log scales: [-1.3 -2.  -2.1 -2.7 -2.8 -5.4 -5.6 -5.6 -5.7 -5.8]

Now we’re ready to analyze our new model. We start by plotting the posterior mean against the prior scales \({s_d}\) to see how informative the different latent dimensions are (upper right corner indicates more informative dimensions).

[9]:
### we start by plotting 'informative' and 'discarded' dimensions ###
print('plotting informative and discarded dimensions')
dim_scales = mod.obs.dim_scale.detach().cpu().numpy().flatten() #prior scales (s_d)
dim_scales = np.log(dim_scales) #take the log of the prior scales
nus = np.sqrt(np.mean(mod.lat_dist.nu.detach().cpu().numpy()**2, axis = (0, -1))) #magnitude of the variational mean
plt.figure()
plt.scatter(dim_scales, nus, c = 'k', marker = 'x', s = 80) #top right corner are informative, lower left discarded
plt.xlabel(r'$\log \, s_d$')
plt.ylabel('latent mean scale', labelpad = 5)
plt.show()
plotting informative and discarded dimensions
../_images/notebooks_bGPFA_19_1.png

We proceed to plot the posterior latent mean in the two most informative dimensions (i.e. the ones with the highest \(s_d\)) for a subset of the behavior. We contrast this with FA which yields discontinuous trajectories dominated by noise since the model is unable to share information across time bins.

[10]:
### plot the inferred latent trajectories
print('plotting latent trajectories')
X = mod.lat_dist.lat_mu.detach().cpu().numpy()[0, ...] # extract inferred latents ('mu' has shape (ntrials x T x d_fit))
X = X[..., np.argsort(-dim_scales)] # only consider the two most informative dimensions (c.f. Jensen & Kao)
tplot = np.arange(300, 400) # let's only plot a shorter period (here 2.s) so it doesn't get too cluttered

# fit FA for comparison
fa = FactorAnalysis(2)
Xfa = fa.fit_transform(np.sqrt(Y[0, ...].T)) # sqrt the counts for variance stabilization (c.f. Yu et al. 2009)

i1, i2 = 2, 3 # which dimensions to plot
fig, axs = plt.subplots(1, 2, figsize = (10, 5))
axs[0].scatter(X[tplot, i1], X[tplot, i2], c = tplot, cmap = 'coolwarm', s = 80) # plot bGPFA latents
axs[1].scatter(Xfa[tplot, 0], Xfa[tplot, 1], c = tplot, cmap = 'coolwarm', s = 80) # plot FA latents
for ax in axs:
    ax.set_xlabel('latent dim 1')
    ax.set_ylabel('latent dim 2')
    ax.set_xticks([])
    ax.set_yticks([])
axs[0].set_title('Bayesian GPFA')
axs[1].set_title('factor analysis')
plt.show()

# let's also print the learned timescales (sorted by the prior scales s_d)
taus = mod.lat_dist.ell.detach().cpu().numpy().flatten()[np.argsort(-dim_scales)]*binsize
print('learned timescales (ms):', np.round(taus).astype(int))
plotting latent trajectories
../_images/notebooks_bGPFA_21_1.png
learned timescales (ms): [135  98 311 626 254 502 236 294 247 321]

To make sense of our latent trajectories, it may be useful to also take into account what’s actually occurring at the level of behavior. We therefore start by visualizing the hand kinematics during the period of reaching considered for the latent trajectories above. We can visualize both the position of the hand as it changes over time as well as the instantaneous velocity in the x and y directions.

[11]:
ts = np.arange(Y.shape[-1])*fit_data['binsize'] # measured in ms
cs = CubicSpline(ts, locs) # fit cubic spline to behavior
vels = cs(ts, 1) # velocity (first derivative)

fig, axs = plt.subplots(1, 2, figsize = (10, 5))
axs[0].plot(locs[tplot, 0], locs[tplot, 1], 'k-') # plot position

for s in switches: # plot the targets
  if s in tplot:
    axs[0].scatter([targets[s+1, 0]], [targets[s+1, 1]], marker = 'x', c = 'r')
axs[0].legend(['hand', 'target'], frameon = False)

axs[1].plot(vels[tplot, 0], vels[tplot, 1], 'k-') # plot velocity
for ax in axs:
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xticks([])
    ax.set_yticks([])
axs[0].set_title('position')
axs[1].set_title('velocity')
plt.show()
../_images/notebooks_bGPFA_23_0.png

Finally we might ask whether our latent trajectories can predict the behavioral output. To investigate this, we train a linear decoder to predict the hand kinematics from the inferred firing rates \(\hat{\bf Y}\). This can also be seen as a non-linear decoder from the latent trajectories \({\bf X}\) where the first layer of the decoder is giving by the learned observation model. We fit a decoder with different delays between neural activity and behavior from -150ms (activity lags behavior) to +250ms (activity precedes behavior). We do this to find the ‘optimal delay’ between neural activity and behavior and show that cortical activity seems to predict behavior ~100 ms into the future.

[12]:
### finally let's do a simple decoding analysis ###
print('running decoding analysis')
Ypreds = [] # decode from the inferred firing rates (this is a non-linear decoder from latents)
query = mod.lat_dist.lat_mu.detach().transpose(-1, -2).to(device)  # (ntrial, d_fit, T)
for i in range(10): # loop over mc samples to avoid memory issues
    Ypred = mod.svgp.sample(query, n_mc=100, noise=False)
    Ypred = Ypred.detach().mean(0).cpu().numpy()  # (ntrial x n x T)
    Ypreds.append(Ypred)
Ypred = np.mean(np.array(Ypreds), axis = (0,1)).T # T x n

delays = np.linspace(-150, 250, 50) # consider different behavioral delays
performance = np.zeros((len(delays), 2)) # model performance
for idelay, delay in enumerate(delays):
    vels = cs(ts+delay, 1) # velocity at time+delay
    for itest, Ytest in enumerate([Ypred]): # bGPFA
      regs = [Ridge(alpha=1e-3).fit(Ytest[::2, :], vels[::2, i]) for i in range(2)] # fit x and y vel on half the data
      scores = [regs[i].score(Ytest[1::2, :], vels[1::2, i]) for i in range(2)] # score x and y vel on the other half
      performance[idelay, itest] = np.mean(scores) # save performance
print('plotting decoding')
plt.figure()
plt.plot(delays, performance[:, 0], 'k-')
plt.axvline(delays[np.argmax(performance[:, 0])], color = 'b', ls = '--')
plt.xlim(delays[0], delays[-1])
plt.xlabel('delay (ms)')
plt.ylabel('kinematic decoding')
plt.show()

running decoding analysis
plotting decoding
../_images/notebooks_bGPFA_25_1.png

In Jensen & Kao et al. (2021), we carry out further analyses on multi-region recordings, preparatory dynamics and reaction times for bGPFA models fitted to the full reaching dataset which the interested reader can look at further. However, we hope that this short notebook has given an introduction to GPFA and its Bayesian extension as well as some insight into the possible use cases for such models.