Skip to content

Discrete updates¤

JPC provides access to standard discrete optimisers to update the parameters of PC networks (jpc.update_pc_params), and to both discrete (jpc.update_pc_activities) and continuous optimisers (jpc.solve_inference) to solve the PC inference or activity dynamics.

jpc.update_pc_activities(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Dict ¤

Updates activities of a predictive coding network with a given optax optimiser.

Warning

param_type = "mupc" (μPC) assumes that one is using jpc.make_mlp() to create the model.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • optim: optax optimiser, e.g. optax.sgd().
  • opt_state: State of optax optimiser.
  • output: Observation or target of the generative model.

Other arguments:

  • input: Optional prior of the generative model.
  • loss_id: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • weight_decay: Weight decay for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: Activity decay for the activities (0 by default).

Returns:

Dictionary with energy, updated activities, activity gradients, and optimiser state.


jpc.update_pc_params(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss_id: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> typing.Dict ¤

Updates parameters of a predictive coding network with a given optax optimiser.

Warning

param_type = "mupc" (μPC) assumes that one is using jpc.make_mlp() to create the model.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities: List of activities for each layer free to vary.
  • optim: optax optimiser, e.g. optax.sgd().
  • opt_state: State of optax optimiser.
  • output: Observation or target of the generative model.

Other arguments:

  • input: Optional prior of the generative model.
  • loss_id: Loss function to use at the output layer. Options are mean squared error "mse" (default) or cross-entropy "ce".
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".
  • weight_decay: Weight decay for the weights (0 by default).
  • spectral_penalty: Weight spectral penalty of the form \(||\mathbf{I} - \mathbf{W}_\ell^T \mathbf{W}_\ell||^2\) (0 by default).
  • activity_decay: Activity decay for the activities (0 by default).

Returns:

Dictionary with model and optional skip model with updated parameters, parameter gradients, and optimiser state.


jpc.update_bpc_activities(top_down_model: PyTree[typing.Callable], bottom_up_model: PyTree[typing.Callable], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> typing.Dict ¤

Updates activities of a bidirectional PC network.

Main arguments:

  • top_down_model: List of callable model (e.g. neural network) layers for the forward model.
  • bottom_up_model: List of callable model (e.g. neural network) layers for the backward model.
  • activities: List of activities for each layer free to vary.
  • optim: optax optimiser, e.g. optax.sgd().
  • opt_state: State of optax optimiser.
  • output: Target of the top_down_model and input to the bottom_up_model.

Other arguments:

  • input: Input to the top_down_model and target of the bottom_up_model.
  • skip_model: Optional skip connection model.
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".

Returns:

Dictionary with energy, updated activities, activity gradients, and optimiser state.


jpc.update_bpc_params(top_down_model: PyTree[typing.Callable], bottom_up_model: PyTree[typing.Callable], activities: PyTree[jax.Array | numpy.ndarray | numpy.bool | numpy.number | bool | int | float | complex], top_down_optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, bottom_up_optim: optax._src.base.GradientTransformation | optax._src.base.GradientTransformationExtraArgs, top_down_opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], bottom_up_opt_state: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef(ArrayTree)], typing.Mapping[typing.Any, ForwardRef(ArrayTree)]], output: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> typing.Dict ¤

Updates parameters of a bidirectional PC network.

Main arguments:

  • top_down_model: List of callable model (e.g. neural network) layers for the forward model.
  • bottom_up_model: List of callable model (e.g. neural network) layers for the backward model.
  • activities: List of activities for each layer free to vary.
  • top_down_optim: optax optimiser for the top-down model.
  • bottom_up_optim: optax optimiser for the bottom-up model.
  • top_down_opt_state: State of the top-down optimiser.
  • bottom_up_opt_state: State of the bottom-up optimiser.
  • output: Target of the top_down_model and input to the bottom_up_model.

Other arguments:

  • input: Input to the top_down_model and target of the bottom_up_model.
  • skip_model: Optional skip connection model.
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".

Returns:

Dictionary with models with updated parameters, parameter gradients, and optimiser states.