Skip to content

Discriminative PC on MNIST¤

Open in Colab

This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits.

%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jpc

import jax
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import warnings
warnings.simplefilter('ignore')  # ignore warnings

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc.

SEED = 0

INPUT_DIM = 784
WIDTH = 300
DEPTH = 3
OUTPUT_DIM = 10
ACT_FN = "relu"

LEARNING_RATE = 1e-3
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 300

Dataset¤

Some utils to fetch MNIST.

def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]

Network¤

For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in equinox. For example, we can define a ReLU MLP with two hidden layers as follows

key = jax.random.PRNGKey(SEED)
_, *subkeys = jax.random.split(key, 4)
network = [
    nn.Sequential(
        [
            nn.Linear(784, 300, key=subkeys[0]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Sequential(
        [
            nn.Linear(300, 300, key=subkeys[1]),
            nn.Lambda(jax.nn.relu)
        ],
    ),
    nn.Linear(300, 10, key=subkeys[2]),
]

You can also use jpc.make_mlp() to define a multi-layer perceptron (MLP) or fully connected network.

network = jpc.make_mlp(
    key,
    input_dim=INPUT_DIM,
    width=WIDTH,
    depth=DEPTH,
    output_dim=OUTPUT_DIM,
    act_fn=ACT_FN,
    use_bias=True
)
print(network)
[Sequential(
  layers=(
    Lambda(fn=Identity()),
    Linear(
      weight=f32[300,784],
      bias=f32[300],
      in_features=784,
      out_features=300,
      use_bias=True
    )
  )
), Sequential(
  layers=(
    Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>),
    Linear(
      weight=f32[300,300],
      bias=f32[300],
      in_features=300,
      out_features=300,
      use_bias=True
    )
  )
), Sequential(
  layers=(
    Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>),
    Linear(
      weight=f32[10,300],
      bias=f32[10],
      in_features=300,
      out_features=10,
      use_bias=True
    )
  )
)]

Train and test¤

A PC network can be updated in a single line of code with jpc.make_pc_step(). Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already "jitted" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively.

def evaluate(model, test_loader):
    avg_test_loss, avg_test_acc = 0, 0
    for _, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        test_loss, test_acc = jpc.test_discriminative_pc(
            model=model,
            input=img_batch,
            output=label_batch
        )
        avg_test_loss += test_loss
        avg_test_acc += test_acc

    return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)


def train(
      model,
      lr,
      batch_size,
      test_every,
      n_train_iters
):
    optim = optax.adam(lr)
    opt_state = optim.init(
        (eqx.filter(model, eqx.is_array), None)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        result = jpc.make_pc_step(
            model=model,
            optim=optim,
            opt_state=opt_state,
            output=label_batch,
            input=img_batch
        )
        model, opt_state = result["model"], result["opt_state"]
        train_loss = result["loss"]
        if ((iter+1) % test_every) == 0:
            _, avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

Run¤

train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.007197, avg test accuracy=93.309296
Train iter 200, train loss=0.005052, avg test accuracy=95.462738
Train iter 300, train loss=0.006984, avg test accuracy=95.903442