Skip to content

Energy functions¤

JPC provides two main PC energy functions:

jpc.pc_energy_fn(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array ¤

Computes the free energy for a neural network with optional skip connections of the form

\[ \mathcal{F}(\mathbf{z}; θ) = 1/N \sum_i^N \sum_{\ell=1}^L || \mathbf{z}_{i, \ell} - f_\ell(\mathbf{z}_{i, \ell-1}; θ) ||^2 \]

given parameters \(θ\), activities \(\mathbf{z}\), output \(\mathbf{z}_L = \mathbf{y}\), and optional input \(\mathbf{z}_0 = \mathbf{x}\) for supervised training. The activity of each layer \(\mathbf{z}_\ell\) is some function of the previous layer, e.g. ReLU\((\mathbf{W}_\ell \mathbf{z}_{\ell-1} + \mathbf{b}_\ell)\) for a fully connected layer with biases and ReLU as activation.

Note

The input \(x\) and output \(y\) correspond to the prior and observation of the generative model, respectively.

Main arguments:

  • params: Tuple with callable model (e.g. neural network) layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • y: Observation or target of the generative model.

Other arguments:

  • x: Optional prior of the generative model (for supervised training).
  • loss: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • weight_decay: \(\ell^2\) regulariser for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: \(\ell^2\) regulariser for the activities (0 by default).
  • record_layers: If True, returns the energy of each layer.

Returns:

The total or layer-wise energy normalised by the batch size.


jpc.hpc_energy_fn(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, record_layers: bool = False) -> jaxtyping.Shaped[Array, ''] | jax.Array ¤

Computes the free energy of an amortised PC network (Tscshantz et al., 2023)

\[ \mathcal{F}(\mathbf{z}^*, \hat{\mathbf{z}}; θ) = 1/N \sum_i^N \sum_{\ell=1}^L || \mathbf{z}^*_{i, \ell} - f_\ell(\hat{\mathbf{z}}_{i, \ell-1}; θ) ||^2 \]

given the equilibrated activities of the generator \(\mathbf{z}^*\) (target for the amortiser), the feedforward guesses of the amortiser \(\hat{\mathbf{z}}\), the amortiser's parameters \(θ\), input \(\mathbf{z}_0 = \mathbf{x}\), and optional output \(\mathbf{z}_L = \mathbf{y}\) for supervised training.

Note

The input \(x\) and output \(y\) are reversed compared to pc_energy_fn() (\(x\) is the generator's target and \(y\) is its optional input or prior). Just think of \(x\) and \(y\) as the actual input and output of the amortiser, respectively.

Reference
@article{tscshantz2023hybrid,
    title={Hybrid predictive coding: Inferring, fast and slow},
    author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
    journal={PLoS computational biology},
    volume={19},
    number={8},
    pages={e1011280},
    year={2023},
    publisher={Public Library of Science San Francisco, CA USA}
}

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • equilib_activities: List of equilibrated activities reached by the generator and target for the amortiser.
  • amort_activities: List of amortiser's feedforward guesses (initialisation) for the network activities.
  • x: Input to the amortiser.
  • y: Optional target of the amortiser (for supervised training).

Other arguments:

  • record_layers: If True, returns energies for each layer.

Returns:

The total or layer-wise energy normalised by batch size.


jpc._get_param_scalings(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> list[float] ¤

Gets layer scalings for a given parameterisation.

Warning

param_type = "mupc" (μPC) assumes that one is using jpc.make_mlp() to create the model.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • input: input to the model.

Other arguments:

  • skip_model: Optional skip connection model.
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). Defaults to "sp".

Returns:

List with scalings for each layer.