Discrete updates¤
JPC provides access to standard discrete optimisers to update the parameters of PC networks (jpc.update_params), and to both discrete (jpc.update_activities) and continuous optimisers (jpc.solve_inference) to solve the PC inference or activity dynamics.
jpc.update_activities(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Dict
¤
Updates activities of a predictive coding network with a given optax optimiser.
Warning
param_type = "mupc"
(μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.optim
: optax optimiser, e.g.optax.sgd()
.opt_state
: State of optax optimiser.output
: Observation or target of the generative model.
Other arguments:
input
: Optional prior of the generative model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.weight_decay
: Weight decay for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: Activity decay for the activities (0 by default).
Returns:
Dictionary with energy, updated activities, activity gradients, and optimiser state.
jpc.update_params(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Dict
¤
Updates parameters of a predictive coding network with a given optax optimiser.
Warning
param_type = "mupc"
(μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.optim
: optax optimiser, e.g.optax.sgd()
.opt_state
: State of optax optimiser.output
: Observation or target of the generative model.
Other arguments:
input
: Optional prior of the generative model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.weight_decay
: Weight decay for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: Activity decay for the activities (0 by default).
Returns:
Dictionary with model and optional skip model with updated parameters, parameter gradients, and optimiser state.