Energy functions¤
JPC provides two main PC energy functions:
- jpc.pc_energy_fn() for standard PC networks, and
- jpc.hpc_energy_fn() for hybrid PC models (Tscshantz et al., 2023).
jpc.pc_energy_fn(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array
¤
Computes the free energy for a neural network with optional skip connections of the form
given parameters \(θ\), activities \(\mathbf{z}\), output \(\mathbf{z}_L = \mathbf{y}\), and optional input \(\mathbf{z}_0 = \mathbf{x}\) for supervised training. The activity of each layer \(\mathbf{z}_\ell\) is some function of the previous layer, e.g. ReLU\((\mathbf{W}_\ell \mathbf{z}_{\ell-1} + \mathbf{b}_\ell)\) for a fully connected layer with biases and ReLU as activation.
Note
The input \(x\) and output \(y\) correspond to the prior and observation of the generative model, respectively.
Main arguments:
params
: Tuple with callable model (e.g. neural network) layers and optional skip connections.activities
: List of activities for each layer free to vary.y
: Observation or target of the generative model.
Other arguments:
x
: Optional prior of the generative model (for supervised training).loss
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.weight_decay
: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: \(\ell^2\) regulariser for the activities (0 by default).record_layers
: IfTrue
, returns the energy of each layer.
Returns:
The total or layer-wise energy normalised by the batch size.
jpc.hpc_energy_fn(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array
¤
Computes the free energy of an amortised PC network (Tscshantz et al., 2023)
given the equilibrated activities of the generator \(\mathbf{z}^*\) (target for the amortiser), the feedforward guesses of the amortiser \(\hat{\mathbf{z}}\), the amortiser's parameters \(θ\), input \(\mathbf{z}_0 = \mathbf{x}\), and optional output \(\mathbf{z}_L = \mathbf{y}\) for supervised training.
Note
The input \(x\) and output \(y\) are reversed compared to pc_energy_fn()
(\(x\) is the generator's target and \(y\) is its optional input or prior).
Just think of \(x\) and \(y\) as the actual input and output of the
amortiser, respectively.
Reference
@article{tscshantz2023hybrid,
title={Hybrid predictive coding: Inferring, fast and slow},
author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
journal={PLoS computational biology},
volume={19},
number={8},
pages={e1011280},
year={2023},
publisher={Public Library of Science San Francisco, CA USA}
}
Main arguments:
model
: List of callable model (e.g. neural network) layers.equilib_activities
: List of equilibrated activities reached by the generator and target for the amortiser.amort_activities
: List of amortiser's feedforward guesses (initialisation) for the network activities.x
: Input to the amortiser.y
: Optional target of the amortiser (for supervised training).
Other arguments:
record_layers
: IfTrue
, returns energies for each layer.
Returns:
The total or layer-wise energy normalised by batch size.
jpc._get_param_scalings(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> list[float]
¤
Gets layer scalings for a given parameterisation.
Warning
param_type = "mupc"
(μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
model
: List of callable model (e.g. neural network) layers.input
: input to the model.
Other arguments:
skip_model
: Optional skip connection model.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). Defaults to"sp"
.
Returns:
List with scalings for each layer.