Error-reparameterised Predictive Coding (ePC)¤
This notebook demonstrates how to train PC networks with ePC (Goemaere et al., 2025), a reparameterisation of PC that can allow for faster convergence of the inference dynamics and training of deeper networks than standard PC.
python
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
```python import jpc
import numpy as np import jax.random as jr import equinox as eqx import optax
import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms
import warnings warnings.simplefilter('ignore') # ignore warnings ```
Hyperparameters¤
We define some global parameters, including the network architecture, learning rate, batch size, etc.
```python WIDTH = 128 DEPTH = 10 ACT_FN = "relu"
STATE_LR = 5e-1 # for either activities (PC) or errors (ePC) PARAM_LR = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 ```
Dataset¤
Some utils to fetch MNIST.
```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader
class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```
Train and test¤
```python def evaluate(model, test_loader): avg_test_acc = 0 for _, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
_, test_acc = jpc.test_discriminative_pc(
model=model,
input=img_batch,
output=label_batch
)
avg_test_acc += test_acc
return avg_test_acc / len(test_loader)
def train(
algo,
width,
depth,
act_fn,
state_lr,
param_lr,
batch_size,
test_every,
n_train_iters
):
key = jr.PRNGKey(5482)
model = jpc.make_mlp(
key,
input_dim=784,
width=width,
depth=depth,
output_dim=10,
act_fn=act_fn,
use_bias=True
)
layer_sizes = [784] + [width] * (depth-1) + [10]
if algo == "pc":
activity_optim = optax.sgd(state_lr)
elif algo == "epc":
error_optim = optax.sgd(state_lr)
param_optim = optax.adam(param_lr)
param_opt_state = param_optim.init(
(eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
# initialise activities or errors
activities = jpc.init_activities_with_ffwd(
model=model,
input=img_batch
)
train_loss = jpc.mse_loss(activities[-1], label_batch)
if algo == "pc":
activity_opt_state = activity_optim.init(activities)
elif algo == "epc":
errors = jpc.init_epc_errors(
layer_sizes=layer_sizes,
batch_size=batch_size
)
error_opt_state = error_optim.init(errors)
# inference
for _ in range(len(model)):
if algo == "pc":
activity_update_result = jpc.update_pc_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=label_batch,
input=img_batch
)
activities = activity_update_result["activities"]
activity_opt_state = activity_update_result["opt_state"]
elif algo == "epc":
error_update_result = jpc.update_epc_errors(
params=(model, None),
errors=errors,
optim=error_optim,
opt_state=error_opt_state,
output=label_batch,
input=img_batch
)
errors = error_update_result["errors"]
error_opt_state = error_update_result["opt_state"]
# learning
if algo == "pc":
param_update_result = jpc.update_pc_params(
params=(model, None),
activities=activities,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
elif algo == "epc":
param_update_result = jpc.update_epc_params(
params=(model, None),
errors=errors,
optim=param_optim,
opt_state=param_opt_state,
output=label_batch,
input=img_batch
)
model = param_update_result["model"]
param_opt_state = param_update_result["opt_state"]
if np.isinf(train_loss) or np.isnan(train_loss):
print(
f"Stopping training because of divergence, train loss={train_loss}"
)
break
if ((iter+1) % test_every) == 0:
avg_test_acc = evaluate(model=model, test_loader=test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
```
Run¤
The script below should take ~30s to run on a CPU.
python
train(
algo="epc",
width=WIDTH,
depth=DEPTH,
act_fn=ACT_FN,
state_lr=STATE_LR,
param_lr=PARAM_LR,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.014619, avg test accuracy=82.291664
Train iter 200, train loss=0.006187, avg test accuracy=91.185898
Train iter 300, train loss=0.007181, avg test accuracy=91.606567
For comparison, try to change to standard pc with algo = "pc".