Skip to content

Training¤

jpc.make_pc_step(model: PyTree[typing.Callable], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], output: ArrayLike, input: Optional[ArrayLike] = None, loss_id: str = 'MSE', ode_solver: AbstractSolver = Heun(scan_kind=None), max_t1: int = 20, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001,atol=0.001,pcoeff=0,icoeff=1,dcoeff=0,dtmin=None,dtmax=None,force_dtmin=True,step_ts=None,jump_ts=None,factormin=0.2,factormax=10.0,norm=<function rms_norm>,safety=0.9,error_order=None), skip_model: Optional[PyTree[Callable]] = None, key: Optional[PRNGKeyArray] = None, layer_sizes: Optional[PyTree[int]] = None, batch_size: Optional[int] = None, sigma: Array = 0.05, record_activities: bool = False, record_energies: bool = False, record_every: int = None, activity_norms: bool = False, param_norms: bool = False, grad_norms: bool = False, calculate_accuracy: bool = False) -> Dict ¤

Updates network parameters with predictive coding.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • optim: Optax optimiser, e.g. optax.sgd().
  • opt_state: State of Optax optimiser.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Note

key, layer_sizes and batch_size must be passed if input is None, since unsupervised training will be assumed and activities need to be initialised randomly.

Other arguments:

  • loss_id: Loss function for the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (20 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.
  • skip_model: Optional list of callable skip connection functions.
  • key: jax.random.PRNGKey for random initialisation of activities.
  • layer_sizes: Dimension of all layers (input, hidden and output).
  • batch_size: Dimension of data batch for activity initialisation.
  • sigma: Standard deviation for Gaussian to sample activities from for random initialisation. Defaults to 5e-2.
  • record_activities: If True, returns activities at every inference iteration.
  • record_energies: If True, returns layer-wise energies at every inference iteration.
  • record_every: int determining the sampling frequency the integration steps.
  • activity_norms: If True, computes l2 norm of the activities.
  • param_norms: If True, computes l2 norm of the parameters.
  • grad_norms: If True, computes l2 norm of parameter gradients.
  • calculate_accuracy: If True, computes the training accuracy.

Returns:

Dict including model (and optional skip model) with updated parameters, optimiser, updated optimiser state, loss, energies, activities, and optionally other metrics (see other args above).

Raises:

  • ValueError for inconsistent inputs and invalid losses.

jpc.make_hpc_step(generator: PyTree[typing.Callable], amortiser: PyTree[typing.Callable], optims: Tuple[optax._src.base.GradientTransformationExtraArgs], opt_states: Tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], output: ArrayLike, input: Optional[ArrayLike] = None, ode_solver: AbstractSolver = Heun(scan_kind=None), max_t1: int = 300, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001,atol=0.001,pcoeff=0,icoeff=1,dcoeff=0,dtmin=None,dtmax=None,force_dtmin=True,step_ts=None,jump_ts=None,factormin=0.2,factormax=10.0,norm=<function rms_norm>,safety=0.9,error_order=None), record_activities: bool = False, record_energies: bool = False) -> Dict ¤

Updates parameters of a hybrid predictive coding network.

Reference
@article{tscshantz2023hybrid,
  title={Hybrid predictive coding: Inferring, fast and slow},
  author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
  journal={PLoS Computational Biology},
  volume={19},
  number={8},
  pages={e1011280},
  year={2023},
  publisher={Public Library of Science San Francisco, CA USA}
}

Note

The input and output of the generator are the output and input of the amortiser, respectively.

Main arguments:

  • generator: List of callable layers for the generative model.
  • amortiser: List of callable layers for model amortising the inference of the generator.
  • optims: Optax optimisers (e.g. optax.sgd()), one for each model.
  • opt_states: State of Optax optimisers, one for each model.
  • output: Observation of the generator, input to the amortiser.
  • input: Optional prior of the generator, target for the amortiser.

Other arguments:

  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method..
  • max_t1: Maximum end of integration region (300 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.
  • record_activities: If True, returns activities at every inference iteration.
  • record_energies: If True, returns layer-wise energies at every inference iteration.

Returns:

Dict including models with updated parameters, optimiser and state for each model, model activities, last inference step for the generator, MSE losses, and energies.