Testing¤
JPC provides a few convenience functions to test different types of PC network (PCN):
- jpc.test_discriminative_pc() for test loss and accuracy of discriminative PCNs;
- jpc.test_generative_pc() for accuracy and output predictions of generative PCNs; and
- jpc.test_hpc() for accuracy of all models (amortiser, generator, & hybrid) as well as output predictions.
jpc.test_discriminative_pc(model: PyTree[typing.Callable], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, loss: str = 'mse', param_type: str = 'sp') -> typing.Tuple[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[Array, '']]
¤
Computes test metrics for a discriminative predictive coding network.
Main arguments:
model
: List of callable model (e.g. neural network) layers.output
: Observation or target of the generative model.input
: Optional prior of the generative model.
Other arguments:
skip_model
: Optional skip connection model.loss
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.
Returns:
Test loss and accuracy of output predictions.
jpc.test_generative_pc(model: PyTree[typing.Callable], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], layer_sizes: PyTree[int], batch_size: int, *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, loss_id: str = 'mse', param_type: str = 'sp', sigma: Shaped[Array, ''] = 0.05, ode_solver: AbstractSolver = Heun(), max_t1: int = 500, dt: jaxtyping.Shaped[Array, ''] | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001), weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Tuple[jaxtyping.Shaped[Array, ''], jax.Array]
¤
Computes test metrics for a generative predictive coding network.
Gets output predictions (e.g. of an image given a label) with a feedforward pass and calculates accuracy of inferred input (e.g. of a label given an image).
Main arguments:
model
: List of callable model (e.g. neural network) layers.output
: Observation or target of the generative model.input
: Prior of the generative model.key
:jax.random.PRNGKey
for random initialisation of activities.layer_sizes
: Dimension of all layers (input, hidden and output).batch_size
: Dimension of data batch for activity initialisation.
Other arguments:
skip_model
: Optional skip connection model.loss_id
: Loss function to use at the output layer. Options are mean squared error"mse"
(default) or cross-entropy"ce"
.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.sigma
: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.ode_solver
: diffrax ODE solver to be used. Default isHeun
, a 2nd order explicit Runge--Kutta method.max_t1
: Maximum end of integration region (500 by default).dt
: Integration step size. Defaults to None since the defaultstepsize_controller
will automatically determine it.stepsize_controller
: diffrax controller for step size integration. Defaults toPIDController
. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.weight_decay
: Weight decay for the weights (0 by default).spectral_penalty
: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).activity_decay
: Activity decay for the activities (0 by default).
Returns:
Accuracy and output predictions.
jpc.test_hpc(generator: PyTree[typing.Callable], amortiser: PyTree[typing.Callable], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], layer_sizes: PyTree[int], batch_size: int, sigma: Shaped[Array, ''] = 0.05, ode_solver: AbstractSolver = Heun(), max_t1: int = 500, dt: jaxtyping.Shaped[Array, ''] | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001)) -> typing.Tuple[jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[Array, ''], jaxtyping.Shaped[Array, ''], jax.Array]
¤
Computes test metrics for hybrid predictive coding trained in a supervised manner.
Calculates input accuracy of (i) amortiser, (ii) generator, and (iii) hybrid (amortiser + generator). Also returns output predictions (e.g. of an image given a label) with a feedforward pass of the generator.
Note
The input and output of the generator are the output and input of the amortiser, respectively.
Main arguments:
generator
: List of callable layers for the generative model.amortiser
: List of callable layers for model amortising the inference of thegenerator
.output
: Observation or target of the generative model.input
: Optional prior of the generator, target for the amortiser.key
:jax.random.PRNGKey
for random initialisation of activities.layer_sizes
: Dimension of all layers (input, hidden and output).batch_size
: Dimension of data batch for initialisation of activities.
Other arguments:
sigma
: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.ode_solver
: diffrax ODE solver to be used. Default isHeun
, a 2nd order explicit Runge--Kutta method.max_t1
: Maximum end of integration region (500 by default).dt
: Integration step size. Defaults to None since the defaultstepsize_controller
will automatically determine it.stepsize_controller
: diffrax controller for step size integration. Defaults toPIDController
. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.
Returns:
Accuracies of all models and output predictions.