Skip to content

Theoretical PC energy of deep linear networks¤

Open in Colab

This notebook demonstrates how to compute the theoretical PC energy at the inference equilibrium \(\mathcal{F}^*\) when \(\mathcal{F}|_{\nabla_{\mathbf{z}} \mathcal{F} = \mathbf{0}}\) for a deep linear network with input and output \((\mathbf{x}_i, \mathbf{y}_i)\) (see Innocenti et al., 2024)

\[\begin{equation} \mathcal{F}^* = \frac{1}{2N} \sum_{i=1}^N (\mathbf{y}_i - W_{L:1}\mathbf{x}_i)^T S^{-1}(\mathbf{y}_i - W_{L:1}\mathbf{x}_i) \end{equation}\]

where \(S = I_{d_y} + \sum_{\ell=2}^L (W_{L:\ell})(W_{L:\ell})^T\) and \(W_{k:\ell} = W_k \dots W_\ell\) for \(\ell, k \in 1,\dots, L\).

%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install plotly==5.11.0
!pip install -U kaleido
import jpc

import jax
from jax import vmap
import jax.numpy as jnp
import equinox as eqx
import equinox.nn as nn
import optax

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import plotly.graph_objs as go
import plotly.io as pio

pio.renderers.default = 'iframe'

import warnings
warnings.simplefilter('ignore')  # ignore warnings

Hyperparameters¤

We define some global parameters, including network architecture, learning rate, batch size, etc.

SEED = 0
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
MAX_T1 = 300
TEST_EVERY = 10
N_TRAIN_ITERS = 100

Dataset¤

Some utils to fetch MNIST.

#@title data utils


def get_mnist_loaders(batch_size):
    train_data = MNIST(train=True, normalise=True)
    test_data = MNIST(train=False, normalise=True)
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]

Plotting¤

def plot_total_energies(energies):
    n_train_iters = len(energies["theory"])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    for energy_type, energy in energies.items():
        is_theory = energy_type == "theory"
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=energy,
                name=energy_type,
                mode="lines",
                line=dict(
                    width=3, 
                    dash="dash" if is_theory else "solid",
                    color="rgb(27, 158, 119)" if is_theory else "#00CC96"
                ),
                legendrank=1 if is_theory else 2
            )
        )

    fig.update_layout(
        height=300,
        width=450,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="Energy",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image("dln_theory_energy.pdf")
    return fig

Linear network¤

We'll use a linear network with 10 hidden layers as an example.

key = jax.random.PRNGKey(0)
width, n_hidden = 300, 10
network = jpc.make_mlp(
    key, 
    [784] + [width]*n_hidden + [10], 
    act_fn="linear", 
    use_bias=False
)

Train and test¤

To compute the theoretical energy, we can use jpc.linear_equilib_energy() (see the the docs for more details) which as clear from the equation above just takes the model and the data.

def evaluate(model, test_loader):
    avg_test_loss, avg_test_acc = 0, 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        test_loss, test_acc = jpc.test_discriminative_pc(
            model=model,
            output=label_batch,
            input=img_batch
        )
        avg_test_loss += test_loss
        avg_test_acc += test_acc

    return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)


def train( 
      model,  
      lr,
      batch_size,
      max_t1,
      test_every,
      n_train_iters
):
    optim = optax.adam(lr)
    opt_state = optim.init(
        (eqx.filter(model, eqx.is_array), None)
    )
    train_loader, test_loader = get_mnist_loaders(batch_size)

    num_total_energies, theory_total_energies = [], []
    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch, label_batch = img_batch.numpy(), label_batch.numpy()

        theory_total_energies.append(
            jpc.linear_equilib_energy(
                network=model, 
                x=img_batch, 
                y=label_batch
            )
        )
        result = jpc.make_pc_step(
            model,
            optim,
            opt_state,
            output=label_batch,
            input=img_batch,
            max_t1=max_t1,
            record_energies=True
        )
        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
        train_loss, t_max = result["loss"], result["t_max"]
        num_total_energies.append(result["energies"][:, t_max-1].sum())

        if ((iter+1) % test_every) == 0:
            avg_test_loss, avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

    return {
        "experiment": jnp.array(num_total_energies),
        "theory": jnp.array(theory_total_energies)
    }

Run¤

Below we plot the theoretical energy against the numerical one.

energies = train(
    model=network,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    test_every=TEST_EVERY,
    max_t1=MAX_T1,
    n_train_iters=N_TRAIN_ITERS
)
plot_total_energies(energies)
Train iter 10, train loss=0.070245, avg test accuracy=64.853767
Train iter 20, train loss=0.075006, avg test accuracy=69.961937
Train iter 30, train loss=0.055347, avg test accuracy=70.292465
Train iter 40, train loss=0.057690, avg test accuracy=78.275238
Train iter 50, train loss=0.052301, avg test accuracy=79.607368
Train iter 60, train loss=0.051747, avg test accuracy=80.909454
Train iter 70, train loss=0.053040, avg test accuracy=80.238380
Train iter 80, train loss=0.047872, avg test accuracy=81.029648
Train iter 90, train loss=0.051192, avg test accuracy=82.662262
Train iter 100, train loss=0.054825, avg test accuracy=83.533653