Skip to content

Gradients¤

Note

There are two similar functions to compute the activity gradient: jpc.neg_activity_grad and jpc.compute_activity_grad. The first is used by jpc.solve_inference as gradient flow, while the second is for compatibility with discrete optax optimisers such as gradient descent.

jpc.neg_activity_grad(t: float | int, activities: PyTree[ArrayLike], args: Tuple[Tuple[PyTree[Callable], Optional[PyTree[Callable]]], ArrayLike, Optional[ArrayLike], str, diffrax._step_size_controller.base.AbstractStepSizeController]) -> PyTree[Array] ¤

Computes the negative gradient of the energy with respect to the activities \(- \partial \mathcal{F} / \partial \mathbf{z}\).

This defines an ODE system to be integrated by solve_pc_inference.

Main arguments:

  • t: Time step of the ODE system, used for downstream integration by diffrax.diffeqsolve.
  • activities: List of activities for each layer free to vary.
  • args: 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration.

Returns:

List of negative gradients of the energy w.r.t. the activities.


jpc.compute_activity_grad(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], y: ArrayLike, x: Optional[ArrayLike], loss_id: str = 'MSE') -> PyTree[Array] ¤

Computes the gradient of the energy with respect to the activities \(\partial \mathcal{F} / \partial \mathbf{z}\).

Note

This function differs from neg_activity_grad, which computes the negative gradients, and is called in update_activities for use of any Optax optimiser.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • y: Observation or target of the generative model.
  • x: Optional prior of the generative model.

Other arguments:

  • loss_id: Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
  • energy_fn: Free energy to take the gradient of.

Returns:

List of negative gradients of the energy w.r.t. the activities.


jpc.compute_pc_param_grads(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], y: ArrayLike, x: Optional[ArrayLike] = None, loss_id: str = 'MSE') -> Tuple[PyTree[Array], PyTree[Array]] ¤

Computes the gradient of the PC energy with respect to model parameters \(\partial \mathcal{F} / \partial θ\).

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • y: Observation or target of the generative model.
  • x: Optional prior of the generative model.

Other arguments:

  • loss_id: Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE').

Returns:

List of parameter gradients for each network layer.


jpc.compute_hpc_param_grads(model: PyTree[typing.Callable], equilib_activities: PyTree[ArrayLike], amort_activities: PyTree[ArrayLike], x: ArrayLike, y: Optional[ArrayLike] = None) -> PyTree[Array] ¤

Computes the gradient of the hybrid energy with respect to an amortiser's parameters \(\partial \mathcal{F} / \partial θ\).

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • equilib_activities: List of equilibrated activities reached by the generator and target for the amortiser.
  • amort_activities: List of amortiser's feedforward guesses (initialisation) for the network activities.
  • x: Input to the amortiser.
  • y: Optional target of the amortiser (for supervised training).

Note

The input \(x\) and output \(y\) are reversed compared to compute_pc_param_grads (\(x\) is the generator's target and \(y\) is its optional input or prior). Just think of \(x\) and \(y\) as the actual input and output of the amortiser, respectively.

Returns:

List of parameter gradients for each network layer.