Skip to content

μPC¤

Open in Colab

This notebook demonstrates how to train residual networks with μPC (Innocenti et al., 2025), a reparameterisation of PC that allows stable training of very deep (100+ layer) networks while also enabling zero-shot hyperparameter transfer. For a theoretical justification and extension of this parameterisation, see Innocenti et al., 2026.

python %%capture !pip install torch==2.3.1 !pip install torchvision==0.18.1

```python import jpc

import jax.random as jr import equinox as eqx import equinox.nn as nn import optax

import math import random import numpy as np from typing import List, Callable

import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms

import warnings warnings.simplefilter('ignore') # ignore warnings ```

```python

for reproducibility¤

def set_global_seed(seed): torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False ```

Hyperparameters¤

We define some global parameters, including the network architecture, learning rate, batch size, etc. We choose a network with "only" 30 layers and 128 hidden neurons so that it can run relatively fast on a CPU, but feel free to try deeper and wider networks.

```python SEED = 4329

INPUT_DIM = 784 WIDTH = 128 DEPTH = 30 OUTPUT_DIM = 10 ACT_FN = "relu"

ACTIVITY_LR = 5e-1 PARAM_LR = 1e-1 BATCH_SIZE = 64 TEST_EVERY = 100 N_TRAIN_ITERS = 900 ```

Dataset¤

Some utils to fetch MNIST.

```python def get_mnist_loaders(batch_size): train_data = MNIST(train=True, normalise=True) test_data = MNIST(train=False, normalise=True) train_loader = DataLoader( dataset=train_data, batch_size=batch_size, shuffle=True, drop_last=True ) test_loader = DataLoader( dataset=test_data, batch_size=batch_size, shuffle=True, drop_last=True ) return train_loader, test_loader

class MNIST(datasets.MNIST): def init(self, train, normalise=True, save_dir="data"): if normalise: transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=(0.1307), std=(0.3081) ) ] ) else: transform = transforms.Compose([transforms.ToTensor()]) super().init(save_dir, download=True, train=train, transform=transform)

def __getitem__(self, index):
    img, label = super().__getitem__(index)
    img = torch.flatten(img)
    label = one_hot(label)
    return img, label

def one_hot(labels, n_classes=10): arr = torch.eye(n_classes) return arr[labels] ```

Creating a μPC model¤

To parameterise a model with μPC, one can use a few convenience functions of jpc to create an MLP or fully connected network with jpc.make_mlp() and an associated skip model with jpc.make_skip model(). Note that μPC works only for a specific type of ResNet, namely one with one-layer skip connections at every layer except from the input to the next layer and from the penultimate layer to the output (see Innocenti et al., 2025), as shown below.

```python key = jr.PRNGKey(SEED)

MLP¤

model = jpc.make_mlp( key, input_dim=INPUT_DIM, width=WIDTH, depth=DEPTH, output_dim=OUTPUT_DIM, act_fn=ACT_FN, param_type="mupc" )

skip model¤

skip_model = jpc.make_skip_model(DEPTH) ```

At training and test time we would need to pass both models to relevant jpc functions and change the argument param_type = "mupc" (default is "sp" for standard parameterisation).

Alternatively, one could define a model class embedding the parameterisation itself and leave the above arguments to their default. This solution is more elegant but it can be harder to debug, at least for a fully connected architecture. However, if you would like to experiment with different parameterisations and more complex architectures (e.g. CNNs), we recommend this approach.

```python class ScaledLinear(eqx.Module): """Scaled linear transformation.""" linear: nn.Linear scaling: float = eqx.static_field()

def __init__(
        self,
        in_features,
        out_features,
        *,
        key,
        scaling=1.,
        param_type="sp",
        use_bias=False
):
    keys = jr.split(key, 2)
    linear = nn.Linear(
        in_features, 
        out_features, 
        use_bias=use_bias,
        key=keys[0]
    )
    if param_type == "mupc":
        W = jr.normal(keys[1], linear.weight.shape)
        linear = eqx.tree_at(lambda l: l.weight, linear, W)

    self.linear = linear
    self.scaling = scaling

def __call__(self, x):
    return self.scaling * self.linear(x)

class ResNetBlock(eqx.Module): """Identity residual block applying activation and a scaled linear layer.""" act_fn: Callable = eqx.static_field() scaled_linear: ScaledLinear

def __init__(
    self,
    in_features,
    out_features,
    *,
    key,
    scaling=1.,
    param_type="sp",
    use_bias=False,
    act_fn="linear"
):
    self.act_fn = act_fn
    self.scaled_linear = ScaledLinear(
        in_features=in_features,
        out_features=out_features,
        key=key,
        scaling=scaling,
        param_type=param_type,
        use_bias=use_bias
    )

def __call__(self, x):
    res_path = x
    x = self.act_fn(x)
    return self.scaled_linear(x) + res_path

class Readout(eqx.Module): """Final network layer applying activation and a scaled linear layer.""" act_fn: Callable = eqx.static_field() scaled_linear: ScaledLinear

def __init__(
    self,
    in_features,
    out_features,
    *,
    key,
    scaling=1.,
    param_type="sp",
    use_bias=False,
    act_fn="linear"
):
    self.act_fn = act_fn
    self.scaled_linear = ScaledLinear(
        in_features=in_features,
        out_features=out_features,
        key=key,
        scaling=scaling,
        param_type=param_type,
        use_bias=use_bias
    )

def __call__(self, x):
    x = self.act_fn(x)
    return self.scaled_linear(x)

class FCResNet(eqx.Module): """Fully-connected ResNet compatible with different parameterisations.""" layers: List[eqx.Module]

def __init__(
        self, 
        *,
        key, 
        in_dim, 
        width, 
        depth, 
        out_dim, 
        act_fn="linear", 
        use_bias=False,
        param_type="sp"
    ):
    act_fn = jpc.get_act_fn(act_fn)
    if param_type == "sp":
        in_scaling = 1.
        hidden_scaling = 1.
        out_scaling = 1.

    elif param_type == "mupc":
        in_scaling = 1 / math.sqrt(in_dim)
        hidden_scaling = 1 / math.sqrt(width * depth)
        out_scaling = 1 / width

    keys = jr.split(key, depth)
    self.layers = [
        ScaledLinear(
            key=keys[0],
            in_features=in_dim,
            out_features=width,
            scaling=in_scaling,
            param_type=param_type,
            use_bias=use_bias
        )
    ]

    for i in range(1, depth - 1):
        self.layers.append(
            ResNetBlock(
                key=keys[i],
                in_features=width,
                out_features=width,
                scaling=hidden_scaling,
                param_type=param_type,
                use_bias=use_bias,
                act_fn=act_fn
            )
        )

    self.layers.append(
        Readout(
            key=keys[-1],
            in_features=width,
            out_features=out_dim,
            scaling=out_scaling,
            param_type=param_type,
            use_bias=use_bias,
            act_fn=act_fn
        )
    )

def __call__(self, x):
    for f in self.layers:
        x = f(x)      
    return x

def __len__(self):
    return len(self.layers)

def __getitem__(self, idx):
    return self.layers[idx]

```

python mupc_model = FCResNet( key=key, in_dim=INPUT_DIM, width=WIDTH, depth=DEPTH, out_dim=OUTPUT_DIM, act_fn=ACT_FN, use_bias=False, param_type="mupc" )

The following makes sure that the models have identical weights.

python mupc_model = FCResNet( key=key, in_dim=INPUT_DIM, width=WIDTH, depth=DEPTH, out_dim=OUTPUT_DIM, act_fn=ACT_FN, use_bias=False, param_type="mupc" ) mupc_model = eqx.tree_at( where=lambda tree: tree[0].linear.weight, pytree=mupc_model, replace=model[0][1].weight ) for l in range(1, len(model)): mupc_model = eqx.tree_at( where=lambda tree: tree[l].scaled_linear.linear.weight, pytree=mupc_model, replace=model[l][1].weight )

Train and test¤

For training, we use the advanced API including the functions jpc.init_activities_with_ffwd() to initialise the activities, jpc.update_activities() to perform PC inference, and jpc.update_params() to update the weights. All these functions accept skip_model and param_type as arguments. Note, however, that one can replace these functions with jpc.make_pc_step(). For testing, we use jpc.test_discriminative_pc().

```python def evaluate(model, skip_model, test_loader, param_type): avg_test_acc = 0 for _, (img_batch, label_batch) in enumerate(test_loader): img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    _, test_acc = jpc.test_discriminative_pc(
        model=model,
        input=img_batch,
        output=label_batch,
        skip_model=skip_model,
        param_type=param_type
    )
    avg_test_acc += test_acc

return avg_test_acc / len(test_loader)

def train( seed,
model, skip_model, param_type, activity_lr,
param_lr, batch_size, test_every, n_train_iters ):
set_global_seed(seed) activity_optim = optax.sgd(activity_lr) param_optim = optax.adam(param_lr) param_opt_state = param_optim.init( (eqx.filter(model, eqx.is_array), skip_model) ) train_loader, test_loader = get_mnist_loaders(batch_size)

for iter, (img_batch, label_batch) in enumerate(train_loader):
    img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

    # initialise activities
    activities = jpc.init_activities_with_ffwd(
        model=model,
        input=img_batch,
        skip_model=skip_model,
        param_type=param_type
    )
    activity_opt_state = activity_optim.init(activities)
    train_loss = jpc.mse_loss(activities[-1], label_batch)

    # inference
    for t in range(len(model)):
        activity_update_result = jpc.update_pc_activities(
            params=(model, skip_model),
            activities=activities,
            optim=activity_optim,
            opt_state=activity_opt_state,
            output=label_batch,
            input=img_batch,
            param_type=param_type
        )
        activities = activity_update_result["activities"]
        activity_opt_state = activity_update_result["opt_state"]

    # learning
    param_update_result = jpc.update_pc_params(
        params=(model, skip_model),
        activities=activities,
        optim=param_optim,
        opt_state=param_opt_state,
        output=label_batch,
        input=img_batch,
        param_type=param_type
    )
    model = param_update_result["model"]
    skip_model = param_update_result["skip_model"]
    param_opt_state = param_update_result["opt_state"]

    if np.isinf(train_loss) or np.isnan(train_loss):
        print(
            f"Stopping training because of divergence, train loss={train_loss}"
        )
        break

    if ((iter+1) % test_every) == 0:
        avg_test_acc = evaluate(
            model=model,
            skip_model=skip_model, 
            test_loader=test_loader, 
            param_type=param_type
        )
        print(
            f"Train iter {iter+1}, train loss={train_loss:4f}, "
            f"avg test accuracy={avg_test_acc:4f}"
        )
        if (iter+1) >= n_train_iters:
            break

```

Run¤

Note that on a CPU the script below should take about a minute to complete.

python train( seed=SEED, model=model, skip_model=skip_model, param_type="mupc", activity_lr=ACTIVITY_LR, param_lr=PARAM_LR, batch_size=BATCH_SIZE, test_every=TEST_EVERY, n_train_iters=N_TRAIN_ITERS )

Train iter 100, train loss=0.016015, avg test accuracy=85.827324 Train iter 200, train loss=0.012215, avg test accuracy=88.541664 Train iter 300, train loss=0.009235, avg test accuracy=90.805290 Train iter 400, train loss=0.008675, avg test accuracy=91.286057 Train iter 500, train loss=0.011475, avg test accuracy=91.836937 Train iter 600, train loss=0.007697, avg test accuracy=92.177483 Train iter 700, train loss=0.007377, avg test accuracy=92.778442 Train iter 800, train loss=0.009710, avg test accuracy=92.477966 Train iter 900, train loss=0.009722, avg test accuracy=93.259216

For comparison, try to change to the standard parameterisation with param_type = "sp".

If you are using your own μPC-parameterised model class, then you can leave the default skip_model = None and param_type = "sp", as shown below.

python train( seed=SEED, model=mupc_model, skip_model=None, param_type="sp", activity_lr=ACTIVITY_LR, param_lr=PARAM_LR, batch_size=BATCH_SIZE, test_every=TEST_EVERY, n_train_iters=N_TRAIN_ITERS )

Train iter 100, train loss=0.016063, avg test accuracy=85.787262 Train iter 200, train loss=0.012327, avg test accuracy=88.571716 Train iter 300, train loss=0.009621, avg test accuracy=90.875404 Train iter 400, train loss=0.009056, avg test accuracy=91.336136 Train iter 500, train loss=0.011603, avg test accuracy=92.007210 Train iter 600, train loss=0.007781, avg test accuracy=91.887016 Train iter 700, train loss=0.006997, avg test accuracy=92.938705 Train iter 800, train loss=0.010020, avg test accuracy=93.129005 Train iter 900, train loss=0.009978, avg test accuracy=93.279243