Gradients¤
Info
There are two similar functions to compute the gradient of the energy with
respect to the activities: jpc.neg_activity_grad()
and jpc.compute_activity_grad().
The first is used by jpc.solve_inference()
as gradient flow, while the second is for compatibility with discrete
optax optimisers such as
gradient descent.
jpc.neg_activity_grad(t: float | int, activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], args: typing.Tuple[typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], int, str, str, diffrax._step_size_controller.base.AbstractStepSizeController]) -> PyTree[jax.Array]
¤
Computes the negative gradient of the PC energy with respect to the activities \(- ∇_{\mathbf{z}} \mathcal{F}\).
This defines an ODE system to be integrated by jpc.solve_pc_inference().
Main arguments:
t: Time step of the ODE system, used for downstream integration bydiffrax.diffeqsolve().activities: List of activities for each layer free to vary.-
args: 5-Tuple with:(i) Tuple with callable model layers and optional skip connections,
(ii) model output (observation),
(iii) model input (prior),
(iv) loss specified at the output layer (
"mse"as default or"ce"),(v) parameterisation type (
"sp"as default,"mupc", or"ntp"),(vi) \(\ell^2\) regulariser for the weights (0 by default),
(vii) spectral penalty for the weights (0 by default),
(viii) \(\ell^2\) regulariser for the activities (0 by default), and
(ix) diffrax controller for step size integration.
Returns:
List of negative gradients of the energy with respect to the activities.
jpc.compute_activity_grad(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> PyTree[jax.Array]
¤
Computes the gradient of the PC energy with respect to the activities \(∇_{\mathbf{z}} \mathcal{F}\).
Note
This function differs from jpc.neg_activity_grad()
only in the sign of the gradient (positive as opposed to negative) and
is called in jpc.update_activities()
for use with any optax
optimiser.
Main arguments:
params: Tuple with callable model layers and optional skip connections.activities: List of activities for each layer free to vary.y: Observation or target of the generative model.
Other arguments:
x: Optional prior of the generative model.loss_id: Loss function to use at the output layer. Options are mean squared error"mse"(default) or cross-entropy"ce".param_type: Determines the parameterisation. Options are"sp"(standard parameterisation),"mupc"(μPC), or"ntp"(neural tangent parameterisation). See_get_param_scalings()for the specific scalings of these different parameterisations. Defaults to"sp".weight_decay: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay: \(\ell^2\) regulariser for the activities (0 by default).
Returns:
List of negative gradients of the energy with respect to the activities.
jpc.compute_pc_param_grads(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Tuple[jaxtyping.PyTree[jax.Array], jaxtyping.PyTree[jax.Array]]
¤
Computes the gradient of the PC energy with respect to model parameters \(∇_θ \mathcal{F}\).
Main arguments:
params: Tuple with callable model layers and optional skip connections.activities: List of activities for each layer free to vary.y: Observation or target of the generative model.
Other arguments:
x: Optional prior of the generative model.loss_id: Loss function to use at the output layer. Options are mean squared error"mse"(default) or cross-entropy"ce".param_type: Determines the parameterisation. Options are"sp"(standard parameterisation),"mupc"(μPC), or"ntp"(neural tangent parameterisation). See_get_param_scalings()for the specific scalings of these different parameterisations. Defaults to"sp".weight_decay: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay: \(\ell^2\) regulariser for the activities (0 by default).
Returns:
List of parameter gradients for each model layer.
jpc.compute_hpc_param_grads(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None) -> PyTree[jax.Array]
¤
Computes the gradient of the hybrid PC energy with respect to the amortiser's parameters \(∇_θ \mathcal{F}\).
Warning
The input \(x\) and output \(y\) are reversed compared to
jpc.compute_pc_param_grads()
(\(x\) is the generator's target and \(y\) is its optional input or prior).
Just think of \(x\) and \(y\) as the actual input and output of the
amortiser, respectively.
Main arguments:
model: List of callable model (e.g. neural network) layers.equilib_activities: List of equilibrated activities reached by the generator and target for the amortiser.amort_activities: List of amortiser's feedforward guesses (initialisation) for the model activities.x: Input to the amortiser.y: Optional target of the amortiser (for supervised training).
Returns:
List of parameter gradients for each model layer.