Skip to content

μPC¤

Open in Colab

This notebook demonstrates how to train residual networks with μPC, a reparameterisation of PC that allows stable training of very deep (100+ layer) networks while also enabling zero-shot hyperparameter transfer (see Innocenti et al., 2025).

%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
import jpc

import jax.random as jr
import equinox as eqx
import equinox.nn as nn
import optax

import math
import random
import numpy as np
from typing import List, Callable

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import warnings
warnings.simplefilter('ignore')  # ignore warnings
# for reproducibility
def set_global_seed(seed):
    torch.manual_seed(seed)             
    torch.cuda.manual_seed(seed)            
    torch.cuda.manual_seed_all(seed)        
    np.random.seed(seed)                  
    random.seed(seed)                       
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc. We choose a network with "only" 30 layers and 128 hidden neurons so that it can run relatively fast on a CPU, but feel free to try deeper and wider networks.

SEED = 4329

INPUT_DIM = 784
WIDTH = 128
DEPTH = 30
OUTPUT_DIM = 10
ACT_FN = "relu"

ACTIVITY_LR = 5e-1
PARAM_LR = 1e-1
BATCH_SIZE = 64
TEST_EVERY = 100
N_TRAIN_ITERS = 900

Dataset¤

Some utils to fetch MNIST.

def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]

Creating a μPC model¤

To parameterise a model with μPC, one can use a few convenience functions of jpc to create an MLP or fully connected network with jpc.make_mlp() and an associated skip model with jpc.make_skip model(). Note that μPC works only for a specific type of ResNet, namely one with one-layer skip connections at every layer except from the input to the next layer and from the penultimate layer to the output (see Innocenti et al., 2025), as shown below.

key = jr.PRNGKey(SEED)

# MLP
model = jpc.make_mlp(
    key,
    input_dim=INPUT_DIM,
    width=WIDTH,
    depth=DEPTH,
    output_dim=OUTPUT_DIM,
    act_fn=ACT_FN,
    param_type="mupc"
)

# skip model
skip_model = jpc.make_skip_model(DEPTH)

At training and test time we would need to pass both models to relevant jpc functions and change the argument param_type = "mupc" (default is "sp" for standard parameterisation).

Alternatively, one could define a model class embedding the parameterisation itself and leave the above arguments to their default. This solution is more elegant but it can be harder to debug, at least for a fully connected architecture. However, if you would like to experiment with different parameterisations and more complex architectures (e.g. CNNs), we recommend this approach.

class ScaledLinear(eqx.Module):
    """Scaled linear transformation."""
    linear: nn.Linear
    scaling: float = eqx.static_field()

    def __init__(
            self,
            in_features,
            out_features,
            *,
            key,
            scaling=1.,
            param_type="sp",
            use_bias=False
    ):
        keys = jr.split(key, 2)
        linear = nn.Linear(
            in_features, 
            out_features, 
            use_bias=use_bias,
            key=keys[0]
        )
        if param_type == "mupc":
            W = jr.normal(keys[1], linear.weight.shape)
            linear = eqx.tree_at(lambda l: l.weight, linear, W)

        self.linear = linear
        self.scaling = scaling

    def __call__(self, x):
        return self.scaling * self.linear(x)


class ResNetBlock(eqx.Module):
    """Identity residual block applying activation and a scaled linear layer."""
    act_fn: Callable = eqx.static_field()
    scaled_linear: ScaledLinear

    def __init__(
        self,
        in_features,
        out_features,
        *,
        key,
        scaling=1.,
        param_type="sp",
        use_bias=False,
        act_fn="linear"
    ):
        self.act_fn = act_fn
        self.scaled_linear = ScaledLinear(
            in_features=in_features,
            out_features=out_features,
            key=key,
            scaling=scaling,
            param_type=param_type,
            use_bias=use_bias
        )

    def __call__(self, x):
        res_path = x
        x = self.act_fn(x)
        return self.scaled_linear(x) + res_path


class Readout(eqx.Module):
    """Final network layer applying activation and a scaled linear layer."""
    act_fn: Callable = eqx.static_field()
    scaled_linear: ScaledLinear

    def __init__(
        self,
        in_features,
        out_features,
        *,
        key,
        scaling=1.,
        param_type="sp",
        use_bias=False,
        act_fn="linear"
    ):
        self.act_fn = act_fn
        self.scaled_linear = ScaledLinear(
            in_features=in_features,
            out_features=out_features,
            key=key,
            scaling=scaling,
            param_type=param_type,
            use_bias=use_bias
        )

    def __call__(self, x):
        x = self.act_fn(x)
        return self.scaled_linear(x)


class FCResNet(eqx.Module):
    """Fully-connected ResNet compatible with different parameterisations."""
    layers: List[eqx.Module]

    def __init__(
            self, 
            *,
            key, 
            in_dim, 
            width, 
            depth, 
            out_dim, 
            act_fn="linear", 
            use_bias=False,
            param_type="sp"
        ):
        act_fn = jpc.get_act_fn(act_fn)
        if param_type == "sp":
            in_scaling = 1.
            hidden_scaling = 1.
            out_scaling = 1.

        elif param_type == "mupc":
            in_scaling = 1 / math.sqrt(in_dim)
            hidden_scaling = 1 / math.sqrt(width * depth)
            out_scaling = 1 / width

        keys = jr.split(key, depth)
        self.layers = [
            ScaledLinear(
                key=keys[0],
                in_features=in_dim,
                out_features=width,
                scaling=in_scaling,
                param_type=param_type,
                use_bias=use_bias
            )
        ]

        for i in range(1, depth - 1):
            self.layers.append(
                ResNetBlock(
                    key=keys[i],
                    in_features=width,
                    out_features=width,
                    scaling=hidden_scaling,
                    param_type=param_type,
                    use_bias=use_bias,
                    act_fn=act_fn
                )
            )

        self.layers.append(
            Readout(
                key=keys[-1],
                in_features=width,
                out_features=out_dim,
                scaling=out_scaling,
                param_type=param_type,
                use_bias=use_bias,
                act_fn=act_fn
            )
        )

    def __call__(self, x):
        for f in self.layers:
            x = f(x)      
        return x

    def __len__(self):
        return len(self.layers)

    def __getitem__(self, idx):
        return self.layers[idx]
mupc_model = FCResNet(
    key=key, 
    in_dim=INPUT_DIM, 
    width=WIDTH, 
    depth=DEPTH, 
    out_dim=OUTPUT_DIM, 
    act_fn=ACT_FN, 
    use_bias=False, 
    param_type="mupc"
)

The following makes sure that the models have identical weights.

mupc_model = FCResNet(
    key=key, 
    in_dim=INPUT_DIM, 
    width=WIDTH, 
    depth=DEPTH, 
    out_dim=OUTPUT_DIM, 
    act_fn=ACT_FN, 
    use_bias=False, 
    param_type="mupc"
)
mupc_model = eqx.tree_at(
    where=lambda tree: tree[0].linear.weight,
    pytree=mupc_model,
    replace=model[0][1].weight
)
for l in range(1, len(model)):
    mupc_model = eqx.tree_at(
        where=lambda tree: tree[l].scaled_linear.linear.weight,
        pytree=mupc_model,
        replace=model[l][1].weight
    )

Train and test¤

For training, we use the advanced API including the functions jpc.init_activities_with_ffwd() to initialise the activities, jpc.update_activities() to perform PC inference, and jpc.update_params() to update the weights. All these functions accept skip_model and param_type as arguments. Note, however, that one can replace these functions with jpc.make_pc_step(). For testing, we use jpc.test_discriminative_pc().

def evaluate(model, skip_model, test_loader, param_type):
    avg_test_acc = 0
    for _, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        _, test_acc = jpc.test_discriminative_pc(
            model=model,
            input=img_batch,
            output=label_batch,
            skip_model=skip_model,
            param_type=param_type
        )
        avg_test_acc += test_acc

    return avg_test_acc / len(test_loader)


def train(
      seed,  
      model,
      skip_model,
      param_type,
      activity_lr,  
      param_lr,
      batch_size,
      test_every,
      n_train_iters
):  
    set_global_seed(seed)
    activity_optim = optax.sgd(activity_lr)
    param_optim = optax.adam(param_lr)
    param_opt_state = param_optim.init(
        (eqx.filter(model, eqx.is_array), skip_model)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)

    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        # initialise activities
        activities = jpc.init_activities_with_ffwd(
            model=model,
            input=img_batch,
            skip_model=skip_model,
            param_type=param_type
        )
        activity_opt_state = activity_optim.init(activities)
        train_loss = jpc.mse_loss(activities[-1], label_batch)

        # inference
        for t in range(len(model)):
            activity_update_result = jpc.update_activities(
                params=(model, skip_model),
                activities=activities,
                optim=activity_optim,
                opt_state=activity_opt_state,
                output=label_batch,
                input=img_batch,
                param_type=param_type
            )
            activities = activity_update_result["activities"]
            activity_opt_state = activity_update_result["opt_state"]

        # learning
        param_update_result = jpc.update_params(
            params=(model, skip_model),
            activities=activities,
            optim=param_optim,
            opt_state=param_opt_state,
            output=label_batch,
            input=img_batch,
            param_type=param_type
        )
        model = param_update_result["model"]
        skip_model = param_update_result["skip_model"]
        param_opt_state = param_update_result["opt_state"]

        if np.isinf(train_loss) or np.isnan(train_loss):
            print(
                f"Stopping training because of divergence, train loss={train_loss}"
            )
            break

        if ((iter+1) % test_every) == 0:
            avg_test_acc = evaluate(
                model=model,
                skip_model=skip_model, 
                test_loader=test_loader, 
                param_type=param_type
            )
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

Run¤

Note that on a CPU the script below should take about a minute to complete.

train(
    seed=SEED,
    model=model,
    skip_model=skip_model,
    param_type="mupc",
    activity_lr=ACTIVITY_LR,
    param_lr=PARAM_LR,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.016015, avg test accuracy=85.827324
Train iter 200, train loss=0.012215, avg test accuracy=88.541664
Train iter 300, train loss=0.009235, avg test accuracy=90.805290
Train iter 400, train loss=0.008675, avg test accuracy=91.286057
Train iter 500, train loss=0.011475, avg test accuracy=91.836937
Train iter 600, train loss=0.007697, avg test accuracy=92.177483
Train iter 700, train loss=0.007377, avg test accuracy=92.778442
Train iter 800, train loss=0.009710, avg test accuracy=92.477966
Train iter 900, train loss=0.009722, avg test accuracy=93.259216

For comparison, try to change to the standard parameterisation with param_type = "sp".

If you are using your own μPC-parameterised model class, then you can leave the default skip_model = None and param_type = "sp", as shown below.

train(
    seed=SEED,
    model=mupc_model,
    skip_model=None,
    param_type="sp",
    activity_lr=ACTIVITY_LR,
    param_lr=PARAM_LR,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    n_train_iters=N_TRAIN_ITERS
)
Train iter 100, train loss=0.016063, avg test accuracy=85.787262
Train iter 200, train loss=0.012327, avg test accuracy=88.571716
Train iter 300, train loss=0.009621, avg test accuracy=90.875404
Train iter 400, train loss=0.009056, avg test accuracy=91.336136
Train iter 500, train loss=0.011603, avg test accuracy=92.007210
Train iter 600, train loss=0.007781, avg test accuracy=91.887016
Train iter 700, train loss=0.006997, avg test accuracy=92.938705
Train iter 800, train loss=0.010020, avg test accuracy=93.129005
Train iter 900, train loss=0.009978, avg test accuracy=93.279243