Discriminative Predictive Coding¤
This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits.
python
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
```python import jpc
import jax import equinox as eqx import equinox.nn as nn import optax
import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms
import warnings warnings.simplefilter('ignore') # ignore warnings ```
Hyperparameters¤
We define some global parameters, including the network architecture, learning rate, batch size, etc.
```python SEED = 0
INPUT_DIM = 784 WIDTH = 300 DEPTH = 3 OUTPUT_DIM = 10 ACT_FN = "relu"
LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 ```
Dataset¤
Some utils to fetch MNIST.
```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader
class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```
Network¤
For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in equinox. For example, we can define a ReLU MLP with two hidden layers as follows
python
key = jax.random.PRNGKey(SEED)
_, *subkeys = jax.random.split(key, 4)
network = [
nn.Sequential(
[
nn.Linear(784, 300, key=subkeys[0]),
nn.Lambda(jax.nn.relu)
],
),
nn.Sequential(
[
nn.Linear(300, 300, key=subkeys[1]),
nn.Lambda(jax.nn.relu)
],
),
nn.Linear(300, 10, key=subkeys[2]),
]
You can also use jpc.make_mlp() to define a multi-layer perceptron (MLP) or fully connected network.
python
network = jpc.make_mlp(
key,
input_dim=INPUT_DIM,
width=WIDTH,
depth=DEPTH,
output_dim=OUTPUT_DIM,
act_fn=ACT_FN,
use_bias=True
)
print(network)
[Sequential(
layers=(
Lambda(fn=Identity()),
Linear(
weight=f32[300,784],
bias=f32[300],
in_features=784,
out_features=300,
use_bias=True
)
)
), Sequential(
layers=(
Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>),
Linear(
weight=f32[300,300],
bias=f32[300],
in_features=300,
out_features=300,
use_bias=True
)
)
), Sequential(
layers=(
Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>),
Linear(
weight=f32[10,300],
bias=f32[10],
in_features=300,
out_features=10,
use_bias=True
)
)
)]
Train and test¤
A PC network can be updated in a single line of code with jpc.make_pc_step(). Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already "jitted" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively.
```python def evaluate(model, test_loader): avg_test_loss, avg_test_acc = 0, 0 for _, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
test_loss, test_acc = jpc.test_discriminative_pc(
model=model,
input=img_batch,
output=label_batch
)
avg_test_loss += test_loss
avg_test_acc += test_acc
return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)
def train( model, lr, batch_size, test_every, n_train_iters ): optim = optax.adam(lr) opt_state = optim.init( (eqx.filter(model, eqx.is_array), None) ) train_loader, test_loader = get_mnist_loaders(batch_size)
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
result = jpc.make_pc_step(
model=model,
optim=optim,
opt_state=opt_state,
output=label_batch,
input=img_batch
)
model, opt_state = result["model"], result["opt_state"]
train_loss = result["loss"]
if ((iter+1) % test_every) == 0:
_, avg_test_acc = evaluate(model, test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
```
Run¤
python
train(
model=network,
lr=LEARNING_RATE,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.007197, avg test accuracy=93.309296
Train iter 200, train loss=0.005052, avg test accuracy=95.462738
Train iter 300, train loss=0.006984, avg test accuracy=95.903442