Skip to content

Error-reparameterised Predictive Coding (ePC)¤

Open in Colab

This notebook demonstrates how to train PC networks with ePC (Goemaere et al., 2025), a reparameterisation of PC that can allow for faster convergence of the inference dynamics and training of deeper networks than standard PC.

python %%capture !pip install torch==2.3.1 !pip install torchvision==0.18.1

```python import jpc

import numpy as np import jax.random as jr import equinox as eqx import optax

import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms

import warnings warnings.simplefilter('ignore') # ignore warnings ```

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc.

```python WIDTH = 128 DEPTH = 10 ACT_FN = "relu"

STATE_LR = 5e-1 # for either activities (PC) or errors (ePC) PARAM_LR = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 ```

Dataset¤

Some utils to fetch MNIST.

```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader

class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)

def __getitem__(self, index):
    img, label = super().__getitem__(index)
    img = torch.flatten(img)
    label = one_hot(label)
    return img, label

def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```

Train and test¤

```python def evaluate(model, test_loader): avg_test_acc = 0 for _, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    _, test_acc = jpc.test_discriminative_pc(
        model=model,
        input=img_batch,
        output=label_batch
    )
    avg_test_acc += test_acc

return avg_test_acc / len(test_loader)

def train( algo, width, depth, act_fn, state_lr,
param_lr, batch_size, test_every, n_train_iters ):
key = jr.PRNGKey(5482) model = jpc.make_mlp( key, input_dim=784, width=width, depth=depth, output_dim=10, act_fn=act_fn, use_bias=True ) layer_sizes = [784] + [width] * (depth-1) + [10]

if algo == "pc":
    activity_optim = optax.sgd(state_lr)
elif algo == "epc":
    error_optim = optax.sgd(state_lr)

param_optim = optax.adam(param_lr)
param_opt_state = param_optim.init(
    (eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)

for iter, (img_batch, label_batch) in enumerate(train_loader):
    img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    # initialise activities or errors
    activities = jpc.init_activities_with_ffwd(
        model=model,
        input=img_batch
    )
    train_loss = jpc.mse_loss(activities[-1], label_batch)

    if algo == "pc":
        activity_opt_state = activity_optim.init(activities)

    elif algo == "epc":
        errors = jpc.init_epc_errors(
            layer_sizes=layer_sizes,
            batch_size=batch_size
        )
        error_opt_state = error_optim.init(errors)

    # inference
    for _ in range(len(model)):
        if algo == "pc":
            activity_update_result = jpc.update_pc_activities(
                params=(model, None),
                activities=activities,
                optim=activity_optim,
                opt_state=activity_opt_state,
                output=label_batch,
                input=img_batch
            )
            activities = activity_update_result["activities"]
            activity_opt_state = activity_update_result["opt_state"]

        elif algo == "epc":
            error_update_result = jpc.update_epc_errors(
                params=(model, None),
                errors=errors,
                optim=error_optim,
                opt_state=error_opt_state,
                output=label_batch,
                input=img_batch
            )
            errors = error_update_result["errors"]
            error_opt_state = error_update_result["opt_state"]

    # learning
    if algo == "pc":
        param_update_result = jpc.update_pc_params(
            params=(model, None),
            activities=activities,
            optim=param_optim,
            opt_state=param_opt_state,
            output=label_batch,
            input=img_batch
        )
    elif algo == "epc":
        param_update_result = jpc.update_epc_params(
            params=(model, None),
            errors=errors,
            optim=param_optim,
            opt_state=param_opt_state,
            output=label_batch,
            input=img_batch
        )

    model = param_update_result["model"]
    param_opt_state = param_update_result["opt_state"]

    if np.isinf(train_loss) or np.isnan(train_loss):
        print(
            f"Stopping training because of divergence, train loss={train_loss}"
        )
        break

    if ((iter+1) % test_every) == 0:
        avg_test_acc = evaluate(model=model, test_loader=test_loader)
        print(
            f"Train iter {iter+1}, train loss={train_loss:4f}, "
            f"avg test accuracy={avg_test_acc:4f}"
        )
        if (iter+1) >= n_train_iters:
            break

```

Run¤

The script below should take ~30s to run on a CPU.

python train( algo="epc", width=WIDTH, depth=DEPTH, act_fn=ACT_FN, state_lr=STATE_LR, param_lr=PARAM_LR, batch_size=BATCH_SIZE, test_every=TEST_EVERY, n_train_iters=N_TRAIN_ITERS )

Train iter 100, train loss=0.014619, avg test accuracy=82.291664 Train iter 200, train loss=0.006187, avg test accuracy=91.185898 Train iter 300, train loss=0.007181, avg test accuracy=91.606567

For comparison, try to change to standard pc with algo = "pc".