Theoretical tools¤
JPC provides the following theoretical tools that can be used to study deep linear networks (DLNs) trained with PC:
- jpc.compute_linear_equilib_energy() to compute the theoretical PC energy at the solution of the activities for DLNs;
- jpc.compute_linear_activity_hessian() to compute the theoretical Hessian of the energy with respect to the activities of DLNs;
- jpc.compute_linear_activity_solution() to compute the analytical PC inference solution for DLNs.
jpc.compute_linear_equilib_energy(network: PyTree[equinox.nn._linear.Linear], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Array
¤
Computes the theoretical PC energy at the solution of the activities for a deep linear network (Innocenti et al. 2024):
where \(\mathbf{S} = \mathbf{I}_{d_y} + \sum_{\ell=2}^L (\mathbf{W}_{L:\ell})(\mathbf{W}_{L:\ell})^T\) and \(\mathbf{W}_{k:\ell} = \mathbf{W}_k \dots \mathbf{W}_\ell\) for \(\ell, k \in 1,\dots, L\).
Note
This expression assumes no biases. It could also be generalised to
other network architectures (e.g. ResNets) and parameterisations
(see Innocenti et al. 2025).
However, note that the equilibrated energy for ResNets and other
parameterisations can still be computed by getting the activity solution
with jpc.compute_linear_activity_solution()
and then plugging this into the standard PC energy
jpc.pc_energy_fn().
Example
In practice, this means that if you run, at any point in training, the
inference dynamics of any PC linear network to equilibrium, then
jpc.pc_energy_fn()
will return the same energy value as this function. For a demonstration, see
this example notebook.
Reference
@article{innocenti2025only,
title={Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?},
author={Innocenti, Francesco and Achour, El Mehdi and Singh, Ryan and Buckley, Christopher L},
journal={Advances in Neural Information Processing Systems},
volume={37},
pages={53649--53683},
year={2025}
}
Main arguments:
network
: Linear network defined as a list of Equinox Linear layers.x
: Network input.y
: Network output.
Returns:
Mean total theoretical energy over a data batch.
jpc.compute_linear_activity_hessian(Ws: PyTree[jax.Array], *, use_skips: bool = False, param_type: str = 'sp', activity_decay: bool = False, diag: bool = True, off_diag: bool = True) -> Array
¤
Computes the theoretical Hessian matrix of the PC energy with respect to the activities for a linear network, \((\mathbf{H}_{\mathbf{z}})_{\ell k} := \partial^2 \mathcal{F} / \partial \mathbf{z}_\ell \partial \mathbf{z}_k \in \mathbb{R}^{(NH)×(NH)}\) where \(N\) and \(H\) are the width and number of hidden layers, respectively (Innocenti et al., 2025).
Info
This function can be used (i) to study the inference landscape of linear
PC networks and (ii) to compute the analytical solution with
jpc.compute_linear_activity_solution()
.
Warning
This was highly hard-coded for quick experimental iteration with different models and parameterisations. The computation of the blocks could be implemented much more elegantly by fetching the transformation for each layer.
Reference
@article{innocenti2025mu,
title={$$ackslash$mu $ PC: Scaling Predictive Coding to 100+ Layer Networks},
author={Innocenti, Francesco and Achour, El Mehdi and Buckley, Christopher L},
journal={arXiv preprint arXiv:2505.13124},
year={2025}
}
Main arguments:
Ws
: List of all the network weight matrices.
Other arguments:
use_skips
: Whether to assume one-layer skip connections at every layer except from the input and to the output.False
by default.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.activity_decay
: \(\ell^2\) regulariser for the activities.diag
: Whether to compute the diagonal blocks of the Hessian.off-diag
: Whether to compute the off-diagonal blocks of the Hessian.
Returns:
The activity Hessian matrix of size \(NH×NH\) where \(N\) is the width and \(H\) is the number of hidden layers.
jpc.compute_linear_activity_solution(network: PyTree[equinox.nn._linear.Linear], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, use_skips: bool = False, param_type: str = 'sp', activity_decay: bool = False) -> PyTree[jax.Array]
¤
Computes the theoretical solution for the PC activities of a linear network (Innocenti et al., 2025).
where \((\mathbf{H}_{\mathbf{z}})_{\ell k} := \partial^2 \mathcal{F} / \partial \mathbf{z}_\ell \partial \mathbf{z}_k \in \mathbb{R}^{(NH)×(NH)}\)
is the Hessian of the energy with respect to the activities, and
\(\mathbf{b} \in \mathbb{R}^{NH}\) is a sparse vector depending only on the
data and associated weights. The activity Hessian is computed analytically
using jpc.compute_linear_activity_hessian()
.
Info
This can be used to study how linear PC networks learn when they perform perfect inference. An example notebook demonstration is in the works!
Reference
@article{innocenti2025mu,
title={$$ackslash$mu $ PC: Scaling Predictive Coding to 100+ Layer Networks},
author={Innocenti, Francesco and Achour, El Mehdi and Buckley, Christopher L},
journal={arXiv preprint arXiv:2505.13124},
year={2025}
}
Main arguments:
network
: Linear network defined as a list of Equinox Linear layers.x
: Network input.y
: Network output.
Other arguments:
use_skips
: Whether to assume one-layer skip connections at every layer except from the input and to the output.False
by default.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.activity_decay
: \(\ell^2\) regulariser for the activities.
Returns:
List of theoretical activities for each layer.