Updates¤
jpc.update_activities(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], output: ArrayLike, input: Optional[ArrayLike] = None) -> Dict
¤
Updates activities of a predictive coding network with a given Optax optimiser.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.optim
: Optax optimiser, e.g.optax.sgd()
.opt_state
: State of Optax optimiser.output
: Observation or target of the generative model.input
: Optional prior of the generative model.
Returns:
Dictionary with energy, updated activities, activity gradients, optimiser, and updated optimiser state.
jpc.update_params(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities: PyTree[ArrayLike], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], output: ArrayLike, input: Optional[ArrayLike] = None) -> Dict
¤
Updates parameters of a predictive coding network with a given Optax optimiser.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities
: List of activities for each layer free to vary.optim
: Optax optimiser, e.g.optax.sgd()
.opt_state
: State of Optax optimiser.output
: Observation or target of the generative model.input
: Optional prior of the generative model.
Returns:
Dictionary with model (and optional skip model) with updated parameters, parameter gradients, optimiser, and updated optimiser state.