Energy functions¤
JPC provides two main PC energy functions:
- jpc.pc_energy_fn() for standard PC networks, and
- jpc.hpc_energy_fn() for hybrid PC models (Tscshantz et al., 2023).
jpc.pc_energy_fn(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array
¤
Computes the free energy for a neural network with optional skip connections of the form
given parameters \(θ\), activities \(\mathbf{z}\), output \(\mathbf{z}_L = \mathbf{y}\), and optional input \(\mathbf{z}_0 = \mathbf{x}\) for supervised training. The activity of each layer \(\mathbf{z}_\ell\) is some function of the previous layer, e.g. ReLU\((\mathbf{W}_\ell \mathbf{z}_{\ell-1} + \mathbf{b}_\ell)\) for a fully connected layer with biases and ReLU as activation.
Note
The input \(x\) and output \(y\) correspond to the prior and observation of the generative model, respectively.
Main arguments:
params: Tuple with callable model (e.g. neural network) layers and optional skip connections.activities: List of activities for each layer free to vary.y: Observation or target of the generative model.
Other arguments:
x: Optional prior of the generative model (for supervised training).loss: Loss function to use at the output layer. Options are mean squared error"mse"(default) or cross-entropy"ce".param_type: Determines the parameterisation. Options are"sp"(standard parameterisation),"mupc"(μPC), or"ntp"(neural tangent parameterisation). See_get_param_scalings()for the specific scalings of these different parameterisations. Defaults to"sp".weight_decay: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay: \(\ell^2\) regulariser for the activities (0 by default).record_layers: IfTrue, returns the energy of each layer.
Returns:
The total or layer-wise energy normalised by the batch size.
jpc.hpc_energy_fn(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array
¤
Computes the free energy of an amortised PC network (Tscshantz et al., 2023)
given the equilibrated activities of the generator \(\mathbf{z}^*\) (target for the amortiser), the feedforward guesses of the amortiser \(\hat{\mathbf{z}}\), the amortiser's parameters \(θ\), input \(\mathbf{z}_0 = \mathbf{x}\), and optional output \(\mathbf{z}_L = \mathbf{y}\) for supervised training.
Note
The input \(x\) and output \(y\) are reversed compared to pc_energy_fn()
(\(x\) is the generator's target and \(y\) is its optional input or prior).
Just think of \(x\) and \(y\) as the actual input and output of the
amortiser, respectively.
Reference
@article{tscshantz2023hybrid,
title={Hybrid predictive coding: Inferring, fast and slow},
author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
journal={PLoS computational biology},
volume={19},
number={8},
pages={e1011280},
year={2023},
publisher={Public Library of Science San Francisco, CA USA}
}
Main arguments:
model: List of callable model (e.g. neural network) layers.equilib_activities: List of equilibrated activities reached by the generator and target for the amortiser.amort_activities: List of amortiser's feedforward guesses (initialisation) for the network activities.x: Input to the amortiser.y: Optional target of the amortiser (for supervised training).
Other arguments:
record_layers: IfTrue, returns energies for each layer.
Returns:
The total or layer-wise energy normalised by batch size.
jpc._get_param_scalings(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> list[float]
¤
Gets layer scalings for a given parameterisation.
Warning
param_type = "mupc" (μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
model: List of callable model (e.g. neural network) layers.input: input to the model.
Other arguments:
skip_model: Optional skip connection model.param_type: Determines the parameterisation. Options are"sp"(standard parameterisation),"mupc"(μPC), or"ntp"(neural tangent parameterisation). Defaults to"sp".
Returns:
List with scalings for each layer.