Gradients¤
Info
There are two similar functions to compute the gradient of the energy with
respect to the activities: jpc.neg_activity_grad()
and jpc.compute_activity_grad()
.
The first is used by jpc.solve_inference()
as gradient flow, while the second is for compatibility with discrete
optax optimisers such as
gradient descent.
jpc.neg_activity_grad(t: float | int, activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], args: typing.Tuple[typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], int, str, str, diffrax._step_size_controller.base.AbstractStepSizeController]) -> PyTree[jax.Array]
¤
Computes the negative gradient of the PC energy with respect to the activities \(- ∇_{\mathbf{z}} \mathcal{F}\).
This defines an ODE system to be integrated by jpc.solve_pc_inference()
.
Main arguments:
t
: Time step of the ODE system, used for downstream integration bydiffrax.diffeqsolve()
.activities
: List of activities for each layer free to vary.-
args
: 5-Tuple with:(i) Tuple with callable model layers and optional skip connections,
(ii) model output (observation),
(iii) model input (prior),
(iv) loss specified at the output layer (
"mse"
as default or"ce"
),(v) parameterisation type (
"sp"
as default,"mupc"
, or"ntp"
),(vi) \(\ell^2\) regulariser for the weights (0 by default),
(vii) spectral penalty for the weights (0 by default),
(viii) \(\ell^2\) regulariser for the activities (0 by default), and
(ix) diffrax controller for step size integration.
Returns:
List of negative gradients of the energy with respect to the activities.
jpc.compute_activity_grad(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType], loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> PyTree[jax.Array]
¤
Computes the gradient of the PC energy with respect to the activities \(∇_{\mathbf{z}} \mathcal{F}\).
Note
This function differs from jpc.neg_activity_grad()
only in the sign of the gradient (positive as opposed to negative) and
is called in jpc.update_activities()
for use with any optax
optimiser.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.y
: Observation or target of the generative model.
Other arguments:
x
: Optional prior of the generative model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.weight_decay
: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: \(\ell^2\) regulariser for the activities (0 by default).
Returns:
List of negative gradients of the energy with respect to the activities.
jpc.compute_pc_param_grads(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Tuple[jaxtyping.PyTree[jax.Array], jaxtyping.PyTree[jax.Array]]
¤
Computes the gradient of the PC energy with respect to model parameters \(∇_θ \mathcal{F}\).
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.y
: Observation or target of the generative model.
Other arguments:
x
: Optional prior of the generative model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.weight_decay
: \(\ell^2\) regulariser for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: \(\ell^2\) regulariser for the activities (0 by default).
Returns:
List of parameter gradients for each model layer.
jpc.compute_hpc_param_grads(model: PyTree[typing.Callable], equilib_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], amort_activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None) -> PyTree[jax.Array]
¤
Computes the gradient of the hybrid PC energy with respect to the amortiser's parameters \(∇_θ \mathcal{F}\).
Warning
The input \(x\) and output \(y\) are reversed compared to
jpc.compute_pc_param_grads()
(\(x\) is the generator's target and \(y\) is its optional input or prior).
Just think of \(x\) and \(y\) as the actual input and output of the
amortiser, respectively.
Main arguments:
model
: List of callable model (e.g. neural network) layers.equilib_activities
: List of equilibrated activities reached by the generator and target for the amortiser.amort_activities
: List of amortiser's feedforward guesses (initialisation) for the model activities.x
: Input to the amortiser.y
: Optional target of the amortiser (for supervised training).
Returns:
List of parameter gradients for each model layer.