Gradients¤
Note
There are two similar functions to compute the activity gradient:
jpc.neg_activity_grad
and jpc.compute_activity_grad
. The first is used
by jpc.solve_inference
as gradient flow, while the second is for
compatibility with discrete optax optimisers such as gradient descent.
jpc.neg_activity_grad(t: float | int, activities: PyTree[ArrayLike], args: Tuple[Tuple[PyTree[Callable], Optional[PyTree[Callable]]], ArrayLike, Optional[ArrayLike], str, diffrax._step_size_controller.base.AbstractStepSizeController]) -> PyTree[Array]
¤
Computes the negative gradient of the energy with respect to the activities \(- \partial \mathcal{F} / \partial \mathbf{z}\).
This defines an ODE system to be integrated by solve_pc_inference
.
Main arguments:
t
: Time step of the ODE system, used for downstream integration bydiffrax.diffeqsolve
.activities
: List of activities for each layer free to vary.args
: 5-Tuple with (i) Tuple with callable model layers and optional skip connections, (ii) network output (observation), (iii) network input (prior), (iv) Loss specified at the output layer (MSE vs cross-entropy), and (v) diffrax controller for step size integration.
Returns:
List of negative gradients of the energy w.r.t. the activities.
jpc.compute_activity_grad(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], y: ArrayLike, x: Optional[ArrayLike], loss_id: str = 'MSE') -> PyTree[Array]
¤
Computes the gradient of the energy with respect to the activities \(\partial \mathcal{F} / \partial \mathbf{z}\).
Note
This function differs from neg_activity_grad
, which computes the
negative gradients, and is called in update_activities
for use of
any Optax optimiser.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.y
: Observation or target of the generative model.x
: Optional prior of the generative model.
Other arguments:
loss_id
: Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE').energy_fn
: Free energy to take the gradient of.
Returns:
List of negative gradients of the energy w.r.t. the activities.
jpc.compute_pc_param_grads(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], y: ArrayLike, x: Optional[ArrayLike] = None, loss_id: str = 'MSE') -> Tuple[PyTree[Array], PyTree[Array]]
¤
Computes the gradient of the PC energy with respect to model parameters \(\partial \mathcal{F} / \partial θ\).
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.y
: Observation or target of the generative model.x
: Optional prior of the generative model.
Other arguments:
loss_id
: Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
Returns:
List of parameter gradients for each network layer.
jpc.compute_hpc_param_grads(model: PyTree[typing.Callable], equilib_activities: PyTree[ArrayLike], amort_activities: PyTree[ArrayLike], x: ArrayLike, y: Optional[ArrayLike] = None) -> PyTree[Array]
¤
Computes the gradient of the hybrid energy with respect to an amortiser's parameters \(\partial \mathcal{F} / \partial θ\).
Main arguments:
model
: List of callable model (e.g. neural network) layers.equilib_activities
: List of equilibrated activities reached by the generator and target for the amortiser.amort_activities
: List of amortiser's feedforward guesses (initialisation) for the network activities.x
: Input to the amortiser.y
: Optional target of the amortiser (for supervised training).
Note
The input \(x\) and output \(y\) are reversed compared to compute_pc_param_grads
(\(x\) is the generator's target and \(y\) is its optional input or prior).
Just think of \(x\) and \(y\) as the actual input and output of the
amortiser, respectively.
Returns:
List of parameter gradients for each network layer.