Skip to content

Inference¤

jpc.solve_inference(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], output: ArrayLike, input: Optional[ArrayLike] = None, loss_id: str = 'MSE', solver: AbstractSolver = Heun(scan_kind=None), max_t1: int = 20, dt: float | int = None, stepsize_controller: AbstractStepSizeController = PIDController(rtol=0.001,atol=0.001,pcoeff=0,icoeff=1,dcoeff=0,dtmin=None,dtmax=None,force_dtmin=True,step_ts=None,jump_ts=None,factormin=0.2,factormax=10.0,norm=<function rms_norm>,safety=0.9,error_order=None), record_iters: bool = False, record_every: int = None) -> PyTree[Array] ¤

Solves the inference (activity) dynamics of a predictive coding network.

This is a wrapper around diffrax.diffeqsolve to integrate the gradient ODE system _neg_activity_grad defining the PC inference dynamics

\[ \partial \mathbf{z} / \partial t = - \partial \mathcal{F} / \partial \mathbf{z} \]

where \(\mathcal{F}\) is the free energy, \(\mathbf{z}\) are the activities, with \(\mathbf{z}_L\) clamped to some target and \(\mathbf{z}_0\) optionally equal to some prior.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Other arguments:

  • loss: Loss function to use at the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
  • solver: Diffrax (ODE) solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method.
  • max_t1: Maximum end of integration region (500 by default).
  • dt: Integration step size. Defaults to None since the default stepsize_controller will automatically determine it.
  • stepsize_controller: diffrax controller for step size integration. Defaults to PIDController. Note that the relative and absolute tolerances of the controller will also determine the steady state to terminate the solver.
  • record_iters: If True, returns all integration steps.
  • record_every: int determining the sampling frequency the integration steps.

Returns:

List with solution of the activity dynamics for each layer.