Error-reparameterised Predictive Coding (ePC)¤
This notebook demonstrates how to train PC networks with ePC (Goemaere et al., 2025), a reparameterisation of PC that can allow for faster convergence of the inference dynamics and training of deeper networks than standard PC.
import jpc
import numpy as np
import jax.random as jr
import equinox as eqx
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import warnings
warnings.simplefilter('ignore') # ignore warnings
Hyperparameters¤
We define some global parameters, including the network architecture, learning rate, batch size, etc.
WIDTH = 128
DEPTH = 10
ACT_FN = "relu"
STATE_LR = 5e-1 # for either activities (PC) or errors (ePC)
PARAM_LR = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300
Dataset¤
Some utils to fetch MNIST.
def get_mnist_loaders(batch_size):
train_data = MNIST(train=True, normalise=True)
test_data = MNIST(train=False, normalise=True)
train_loader = DataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return train_loader, test_loader
class MNIST(datasets.MNIST):
def __init__(self, train, normalise=True, save_dir="data"):
if normalise:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307), std=(0.3081)
)
]
)
else:
transform = transforms.Compose([transforms.ToTensor()])
super().__init__(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10):
arr = torch.eye(n_classes)
return arr[labels]
Train and test¤
def evaluate(model, test_loader):
avg_test_acc = 0
for _, (img_batch, label_batch) in enumerate(test_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
_, test_acc = jpc.test_discriminative_pc(
model=model,
input=img_batch,
output=label_batch
)
avg_test_acc += test_acc
return avg_test_acc / len(test_loader)
def train(
algo,
width,
depth,
act_fn,
state_lr,
param_lr,
batch_size,
test_every,
n_train_iters
):
key = jr.PRNGKey(5482)
model = jpc.make_mlp(
key,
input_dim=784,
width=width,
depth=depth,
output_dim=10,
act_fn=act_fn,
use_bias=True
)
layer_sizes = [784] + [width] * (depth-1) + [10]
if algo == "pc":
activity_optim = optax.sgd(state_lr)
elif algo == "epc":
error_optim = optax.sgd(state_lr)
param_optim = optax.adam(param_lr)
param_opt_state = param_optim.init(
(eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
# initialise activities or errors
activities = jpc.init_activities_with_ffwd(
model=model,
input=img_batch
)
train_loss = jpc.mse_loss(activities[-1], label_batch)
if algo == "pc":
activity_opt_state = activity_optim.init(activities)
elif algo == "epc":
errors = jpc.init_epc_errors(
layer_sizes=layer_sizes,
batch_size=batch_size
)
error_opt_state = error_optim.init(errors)
# inference
for _ in range(len(model)):
if algo == "pc":
activity_update_result = jpc.update_pc_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=label_batch,
input=img_batch
)
activities = activity_update_result["activities"]
activity_opt_state = activity_update_result["opt_state"]
elif algo == "epc":
error_update_result = jpc.update_epc_errors(
params=(model, None),
errors=errors,
optim=error_optim,
opt_state=error_opt_state,
output=label_batch,
input=img_batch
)
errors = error_update_result["errors"]
error_opt_state = error_update_result["opt_state"]
# learning
if algo == "pc":
param_update_result = jpc.update_pc_params(
params=(model, None),
activities=activities,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
elif algo == "epc":
param_update_result = jpc.update_epc_params(
params=(model, None),
errors=errors,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
model = param_update_result["model"]
param_opt_state = param_update_result["opt_state"]
if np.isinf(train_loss) or np.isnan(train_loss):
print(
f"Stopping training because of divergence, train loss={train_loss}"
)
break
if ((iter+1) % test_every) == 0:
avg_test_acc = evaluate(model=model, test_loader=test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
Run¤
The script below should take ~30s to run on a CPU.
train(
algo="epc",
width=WIDTH,
depth=DEPTH,
act_fn=ACT_FN,
state_lr=STATE_LR,
param_lr=PARAM_LR,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.014619, avg test accuracy=82.291664
Train iter 200, train loss=0.006187, avg test accuracy=91.185898
Train iter 300, train loss=0.007181, avg test accuracy=91.606567
For comparison, try to change to standard pc with algo = "pc".