Skip to content

Gradients¤

Info

There are two similar functions to compute the gradient of the energy with respect to the activities: jpc.neg_activity_grad() and jpc.compute_activity_grad(). The first is used by jpc.solve_inference() as gradient flow, while the second is for compatibility with discrete optax optimisers such as gradient descent.

jpc.neg_activity_grad(t: float | int, activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], args: typing.Tuple[typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], int, str, str, diffrax._step_size_controller.base.AbstractStepSizeController]) -> PyTree[jax.Array] ¤

Computes the negative gradient of the PC energy with respect to the activities \(- ∇_{\mathbf{z}} \mathcal{F}\).

This defines an ODE system to be integrated by jpc.solve_pc_inference().

Main arguments:

  • t: Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve().
  • activities: List of activities for each layer free to vary.
  • args: 5-Tuple with:

    (i) Tuple with callable model layers and optional skip connections,

    (ii) model output (observation),

    (iii) model input (prior),

    (iv) loss specified at the output layer ("mse" as default or "ce"),

    (v) parameterisation type ("sp" as default, "mupc", or "ntp"),

    (vi) \(\ell^2\) regulariser for the weights (0 by default),

    (vii) spectral penalty for the weights (0 by default),

    (viii) \(\ell^2\) regulariser for the activities (0 by default), and

    (ix) diffrax controller for step size integration.

Returns:

List of negative gradients of the energy with respect to the activities.


jpc.compute_activity_grad(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> PyTree[jax.Array] ¤

Computes the gradient of the PC energy with respect to the activities \(∇_{\mathbf{z}} \mathcal{F}\).

Note

This function differs from jpc.neg_activity_grad() only in the sign of the gradient (positive as opposed to negative) and is called in jpc.update_activities() for use with any optax optimiser.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • y: Observation or target of the generative model.

Other arguments:

  • x: Optional prior of the generative model.
  • loss_id: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • weight_decay: \(\ell^2\) regulariser for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: \(\ell^2\) regulariser for the activities (0 by default).

Returns:

List of negative gradients of the energy with respect to the activities.


jpc.compute_pc_param_grads(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Tuple[jaxtyping.PyTree[jax.Array], jaxtyping.PyTree[jax.Array]] ¤

Computes the gradient of the PC energy with respect to model parameters \(∇_θ \mathcal{F}\).

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • y: Observation or target of the generative model.

Other arguments:

  • x: Optional prior of the generative model.
  • loss_id: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • weight_decay: \(\ell^2\) regulariser for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: \(\ell^2\) regulariser for the activities (0 by default).

Returns:

List of parameter gradients for each model layer.


jpc.compute_hpc_param_grads(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None) -> PyTree[jax.Array] ¤

Computes the gradient of the hybrid PC energy with respect to the amortiser's parameters \(∇_θ \mathcal{F}\).

Warning

The input \(x\) and output \(y\) are reversed compared to jpc.compute_pc_param_grads() (\(x\) is the generator's target and \(y\) is its optional input or prior). Just think of \(x\) and \(y\) as the actual input and output of the amortiser, respectively.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • equilib_activities: List of equilibrated activities reached by the generator and target for the amortiser.
  • amort_activities: List of amortiser's feedforward guesses (initialisation) for the model activities.
  • x: Input to the amortiser.
  • y: Optional target of the amortiser (for supervised training).

Returns:

List of parameter gradients for each model layer.