Skip to content

Testing¤

jpc.test_discriminative_pc(model: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, loss: str = 'MSE', skip_model: Optional[PyTree[Callable]] = None) -> Tuple[Array, Array] ¤

Computes test metrics for a discriminative predictive coding network.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Other arguments:

  • loss: - loss: Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
  • skip_model: Optional list of callable skip connection functions.

Returns:

Test loss and accuracy of output predictions.


jpc.test_generative_pc(model: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, key: PRNGKeyArray, layer_sizes: PyTree[int], batch_size: int, sigma: Array = 0.05, ode_solver: AbstractSolver = Heun(scan_kind=None), max_t1: int = 500, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001,atol=0.001,pcoeff=0,icoeff=1,dcoeff=0,dtmin=None,dtmax=None,force_dtmin=True,step_ts=None,jump_ts=None,factormin=0.2,factormax=10.0,norm=<function rms_norm>,safety=0.9,error_order=None), skip_model: Optional[PyTree[Callable]] = None) -> Tuple[Array, Array] ¤

Computes test metrics for a generative predictive coding network.

Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image).

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.
  • key: jax.random.PRNGKey for random initialisation of activities.
  • layer_sizes: Dimension of all layers (input, hidden and output).
  • batch_size: Dimension of data batch for activity initialisation.

Other arguments:

  • sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (500 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.

Returns:

Accuracy and output predictions.


jpc.test_hpc(generator: PyTree[typing.Callable], amortiser: PyTree[typing.Callable], output: ArrayLike, input: ArrayLike, key: PRNGKeyArray, layer_sizes: PyTree[int], batch_size: int, sigma: Array = 0.05, ode_solver: AbstractSolver = Heun(scan_kind=None), max_t1: int = 500, dt: Array | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001,atol=0.001,pcoeff=0,icoeff=1,dcoeff=0,dtmin=None,dtmax=None,force_dtmin=True,step_ts=None,jump_ts=None,factormin=0.2,factormax=10.0,norm=<function rms_norm>,safety=0.9,error_order=None)) -> Tuple[Array, Array, Array, Array] ¤

Computes test metrics for hybrid predictive coding trained in a supervised manner.

Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator.

Note

The input and output of the generator are the output and input of the amortiser, respectively.

Main arguments:

  • generator: List of callable layers for the generative model.
  • amortiser: List of callable layers for model amortising the inference of the generator.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generator, target for the amortiser.
  • key: jax.random.PRNGKey for random initialisation of activities.
  • layer_sizes: Dimension of all layers (input, hidden and output).
  • batch_size: Dimension of data batch for initialisation of activities.

Other arguments:

  • sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
  • ode_solver: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (500 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.

Returns:

Accuracies of all models and output predictions.