Skip to content

Continuous-time inferenceยค

The inference or activity dynamics of PC networks can be solved in either discrete or continuous time. jpc.solve_inference() leverages ODE solvers to integrate the continuous-time dynamics.

jpc.solve_inference(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', solver: AbstractSolver = Heun(), max_t1: int = 20, dt: float | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001, atol=0.001), weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0, record_iters: bool = False, record_every: int = None) -> PyTree[jax.Array] ยค

Solves the inference (activity) dynamics of a predictive coding network.

This is a wrapper around diffrax.diffeqsolve() to integrate the gradient ODE system jpc.neg_activity_grad() defining the PC inference dynamics.

\[ d\mathbf{z} / dt = - โˆ‡_{\mathbf{z}} \mathcal{F} \]

where \(\mathcal{F}\) is the free energy, \(\mathbf{z}\) are the activities, with \(\mathbf{z}_L\) clamped to some target and \(\mathbf{z}_0\) optionally set to some prior.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • output: Observation or target of the generative model.

Other arguments:

  • input: Optional prior of the generative model.
  • loss_id: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (ฮผPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • solver: diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (20 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.
  • weight_decay: \(\ell^2\) regulariser for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: \(\ell^2\) regulariser for the activities (0 by default).
  • record_iters: If True, returns all integration steps.
  • record_every: int determining the sampling frequency of the integration steps.

Returns:

List with solution of the activity dynamics for each layer.