Skip to content

⚙️ JPC from Scratch¤

This notebook is a walk-through of how Predictive Coding (PC) is implemented in JPC. It was developed as a lab session of an MSc course at the University of Sussex. We are going to implement all the core functionalities of JPC from scratch, building towards the training of a simple feedforward network to classify MNIST.

If you're not familiar with JAX, have a look at their docs, but we will explain all the necessary concepts below. JAX is basically numpy for GPUs and other hardware accelerators. We will also use: (i) Equinox, which allows you to define neural nets with PyTorch-like syntax; and (ii) Optax, which provides a range of common machine learning optimisers such as gradient descent and Adam.

Installations & imports¤

python %%capture !pip install torch==2.3.1 !pip install torchvision==0.18.1

```python import jax.random as jr import jax.numpy as jnp from jax import vmap, grad from jax.tree_util import tree_map

import equinox as eqx import equinox.nn as nn from equinox import filter_grad import optax

import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms

import warnings warnings.simplefilter("ignore") ```

Hyperparameters¤

We define some global parameters related to the data, network, optimisers, etc.

```python SEED = 827

INPUT_DIM, OUTPUT_DIM = 28*28, 10 NETWORK_WIDTH = 300

ACTIVITY_LR = 1e-1 INFERENCE_STEPS = 20

PARAM_LR = 1e-3 BATCH_SIZE = 64

TEST_EVERY = 50 N_TRAIN_ITERS = 500 ```

Dataset¤

Some utils to fetch MNIST.

```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader

class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)

def __getitem__(self, index):
    img, label = super().__getitem__(index)
    img = torch.flatten(img)
    label = one_hot(label)
    return img, label

def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```

PC energy¤

First, recall that PC can be derived as a variational inference algorithm under certain assumptions. In particular, if we assume * a dirac delta (point mass) posterior and * a hierarchical Gaussian generative model,

we get the standard PC energy

\begin{equation} \mathcal{F} = \frac{1}{2N}\sum_{i=1}^{N} \sum_{\ell=1}^L ||\mathbf{z}{\ell, i} - f\ell(W_\ell \mathbf{z}{\ell-1, i} + \mathbf{b}\ell)||^2_2 \end{equation} which is just a sum of squared prediction errors at each network layer. Here we are being a little bit more precise than in the lecture, including multiple (\(N\)) data points and biases \(\mathbf{b}_\ell\).

🤔 Food for thought: Think about how the form of this energy could change depending on other assumptions we make about the generative model. See, for example, Learning on Arbitrary Graph Topologies via Predictive Coding by Salvatori et al. (2022).

Let's start by implementing this energy below. The function simply takes the model (with all the parameters), some initialised activities, and some input and output. Given these, it simply sums the prediction error at each layer.

NOTE: below we use vmap, one of the core JAX transforms that allows you to vectorise operations, in this case for multiple data points or over a batch. See their docs for more details.

```python def pc_energy_fn(model, activities, input, output): batch_size = output.shape[0] n_activity_layers = len(activities) - 1 n_layers = len(model) - 1

eL = output - vmap(model[-1])(activities[-2])
energies = [jnp.sum(eL ** 2)]
for act_l, net_l in zip(
        range(1, n_activity_layers),
        range(1, n_layers)
):
    err = activities[act_l] - vmap(model[net_l])(activities[act_l - 1])
    energies.append(jnp.sum(err ** 2))

e1 = activities[0] - vmap(model[0])(input)
energies.append(jnp.sum(e1 ** 2))

return jnp.sum(jnp.array(energies)) / batch_size

```

Now let's test it. To do so, we first need a model. Below we use Equinox to create a simple feedforward network with 2 hidden layers and Tanh activations. Note that we split the model into different parts with nn.Sequential to define the activities which PC will optimise over (during inference, more on this below).

Question: Think about other ways in which we could split the layers, for example by separating the non-linearities. Can you think of potential issues with this?

```python

jax uses explicit random number generators (see https://jax.readthedocs.io/en/latest/random-numbers.html)¤

key = jr.PRNGKey(SEED) subkeys = jr.split(key, 3)

model = [ nn.Sequential( [ nn.Linear(INPUT_DIM, NETWORK_WIDTH, key=subkeys[0]), nn.Lambda(jnp.tanh) ], ), nn.Sequential( [ nn.Linear(NETWORK_WIDTH, NETWORK_WIDTH, key=subkeys[1]), nn.Lambda(jnp.tanh) ], ), nn.Linear(NETWORK_WIDTH, OUTPUT_DIM, key=subkeys[2]), ] model ```

[Sequential( layers=( Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ), Lambda(fn=<wrapped function tanh>) ) ), Sequential( layers=( Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ), Lambda(fn=<wrapped function tanh>) ) ), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True )]

The last thing we need is to initialise the activities. For this, we will use a feedforward pass as often done in practice.

Question: Can you think of other ways of initialising the activities?

```python def init_activities_with_ffwd(model, input): activities = [vmap(model[0])(input)] for l in range(1, len(model)): layer_output = vmap(model[l])(activities[l - 1]) activities.append(layer_output)

return activities

```

Let's test it on an MNIST sample.

```python

get a data sample¤

train_loader, test_loader = get_mnist_loaders(BATCH_SIZE) img_batch, label_batch = next(iter(train_loader))

we need to turn the torch.Tensor data into numpy arrays for jax¤

img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

let's check our initialised activities¤

activities = init_activities_with_ffwd(model, img_batch) for i, a in enumerate(activities): print(f"activity z at layer {i+1}: {a.shape}") ```

activity z at layer 1: (64, 300) activity z at layer 2: (64, 300) activity z at layer 3: (64, 10)

Ok so now we have everything to test our PC energy function: model, activities, and some data.

python pc_energy_fn( model=model, activities=activities, input=img_batch, output=label_batch )

Array(1.2335204, dtype=float32)

And it works!

Energy gradients¤

How do we minimise the PC energy we defined above (Eq. 1)? Recall from the lecture that we do this in two phases: first with respect to the activities (inference) and then with respect to the weights (learning).

\[\begin{equation} \textit{Inference:} - \frac{\partial \mathcal{F}}{\partial \mathbf{z}_\ell} \end{equation}\]
\[\begin{equation} \textit{Learning:} - \frac{\partial \mathcal{F}}{\partial W_\ell} \end{equation}\]

So we just need to take these gradients of the energy. We are going to use autodiff, which JAX embeds by design (see the docs). If you're familiar with PyTorch, you are probably used to loss.backward() for this, which might feel obstruse at times. JAX, on the other hand, is a fully functional (as opposed to object-oriented) language whose syntax is very close to the maths as you can see below.

```python

note how close this code is to the maths¤

this can be read as "take the gradient of the energy...¤

...with the respect to the 2nd argument (the activities)¤

def compute_activity_grad(model, activities, input, output): return grad(pc_energy_fn, argnums=1)( model, activities, input, output ) ```

Let's test this out.

python dFdzs = compute_activity_grad( model=model, activities=activities, input=img_batch, output=label_batch ) for i, dFdz in enumerate(dFdzs): print(f"activity gradient dFdz shape at layer {i+1}: {dFdz.shape}")

activity gradient dFdz shape at layer 1: (64, 300) activity gradient dFdz shape at layer 2: (64, 300) activity gradient dFdz shape at layer 3: (64, 10)

Now we do the same and take the gradient of the energy with respect to the parameters.

Technical note: below we use Equinox's convenience function filter_grad rather than JAX's native grad. This is because things like activation functions do not have parameters and so we do not want to differentiate them. filter_grad automatically filters these non-differentiable objects for us, while grad alone would throw an error.

```python

note that, compared to the previous function,...¤

...we just change the argument with respect to which...¤

...we are differentiating (the first, or in this case the model)¤

def compute_param_grad(model, activities, input, output): return filter_grad(pc_energy_fn)( model, activities, input, output ) ```

And let's test it.

python param_grads = compute_param_grad( model=model, activities=activities, input=img_batch, output=label_batch )

Updates¤

Before putting everything together, let's wrap our gradients into update functions. This will also allow us to use JAX's jit primitive, which essentially compiles your code the first time it's executed so that it can be run more efficiently the next time (see the docs for more details).

These functions take an (Optax) optimiser such as gradient descent in addition to the previous arguments (model, activities and data).

```python @eqx.filter_jit def update_activities(model, activities, optim, opt_state, input, output): activity_grads = compute_activity_grad( model=model, activities=activities, input=input, output=output ) activity_updates, activity_opt_state = optim.update( updates=activity_grads, state=opt_state, params=activities ) activities = eqx.apply_updates( model=activities, updates=activity_updates ) return activities, optim, opt_state

note that the only difference with the above function is...¤

...the variable we are updating (parameters vs activities)¤

@eqx.filter_jit def update_params(model, activities, optim, opt_state, input, output): param_grads = compute_param_grad( model=model, activities=activities, input=input, output=output ) param_updates, param_opt_state = optim.update( updates=param_grads, state=opt_state, params=model ) model = eqx.apply_updates( model=model, updates=param_updates ) return model, optim, opt_state ```

Putting everything together: Training and testing¤

Now that we have our activity and parameter updates, we just need to wrap them in a training and test loop.

```python

note: the test accuracy computation below could be sped up...¤

...with jit in a separate function¤

def evaluate(model, test_loader): avg_test_acc = 0 for test_iter, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    preds = init_activities_with_ffwd(model, img_batch)[-1]
    test_acc = jnp.mean(
        jnp.argmax(label_batch, axis=1) == jnp.argmax(preds, axis=1)
    ) * 100
    avg_test_acc += test_acc

return avg_test_acc / len(test_loader)

def train( model, activity_lr, inference_steps, param_lr, batch_size, test_every, n_train_iters ): # define optimisers for activities and parameters activity_optim = optax.sgd(activity_lr) param_optim = optax.adam(param_lr) param_opt_state = param_optim.init(eqx.filter(model, eqx.is_array))

train_loader, test_loader = get_mnist_loaders(batch_size)
for train_iter, (img_batch, label_batch) in enumerate(train_loader):
    img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    # initialise activities
    activities = init_activities_with_ffwd(model, img_batch)
    activity_opt_state = activity_optim.init(activities)

    # calculate loss
    train_loss = jnp.mean((label_batch - activities[-1])**2)

    # inference
    for t in range(inference_steps):
        activities, activity_optim, activity_opt_state = update_activities(
            model=model, 
            activities=activities, 
            optim=activity_optim, 
            opt_state=activity_opt_state, 
            input=img_batch, 
            output=label_batch
        )

    # learning
    model, param_optim, param_opt_state = update_params(
        model=model,
        activities=activities,  # note how we use the optimised activities
        optim=param_optim,
        opt_state=param_opt_state,
        input=img_batch,
        output=label_batch
    )
    if ((train_iter+1) % test_every) == 0:
        avg_test_acc = evaluate(model, test_loader)
        print(
            f"Train iter {train_iter+1}, train loss={train_loss:4f}, "
            f"avg test accuracy={avg_test_acc:4f}"
        )
        if (train_iter+1) >= n_train_iters:
            break

```

Run¤

Let's test our implementation.

python train( model=model, activity_lr=ACTIVITY_LR, inference_steps=INFERENCE_STEPS, param_lr=PARAM_LR, batch_size=BATCH_SIZE, test_every=TEST_EVERY, n_train_iters=N_TRAIN_ITERS )

Train iter 50, train loss=0.065566, avg test accuracy=72.726364 Train iter 100, train loss=0.046521, avg test accuracy=76.292068 Train iter 150, train loss=0.042710, avg test accuracy=86.568512 Train iter 200, train loss=0.029598, avg test accuracy=89.082535 Train iter 250, train loss=0.031486, avg test accuracy=89.222755 Train iter 300, train loss=0.016624, avg test accuracy=91.296074 Train iter 350, train loss=0.025201, avg test accuracy=92.648239 Train iter 400, train loss=0.018597, avg test accuracy=92.968750 Train iter 450, train loss=0.019027, avg test accuracy=94.130608 Train iter 500, train loss=0.014850, avg test accuracy=93.760017

🥳 Great, we see that our model is learning! This model was not tuned, and you can probably improve the performance by tweaking some of the hyperparameters (e.g. try a higher number of inference steps).

Even if you didn't follow all the implementation details, you should now have at least an idea of how PC works in practice. Indeed, this is basically the core code behind a new PC library our lab will soon release: JPC. Play around with the notebook examples there where you can learn how to train a variety of PC networks.