Theoretical PC energy of deep linear networks¤
This notebook demonstrates how to compute the theoretical PC energy at the inference equilibrium \(\mathcal{F}^*\) when \(\mathcal{F}|_{\nabla_{\mathbf{z}} \mathcal{F} = \mathbf{0}}\) for a deep linear network with input and output \((\mathbf{x}_i, \mathbf{y}_i)\) (see Innocenti et al., 2024)
\[\begin{equation}
\mathcal{F}^* = \frac{1}{2N} \sum_{i=1}^N (\mathbf{y}_i - W_{L:1}\mathbf{x}_i)^T S^{-1}(\mathbf{y}_i - W_{L:1}\mathbf{x}_i)
\end{equation}\]
where \(S = I_{d_y} + \sum_{\ell=2}^L (W_{L:\ell})(W_{L:\ell})^T\) and \(W_{k:\ell} = W_k \dots W_\ell\) for \(\ell, k \in 1,\dots, L\).
%%capture
!pip install torch==2.3.1
!pip install torchvision==0.18.1
!pip install plotly==5.11.0
!pip install -U kaleido
import jpc
import jax
from jax import vmap
import jax.numpy as jnp
import equinox as eqx
import equinox.nn as nn
import optax
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import plotly.graph_objs as go
import plotly.io as pio
pio.renderers.default = 'iframe'
import warnings
warnings.simplefilter('ignore') # ignore warnings
Hyperparameters¤
We define some global parameters, including network architecture, learning rate, batch size, etc.
SEED = 0
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
MAX_T1 = 300
TEST_EVERY = 10
N_TRAIN_ITERS = 100
Dataset¤
Some utils to fetch MNIST.
#@title data utils
def get_mnist_loaders(batch_size):
train_data = MNIST(train=True, normalise=True)
test_data = MNIST(train=False, normalise=True)
train_loader = DataLoader(
dataset=train_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = DataLoader(
dataset=test_data,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
return train_loader, test_loader
class MNIST(datasets.MNIST):
def __init__(self, train, normalise=True, save_dir="data"):
if normalise:
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=(0.1307), std=(0.3081)
)
]
)
else:
transform = transforms.Compose([transforms.ToTensor()])
super().__init__(save_dir, download=True, train=train, transform=transform)
def __getitem__(self, index):
img, label = super().__getitem__(index)
img = torch.flatten(img)
label = one_hot(label)
return img, label
def one_hot(labels, n_classes=10):
arr = torch.eye(n_classes)
return arr[labels]
Plotting¤
def plot_total_energies(energies):
n_train_iters = len(energies["theory"])
train_iters = [b+1 for b in range(n_train_iters)]
fig = go.Figure()
for energy_type, energy in energies.items():
is_theory = energy_type == "theory"
fig.add_traces(
go.Scatter(
x=train_iters,
y=energy,
name=energy_type,
mode="lines",
line=dict(
width=3,
dash="dash" if is_theory else "solid",
color="rgb(27, 158, 119)" if is_theory else "#00CC96"
),
legendrank=1 if is_theory else 2
)
)
fig.update_layout(
height=300,
width=450,
xaxis=dict(
title="Training iteration",
tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
),
yaxis=dict(
title="Energy",
nticks=3
),
font=dict(size=16),
)
fig.write_image("dln_theory_energy.pdf")
return fig
Linear network¤
We'll use a linear network with 10 hidden layers as an example.
key = jax.random.PRNGKey(0)
width, n_hidden = 300, 10
network = jpc.make_mlp(
key,
[784] + [width]*n_hidden + [10],
act_fn="linear",
use_bias=False
)
def evaluate(model, test_loader):
avg_test_loss, avg_test_acc = 0, 0
for batch_id, (img_batch, label_batch) in enumerate(test_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
test_loss, test_acc = jpc.test_discriminative_pc(
model=model,
output=label_batch,
input=img_batch
)
avg_test_loss += test_loss
avg_test_acc += test_acc
return avg_test_loss / len(test_loader), avg_test_acc / len(test_loader)
def train(
model,
lr,
batch_size,
max_t1,
test_every,
n_train_iters
):
optim = optax.adam(lr)
opt_state = optim.init(
(eqx.filter(model, eqx.is_array), None)
)
train_loader, test_loader = get_mnist_loaders(batch_size)
num_total_energies, theory_total_energies = [], []
for iter, (img_batch, label_batch) in enumerate(train_loader):
img_batch, label_batch = img_batch.numpy(), label_batch.numpy()
theory_total_energies.append(
jpc.linear_equilib_energy(
network=model,
x=img_batch,
y=label_batch
)
)
result = jpc.make_pc_step(
model,
optim,
opt_state,
output=label_batch,
input=img_batch,
max_t1=max_t1,
record_energies=True
)
model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
train_loss, t_max = result["loss"], result["t_max"]
num_total_energies.append(result["energies"][:, t_max-1].sum())
if ((iter+1) % test_every) == 0:
avg_test_loss, avg_test_acc = evaluate(model, test_loader)
print(
f"Train iter {iter+1}, train loss={train_loss:4f}, "
f"avg test accuracy={avg_test_acc:4f}"
)
if (iter+1) >= n_train_iters:
break
return {
"experiment": jnp.array(num_total_energies),
"theory": jnp.array(theory_total_energies)
}
Run¤
Below we plot the theoretical energy against the numerical one.
energies = train(
model=network,
lr=LEARNING_RATE,
batch_size=BATCH_SIZE,
test_every=TEST_EVERY,
max_t1=MAX_T1,
n_train_iters=N_TRAIN_ITERS
)
plot_total_energies(energies)