Skip to content


jpc.test_discriminative_pc(model: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, loss: str = 'MSE', skip_model: Optional[PyTree[Callable]] = None) -> Tuple[Array, Array] ¤

Computes test metrics for a discriminative predictive coding network.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Other arguments:

  • loss: - loss: Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
  • skip_model: Optional list of callable skip connection functions.


Test loss and accuracy of output predictions.

jpc.test_generative_pc(model: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, key: PRNGKeyArray, layer_sizes: PyTree[int], batch_size: int, sigma: Array = 0.05, ode_solver: AbstractSolver = Heun(), max_t1: int = 500, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001), skip_model: Optional[PyTree[Callable]] = None) -> Tuple[Array, Array] ¤

Computes test metrics for a generative predictive coding network.

Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image).

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.
  • key: jax.random.PRNGKey for random initialisation of activities.
  • layer_sizes: Dimension of all layers (input, hidden and output).
  • batch_size: Dimension of data batch for activity initialisation.

Other arguments:

  • sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (500 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.


Accuracy and output predictions.

jpc.test_hpc(generator: PyTree[typing.Callable], amortiser: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, key: PRNGKeyArray, layer_sizes: PyTree[int], batch_size: int, sigma: Array = 0.05, ode_solver: AbstractSolver = Heun(), max_t1: int = 500, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001)) -> Tuple[Array, Array, Array, Array] ¤

Computes test metrics for hybrid predictive coding trained in a supervised manner.

Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator.


The input and output of the generator are the output and input of the amortiser, respectively.

Main arguments:

  • generator: List of callable layers for the generative model.
  • amortiser: List of callable layers for model amortising the inference of the generator.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generator, target for the amortiser.
  • key: jax.random.PRNGKey for random initialisation of activities.
  • layer_sizes: Dimension of all layers (input, hidden and output).
  • batch_size: Dimension of data batch for initialisation of activities.

Other arguments:

  • sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (500 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.


Accuracies of all models and output predictions.