Skip to content

Updates¤

jpc.update_activities(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], output: ArrayLike, input: Optional[ArrayLike] = None) -> Dict ¤

Updates activities of a predictive coding network with a given Optax optimiser.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • optim: Optax optimiser, e.g. optax.sgd().
  • opt_state: State of Optax optimiser.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Returns:

Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state.


jpc.update_params(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], output: ArrayLike, input: Optional[ArrayLike] = None) -> Dict ¤

Updates parameters of a predictive coding network with a given Optax optimiser.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • optim: Optax optimiser, e.g. optax.sgd().
  • opt_state: State of Optax optimiser.
  • output: Observation or target of the generative model.
  • input: Optional prior of the generative model.

Returns:

Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.