Skip to content

Utils¤

jpc.make_mlp(key: PRNGKeyArray, layer_sizes: PyTree[int], act_fn: str, use_bias: bool = True) -> PyTree[typing.Callable] ¤

Creates a multi-layer perceptron compatible with predictive coding updates.

Main arguments:

  • key: jax.random.PRNGKey for parameter initialisation.
  • layer_sizes: Dimension of all layers (input, hidden and output). Options are linear, tanh and relu.
  • act_fn: Activation function for all layers except the output.
  • use_bias: True by default.

Returns:

List of callable fully connected layers.


jpc.get_act_fn(name: str) -> Callable ¤


jpc.mse_loss(preds: ArrayLike, labels: ArrayLike) -> Array ¤


jpc.cross_entropy_loss(logits: ArrayLike, labels: ArrayLike) -> Array ¤


jpc.compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Array ¤


jpc.get_t_max(activities_iters: PyTree[Array]) -> Array ¤


jpc.compute_infer_energies(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities_iters: PyTree[Array], t_max: Array, y: ArrayLike, x: Optional[ArrayLike] = None, loss: str = 'MSE') -> PyTree[Array] ¤

Calculates layer energies during predictive coding inference.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities_iters: Layer-wise activities at every inference iteration. Note that each set of activities will have 4096 steps as first dimension by diffrax default.
  • t_max: Maximum number of inference iterations to compute energies for.
  • y: Observation or target of the generative model.
  • x: Optional prior of the generative model.

Other arguments:

  • loss: Loss function specified at the output layer (mean squared error 'MSE' vs cross-entropy 'CE').

Returns:

List of layer-wise energies at every inference iteration.


jpc.compute_activity_norms(activities: PyTree[Array]) -> Array ¤

Calculates l2 norm of activities at each layer.


jpc.compute_param_norms(params) ¤

Calculates l2 norm of all model parameters.