Skip to content

Discriminative Predictive Coding¤

Open in Colab

This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits.

python %%capture !pip install torch==2.3.1 !pip install torchvision==0.18.1

```python import jpc

import jax import equinox as eqx import equinox.nn as nn import optax

import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms

import warnings warnings.simplefilter('ignore') # ignore warnings ```

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc.

```python SEED = 0

INPUT_DIM = 784 WIDTH = 300 DEPTH = 3 OUTPUT_DIM = 10 ACT_FN = "relu"

LEARNING_RATE = 1e-3 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 300 ```

Dataset¤

Some utils to fetch MNIST.

```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader

class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)

def __getitem__(self, index):
    img, label = super().__getitem__(index)
    img = torch.flatten(img)
    label = one_hot(label)
    return img, label

def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```

Network¤

For jpc to work, we need to provide a network with callable layers. This is easy to do with the PyTorch-like nn.Sequential() in equinox. For example, we can define a ReLU MLP with two hidden layers as follows

python key = jax.random.PRNGKey(SEED) _, *subkeys = jax.random.split(key, 4) network = [ nn.Sequential( [ nn.Linear(784, 300, key=subkeys[0]), nn.Lambda(jax.nn.relu) ], ), nn.Sequential( [ nn.Linear(300, 300, key=subkeys[1]), nn.Lambda(jax.nn.relu) ], ), nn.Linear(300, 10, key=subkeys[2]), ]

You can also use jpc.make_mlp() to define a multi-layer perceptron (MLP) or fully connected network.

python network = jpc.make_mlp( key, input_dim=INPUT_DIM, width=WIDTH, depth=DEPTH, output_dim=OUTPUT_DIM, act_fn=ACT_FN, use_bias=True ) print(network)

[Sequential( layers=( Lambda(fn=Identity()), Linear( weight=f32[300,784], bias=f32[300], in_features=784, out_features=300, use_bias=True ) ) ), Sequential( layers=( Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>), Linear( weight=f32[300,300], bias=f32[300], in_features=300, out_features=300, use_bias=True ) ) ), Sequential( layers=( Lambda(fn=<PjitFunction of <function relu at 0x10cea9c60>>), Linear( weight=f32[10,300], bias=f32[10], in_features=300, out_features=10, use_bias=True ) ) )]

Train and test¤

A PC network can be updated in a single line of code with jpc.make_pc_step(). Similarly, we can use jpc.test_discriminative_pc() to compute the network accuracy. Note that these functions are already "jitted" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively.

```python def evaluate(model, test_loader): avg_test_loss, avg_test_acc = 0, 0 for _, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    test_loss, test_acc = jpc.test_discriminative_pc(
        model=model,
        input=img_batch,
        output=label_batch
    )
    avg_test_loss += test_loss
    avg_test_acc += test_acc

return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)

def train( model, lr, batch_size, test_every, n_train_iters ): optim = optax.adam(lr) opt_state = optim.init( (eqx.filter(model, eqx.is_array), None) ) train_loader, test_loader = get_mnist_loaders(batch_size)

for iter, (img_batch, label_batch) in enumerate(train_loader):
    img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    result = jpc.make_pc_step(
        model=model,
        optim=optim,
        opt_state=opt_state,
        output=label_batch,
        input=img_batch
    )
    model, opt_state = result["model"], result["opt_state"]
    train_loss = result["loss"]
    if ((iter+1) % test_every) == 0:
        _, avg_test_acc = evaluate(model, test_loader)
        print(
            f"Train iter {iter+1}, train loss={train_loss:4f}, "
            f"avg test accuracy={avg_test_acc:4f}"
        )
        if (iter+1) >= n_train_iters:
            break

```

Run¤

python train( model=network, lr=LEARNING_RATE, batch_size=BATCH_SIZE, test_every=TEST_EVERY, n_train_iters=N_TRAIN_ITERS )

Train iter 100, train loss=0.007197, avg test accuracy=93.309296 Train iter 200, train loss=0.005052, avg test accuracy=95.462738 Train iter 300, train loss=0.006984, avg test accuracy=95.903442