Discriminative PC on MNIST¤
This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits.
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jpc
import jax
import equinox as eqx
import equinox.nn as nn
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import warnings
warnings.simplefilter('ignore') # ignore warnings
Hyperparameters¤
We define some global parameters, including network architecture, learning rate, batch size, etc.
SEED = 0
LAYER_SIZES = [784, 300, 300, 10]
ACT_FN = "relu"
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300
Dataset¤
Some utils to fetch MNIST.
#@title data utils
def get_mnist_loaders(batch_size):
train_data = MNIST(train=True, normalise=True)
test_data = MNIST(train=False, normalise=True)
train_loader = DataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return train_loader, test_loader
class MNIST(datasets.MNIST):
def __init__(self, train, normalise=True, save_dir="data"):
if normalise:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307), std=(0.3081)
)
]
)
else:
transform = transforms.Compose([transforms.ToTensor()])
super().__init__(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10):
arr = torch.eye(n_classes)
return arr[labels]
key = jax.random.PRNGKey(SEED)
_, *subkeys = jax.random.split(key, 4)
network = [
nn.Sequential(
[
nn.Linear(784, 300, key=subkeys[0]),
nn.Lambda(jax.nn.relu)
],
),
nn.Sequential(
[
nn.Linear(300, 300, key=subkeys[1]),
nn.Lambda(jax.nn.relu)
],
),
nn.Linear(300, 10, key=subkeys[2]),
]
You can also use the utility jpc.make_mlp
to define a multi-layer perceptron (MLP) or fully connected network with some activation function (see docs here for more details).
network = jpc.make_mlp(key, LAYER_SIZES, act_fn="relu")
print(network)
Train and test¤
A PC network can be updated in a single line of code with jpc.make_pc_step()
(see the docs for more details). Similarly, we can use jpc.test_discriminative_pc()
to compute the network accuracy (docs here). Note that these functions are already "jitted" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively.
def evaluate(model, test_loader):
avg_test_loss, avg_test_acc = 0, 0
for batch_id, (img_batch, label_batch) in enumerate(test_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
test_loss, test_acc = jpc.test_discriminative_pc(
model=model,
input=img_batch,
output=label_batch
)
avg_test_loss += test_loss
avg_test_acc += test_acc
return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)
def train(
model,
lr,
batch_size,
test_every,
n_train_iters
):
optim = optax.adam(lr)
opt_state = optim.init(
(eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
result = jpc.make_pc_step(
model,
optim,
opt_state,
output=label_batch,
input=img_batch
)
model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
train_loss = result["loss"]
if ((iter+1) % test_every) == 0:
avg_test_loss, avg_test_acc = evaluate(model, test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
Run¤
train(
model=network,
lr=LEARNING_RATE,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)