⚙️ JPC from Scratch¤
This notebook is a walk-through of how Predictive Coding (PC) is implemented in JPC. It was developed as a lab session of an MSc course at the University of Sussex. We are going to implement all the core functionalities of JPC from scratch, building towards the training of a simple feedforward network to classify MNIST.
If you're not familiar with JAX, have a look at their docs, but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: (i) Equinox, which allows you to define neural nets with PyTorch-like syntax; and (ii) Optax, which provides a range of common machine learning optimisers such as gradient descent and Adam.
Installations & imports¤
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jax.random as jr
import jax.numpy as jnp
from jax import vmap, grad
from jax.tree_util import tree_map
import equinox as eqx
import equinox.nn as nn
from equinox import filter_grad
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import warnings
warnings.simplefilter("ignore")
Hyperparameters¤
We define some global parameters related to the data, network, optimisers, etc.
SEED = 827
INPUT_DIM, OUTPUT_DIM = 28*28, 10
NETWORK_WIDTH = 300
ACTIVITY_LR = 1e-1
INFERENCE_STEPS = 20
PARAM_LR = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 50
N_TRAIN_ITERS = 500
Dataset¤
Some utils to fetch MNIST.
def get_mnist_loaders(batch_size):
train_data = MNIST(train=True, normalise=True)
test_data = MNIST(train=False, normalise=True)
train_loader = DataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return train_loader, test_loader
class MNIST(datasets.MNIST):
def __init__(self, train, normalise=True, save_dir="data"):
if normalise:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307), std=(0.3081)
)
]
)
else:
transform = transforms.Compose([transforms.ToTensor()])
super().__init__(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10):
arr = torch.eye(n_classes)
return arr[labels]
PC energy¤
First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume * a dirac delta (point mass) posterior and * a hierarchical Gaussian generative model,
we get the standard PC energy
\begin{equation} \mathcal{F} = \frac{1}{2N}\sum_{i=1}^{N} \sum_{\ell=1}^L ||\mathbf{z}{\ell, i} - f\ell(W_\ell \mathbf{z}{\ell-1, i} + \mathbf{b}\ell)||^2_2 \end{equation} which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple (\(N\)) data points and biases \(\mathbf{b}_\ell\).
🤔 Food for thought: Think about how the form of this energy could change depending on other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding by Salvatori et al. (2022).
Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer.
NOTE: below we use vmap, one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See their docs for more details.
def pc_energy_fn(model, activities, input, output):
batch_size = output.shape[0]
n_activity_layers = len(activities) - 1
n_layers = len(model) - 1
eL = output - vmap(model[-1])(activities[-2])
energies = [jnp.sum(eL ** 2)]
for act_l, net_l in zip(
range(1, n_activity_layers),
range(1, n_layers)
):
err = activities[act_l] - vmap(model[net_l])(activities[act_l - 1])
energies.append(jnp.sum(err ** 2))
e1 = activities[0] - vmap(model[0])(input)
energies.append(jnp.sum(e1 ** 2))
return jnp.sum(jnp.array(energies)) / batch_size
Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and Tanh activations. Note that we split the model into different parts with nn.Sequential to define the activities which PC will optimise over (during inference, more on this below).
❓ Question: Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this?
# jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html)
key = jr.PRNGKey(SEED)
subkeys = jr.split(key, 3)
model = [
nn.Sequential(
[
nn.Linear(INPUT_DIM, NETWORK_WIDTH, key=subkeys[0]),
nn.Lambda(jnp.tanh)
],
),
nn.Sequential(
[
nn.Linear(NETWORK_WIDTH, NETWORK_WIDTH, key=subkeys[1]),
nn.Lambda(jnp.tanh)
],
),
nn.Linear(NETWORK_WIDTH, OUTPUT_DIM, key=subkeys[2]),
]
model
[Sequential(
layers=(
Linear(
weight=f32[300,784],
bias=f32[300],
in_features=784,
out_features=300,
use_bias=True
),
Lambda(fn=<wrapped function tanh>)
)
),
Sequential(
layers=(
Linear(
weight=f32[300,300],
bias=f32[300],
in_features=300,
out_features=300,
use_bias=True
),
Lambda(fn=<wrapped function tanh>)
)
),
Linear(
weight=f32[10,300],
bias=f32[10],
in_features=300,
out_features=10,
use_bias=True
)]
The last thing we need is to initialise the activities. For this, we will use a feedforward pass as often done in practice.
❓ Question: Can you think of other ways of initialising the activities?
def init_activities_with_ffwd(model, input):
activities = [vmap(model[0])(input)]
for l in range(1, len(model)):
layer_output = vmap(model[l])(activities[l - 1])
activities.append(layer_output)
return activities
Let's test it on an MNIST sample.
# get a data sample
train_loader, test_loader = get_mnist_loaders(BATCH_SIZE)
img_batch, label_batch = next(iter(train_loader))
# we need to turn the torch.Tensor data into numpy arrays for jax
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
# let's check our initialised activities
activities = init_activities_with_ffwd(model, img_batch)
for i, a in enumerate(activities):
print(f"activity z at layer {i+1}: {a.shape}")
activity z at layer 1: (64, 300)
activity z at layer 2: (64, 300)
activity z at layer 3: (64, 10)
Ok so now we have everything to test our PC energy function: model, activities, and some data.
pc_energy_fn(
model=model,
activities=activities,
input=img_batch,
output=label_batch
)
Array(1.2335204, dtype=float32)
And it works!
Energy gradients¤
How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning).
So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (see the docs). If you're familiar with PyTorch, you are probably used to loss.backward() for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below.
# note how close this code is to the maths
# this can be read as "take the gradient of the energy...
# ...with the respect to the 2nd argument (the activities)
def compute_activity_grad(model, activities, input, output):
return grad(pc_energy_fn, argnums=1)(
model,
activities,
input,
output
)
Let's test this out.
dFdzs = compute_activity_grad(
model=model,
activities=activities,
input=img_batch,
output=label_batch
)
for i, dFdz in enumerate(dFdzs):
print(f"activity gradient dFdz shape at layer {i+1}: {dFdz.shape}")
activity gradient dFdz shape at layer 1: (64, 300)
activity gradient dFdz shape at layer 2: (64, 300)
activity gradient dFdz shape at layer 3: (64, 10)
Now we do the same and take the gradient of the energy with respect to the parameters.
Technical note: below we use Equinox's convenience function filter_grad rather than JAX's native grad. This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad automatically filters these non-differentiable objects for us, while grad alone would throw an error.
# note that, compared to the previous function,...
# ...we just change the argument with respect to which...
# ...we are differentiating (the first, or in this case the model)
def compute_param_grad(model, activities, input, output):
return filter_grad(pc_energy_fn)(
model,
activities,
input,
output
)
And let's test it.
param_grads = compute_param_grad(
model=model,
activities=activities,
input=img_batch,
output=label_batch
)
Updates¤
Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see the docs for more details).
These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data).
@eqx.filter_jit
def update_activities(model, activities, optim, opt_state, input, output):
activity_grads = compute_activity_grad(
model=model,
activities=activities,
input=input,
output=output
)
activity_updates, activity_opt_state = optim.update(
updates=activity_grads,
state=opt_state,
params=activities
)
activities = eqx.apply_updates(
model=activities,
updates=activity_updates
)
return activities, optim, opt_state
# note that the only difference with the above function is...
# ...the variable we are updating (parameters vs activities)
@eqx.filter_jit
def update_params(model, activities, optim, opt_state, input, output):
param_grads = compute_param_grad(
model=model,
activities=activities,
input=input,
output=output
)
param_updates, param_opt_state = optim.update(
updates=param_grads,
state=opt_state,
params=model
)
model = eqx.apply_updates(
model=model,
updates=param_updates
)
return model, optim, opt_state
Putting everything together: Training and testing¤
Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop.
# note: the test accuracy computation below could be sped up...
# ...with jit in a separate function
def evaluate(model, test_loader):
avg_test_acc = 0
for test_iter, (img_batch, label_batch) in enumerate(test_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
preds = init_activities_with_ffwd(model, img_batch)[-1]
test_acc = jnp.mean(
jnp.argmax(label_batch, axis=1) == jnp.argmax(preds, axis=1)
) * 100
avg_test_acc += test_acc
return avg_test_acc / len(test_loader)
def train(
model,
activity_lr,
inference_steps,
param_lr,
batch_size,
test_every,
n_train_iters
):
# define optimisers for activities and parameters
activity_optim = optax.sgd(activity_lr)
param_optim = optax.adam(param_lr)
param_opt_state = param_optim.init(eqx.filter(model, eqx.is_array))
train_loader, test_loader = get_mnist_loaders(batch_size)
for train_iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
# initialise activities
activities = init_activities_with_ffwd(model, img_batch)
activity_opt_state = activity_optim.init(activities)
# calculate loss
train_loss = jnp.mean((label_batch - activities[-1])**2)
# inference
for t in range(inference_steps):
activities, activity_optim, activity_opt_state = update_activities(
model=model,
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
input=img_batch,
output=label_batch
)
# learning
model, param_optim, param_opt_state = update_params(
model=model,
activities=activities, # note how we use the optimised activities
optim=param_optim,
opt_state=param_opt_state,
input=img_batch,
output=label_batch
)
if ((train_iter+1) % test_every) == 0:
avg_test_acc = evaluate(model, test_loader)
print(
f"Train iter {train_iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (train_iter+1) >= n_train_iters:
break
Run¤
Let's test our implementation.
train(
model=model,
activity_lr=ACTIVITY_LR,
inference_steps=INFERENCE_STEPS,
param_lr=PARAM_LR,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)
Train iter 50, train loss=0.065566, avg test accuracy=72.726364
Train iter 100, train loss=0.046521, avg test accuracy=76.292068
Train iter 150, train loss=0.042710, avg test accuracy=86.568512
Train iter 200, train loss=0.029598, avg test accuracy=89.082535
Train iter 250, train loss=0.031486, avg test accuracy=89.222755
Train iter 300, train loss=0.016624, avg test accuracy=91.296074
Train iter 350, train loss=0.025201, avg test accuracy=92.648239
Train iter 400, train loss=0.018597, avg test accuracy=92.968750
Train iter 450, train loss=0.019027, avg test accuracy=94.130608
Train iter 500, train loss=0.014850, avg test accuracy=93.760017
🥳 Great, we see that our model is learning! This model was not tuned, and you can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps).
Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC. Play around with the notebook examples there where you can learn how to train a variety of PC networks.