Skip to content

Error-reparameterised Predictive Coding (ePC)¤

Open in Colab

This notebook demonstrates how to train PC networks with ePC (Goemaere et al., 2025), a reparameterisation of PC that can allow for faster convergence of the inference dynamics and training of deeper networks than standard PC.

%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jpc

import numpy as np
import jax.random as jr
import equinox as eqx
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import warnings
warnings.simplefilter('ignore')  # ignore warnings

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc.

WIDTH = 128
DEPTH = 10
ACT_FN = "relu"

STATE_LR = 5e-1  # for either activities (PC) or errors (ePC)
PARAM_LR = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300

Dataset¤

Some utils to fetch MNIST.

def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]

Train and test¤

def evaluate(model, test_loader):
    avg_test_acc = 0
    for _, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        _, test_acc = jpc.test_discriminative_pc(
            model=model,
            input=img_batch,
            output=label_batch
        )
        avg_test_acc += test_acc

    return avg_test_acc / len(test_loader)


def train(
      algo,
      width,
      depth,
      act_fn,
      state_lr,  
      param_lr,
      batch_size,
      test_every,
      n_train_iters
):  
    key = jr.PRNGKey(5482)
    model = jpc.make_mlp(
        key,
        input_dim=784,
        width=width,
        depth=depth,
        output_dim=10,
        act_fn=act_fn,
        use_bias=True
    )
    layer_sizes = [784] + [width] * (depth-1) + [10]

    if algo == "pc":
        activity_optim = optax.sgd(state_lr)
    elif algo == "epc":
        error_optim = optax.sgd(state_lr)

    param_optim = optax.adam(param_lr)
    param_opt_state = param_optim.init(
        (eqx.filter(model, eqx.is_array), None)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        # initialise activities or errors
        activities = jpc.init_activities_with_ffwd(
            model=model,
            input=img_batch
        )
        train_loss = jpc.mse_loss(activities[-1], label_batch)

        if algo == "pc":
            activity_opt_state = activity_optim.init(activities)

        elif algo == "epc":
            errors = jpc.init_epc_errors(
                layer_sizes=layer_sizes,
                batch_size=batch_size
            )
            error_opt_state = error_optim.init(errors)

        # inference
        for _ in range(len(model)):
            if algo == "pc":
                activity_update_result = jpc.update_pc_activities(
                    params=(model, None),
                    activities=activities,
                    optim=activity_optim,
                    opt_state=activity_opt_state,
                    output=label_batch,
                    input=img_batch
                )
                activities = activity_update_result["activities"]
                activity_opt_state = activity_update_result["opt_state"]

            elif algo == "epc":
                error_update_result = jpc.update_epc_errors(
                    params=(model, None),
                    errors=errors,
                    optim=error_optim,
                    opt_state=error_opt_state,
                    output=label_batch,
                    input=img_batch
                )
                errors = error_update_result["errors"]
                error_opt_state = error_update_result["opt_state"]

        # learning
        if algo == "pc":
            param_update_result = jpc.update_pc_params(
                params=(model, None),
                activities=activities,
                optim=param_optim,
                opt_state=param_opt_state,
                output=label_batch,
                input=img_batch
            )
        elif algo == "epc":
            param_update_result = jpc.update_epc_params(
                params=(model, None),
                errors=errors,
                optim=param_optim,
                opt_state=param_opt_state,
                output=label_batch,
                input=img_batch
            )

        model = param_update_result["model"]
        param_opt_state = param_update_result["opt_state"]

        if np.isinf(train_loss) or np.isnan(train_loss):
            print(
                f"Stopping training because of divergence, train loss={train_loss}"
            )
            break

        if ((iter+1) % test_every) == 0:
            avg_test_acc = evaluate(model=model, test_loader=test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

Run¤

The script below should take ~30s to run on a CPU.

train(
    algo="epc",
    width=WIDTH,
    depth=DEPTH,
    act_fn=ACT_FN,
    state_lr=STATE_LR,
    param_lr=PARAM_LR,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.014619, avg test accuracy=82.291664
Train iter 200, train loss=0.006187, avg test accuracy=91.185898
Train iter 300, train loss=0.007181, avg test accuracy=91.606567

For comparison, try to change to standard pc with algo = "pc".