Error-reparameterised Predictive Coding (ePC)¤
This notebook demonstrates how to train PC networks with ePC (Goemaere et al., 2025), a reparameterisation of PC that can allow for faster convergence of the inference dynamics and training of deeper networks than standard PC.
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jpc
import numpy as np
import jax.random as jr
import equinox as eqx
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import warnings
warnings.simplefilter('ignore') # ignore warnings
Hyperparameters¤
We define some global parameters, including the network architecture, learning rate, batch size, etc.
WIDTH = 128
DEPTH = 10
ACT_FN = "relu"
STATE_LR = 5e-1 # for either activities (PC) or errors (ePC)
PARAM_LR = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300
Dataset¤
Some utils to fetch MNIST.
def get_mnist_loaders(batch_size):
train_data = MNIST(train=True, normalise=True)
test_data = MNIST(train=False, normalise=True)
train_loader = DataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return train_loader, test_loader
class MNIST(datasets.MNIST):
def __init__(self, train, normalise=True, save_dir="data"):
if normalise:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307), std=(0.3081)
)
]
)
else:
transform = transforms.Compose([transforms.ToTensor()])
super().__init__(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10):
arr = torch.eye(n_classes)
return arr[labels]
Train and test¤
def evaluate(model, test_loader):
avg_test_acc = 0
for _, (img_batch, label_batch) in enumerate(test_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
_, test_acc = jpc.test_discriminative_pc(
model=model,
input=img_batch,
output=label_batch
)
avg_test_acc += test_acc
return avg_test_acc / len(test_loader)
def train(
algo,
width,
depth,
act_fn,
state_lr,
param_lr,
batch_size,
test_every,
n_train_iters
):
key = jr.PRNGKey(5482)
model = jpc.make_mlp(
key,
input_dim=784,
width=width,
depth=depth,
output_dim=10,
act_fn=act_fn,
use_bias=True
)
layer_sizes = [784] + [width] * (depth-1) + [10]
if algo == "pc":
activity_optim = optax.sgd(state_lr)
elif algo == "epc":
error_optim = optax.sgd(state_lr)
param_optim = optax.adam(param_lr)
param_opt_state = param_optim.init(
(eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
# initialise activities or errors
activities = jpc.init_activities_with_ffwd(
model=model,
input=img_batch
)
train_loss = jpc.mse_loss(activities[-1], label_batch)
if algo == "pc":
activity_opt_state = activity_optim.init(activities)
elif algo == "epc":
errors = jpc.init_epc_errors(
layer_sizes=layer_sizes,
batch_size=batch_size
)
error_opt_state = error_optim.init(errors)
# inference
for _ in range(len(model)):
if algo == "pc":
activity_update_result = jpc.update_pc_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=label_batch,
input=img_batch
)
activities = activity_update_result["activities"]
activity_opt_state = activity_update_result["opt_state"]
elif algo == "epc":
error_update_result = jpc.update_epc_errors(
params=(model, None),
errors=errors,
optim=error_optim,
opt_state=error_opt_state,
output=label_batch,
input=img_batch
)
errors = error_update_result["errors"]
error_opt_state = error_update_result["opt_state"]
# learning
if algo == "pc":
param_update_result = jpc.update_pc_params(
params=(model, None),
activities=activities,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
elif algo == "epc":
param_update_result = jpc.update_epc_params(
params=(model, None),
errors=errors,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
model = param_update_result["model"]
param_opt_state = param_update_result["opt_state"]
if np.isinf(train_loss) or np.isnan(train_loss):
print(
f"Stopping training because of divergence, train loss={train_loss}"
)
break
if ((iter+1) % test_every) == 0:
avg_test_acc = evaluate(model=model, test_loader=test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
Run¤
The script below should take ~30s to run on a CPU.
train(
algo="epc",
width=WIDTH,
depth=DEPTH,
act_fn=ACT_FN,
state_lr=STATE_LR,
param_lr=PARAM_LR,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.014619, avg test accuracy=82.291664
Train iter 200, train loss=0.006187, avg test accuracy=91.185898
Train iter 300, train loss=0.007181, avg test accuracy=91.606567
For comparison, try to change to standard pc with algo = "pc".