Training¤
JPC provides 2 single convenience functions to update the parameters of any PC-compatible model with PC:
- jpc.make_pc_step() to perform an update using standard PC, and
- jpc.make_hpc_step() to use hybrid PC (Tscshantz et al., 2023).
jpc.make_pc_step(model: PyTree[typing.Callable], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', ode_solver: AbstractSolver = Heun(), max_t1: int = 20, dt: jaxtyping.Shaped[Array, ''] | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001), skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0, key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2], NoneType] = None, layer_sizes: typing.Optional[jaxtyping.PyTree[int]] = None, batch_size: typing.Optional[int] = None, sigma: Shaped[Array, ''] = 0.05, record_activities: bool = False, record_energies: bool = False, record_every: int = None, activity_norms: bool = False, param_norms: bool = False, grad_norms: bool = False, calculate_accuracy: bool = False) -> typing.Dict
¤
Performs one model parameter update with predictive coding.
Main arguments:
model
: List of callable model (e.g. neural network) layers.optim
: Optax optimiser, e.g.optax.sgd()
.opt_state
: State of Optax optimiser.output
: Observation or target of the generative model.
Note
key
, layer_sizes
and batch_size
must be passed if input = None
,
since unsupervised training will be assumed and activities need to be
initialised randomly.
Other arguments:
input
: Optional prior of the generative model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.ode_solver
: diffrax ODE solver to be used. Default isHeun
, a 2nd order explicit Runge--Kutta method.max_t1
: Maximum end of integration region (20 by default).dt
: Integration step size. Defaults toNone
since the defaultstepsize_controller
will automatically determine it.stepsize_controller
: diffrax controller for step size integration. Defaults toPIDController
. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.skip_model
: Optional list of callable skip connection functions.weight_decay
: Weight decay for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: Activity decay for the activities (0 by default).key
:jax.random.PRNGKey
for random initialisation of activities.layer_sizes
: Dimension of all layers (input, hidden and output).batch_size
: Dimension of data batch for activity initialisation.sigma
: Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2.record_activities
: IfTrue
, returns activities at every inference iteration.record_energies
: IfTrue
, returns layer-wise energies at every inference iteration.record_every
: int determining the sampling frequency the integration steps.activity_norms
: IfTrue
, computes \(\ell^2\) norm of the activities.param_norms
: IfTrue
, computes \(\ell^2\) norm of the parameters.grad_norms
: IfTrue
, computes \(\ell^2\) norm of parameter gradients.calculate_accuracy
: IfTrue
, computes the training accuracy.
Returns:
Dict including model (and optional skip model) with updated parameters, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above).
Raises:
ValueError
for inconsistent inputs and invalid losses.
jpc.make_hpc_step(generator: PyTree[typing.Callable], amortiser: PyTree[typing.Callable], optims: typing.Tuple[optax._src.base.GradientTransformationExtraArgs], opt_states: typing.Tuple[typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, ode_solver: AbstractSolver = Heun(), max_t1: int = 300, dt: jaxtyping.Shaped[Array, ''] | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001), record_activities: bool = False, record_energies: bool = False) -> typing.Dict
¤
Performs one update of the parameters of a hybrid predictive coding network (Tscshantz et al., 2023).
Reference
@article{tscshantz2023hybrid,
title={Hybrid predictive coding: Inferring, fast and slow},
author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
journal={PLoS Computational Biology},
volume={19},
number={8},
pages={e1011280},
year={2023},
publisher={Public Library of Science San Francisco, CA USA}
}
Note
The input and output of the generator are the output and input of the amortiser, respectively.
Main arguments:
generator
: List of callable layers for the generative model.amortiser
: List of callable layers for model amortising the inference of thegenerator
.optims
: Optax optimisers (e.g.optax.sgd()
), one for each model.opt_states
: State of Optax optimisers, one for each model.output
: Observation of the generator, input to the amortiser.input
: Optional prior of the generator, target for the amortiser.
Other arguments:
ode_solver
: diffrax ODE solver to be used. Default isHeun
, a 2nd order explicit Runge--Kutta method.max_t1
: Maximum end of integration region (300 by default).dt
: Integration step size. Defaults toNone
since the defaultstepsize_controller
will automatically determine it.stepsize_controller
: diffrax controller for step size integration. Defaults toPIDController
. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.record_activities
: IfTrue
, returns activities at every inference iteration.record_energies
: IfTrue
, returns layer-wise energies at every inference iteration.
Returns:
Dict including models with updated parameters, optimiser state for each model, model activities, last inference step for the generator, MSE losses, and energies.