Skip to content

Utils¤

JPC provides several standard utilities for neural network training, including creation of simple models, losses, and metrics.

jpc.make_mlp(key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], input_dim: int, width: int, depth: int, output_dim: int, act_fn: str, use_bias: bool = False, param_type: str = 'sp') -> PyTree[typing.Callable] ¤

Creates a multi-layer perceptron compatible with predictive coding updates.

Note

This implementation places the activation function before the linear transformation, \(\mathbf{W}_\ell \phi(\mathbf{z}_{\ell-1})\), for compatibility with the μPC scalings when param_type = "mupc" in functions including jpc.init_activities_with_ffwd(), jpc.update_activities(), and jpc.update_params().

Main arguments:

  • key: jax.random.PRNGKey for parameter initialisation.
  • input_dim: Input dimension.
  • width: Network width.
  • depth: Network depth.
  • output_dim: Output dimension.
  • act_fn: Activation function (for all layers except the output).
  • use_bias: False by default.
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See jpc._get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".

Returns:

List of callable fully connected layers.


jpc.make_skip_model(depth: int) -> PyTree[typing.Callable] ¤

Creates a residual network with one-layer skip connections at every layer except from the input to the next layer and from the penultimate layer to the output.

This is used for compatibility with the μPC parameterisation when param_type = "mupc" in functions including jpc.init_activities_with_ffwd(), jpc.update_activities(), and jpc.update_params().


jpc.get_act_fn(name: str) -> typing.Callable ¤


jpc.mse_loss(preds: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], labels: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, ''] ¤


jpc.cross_entropy_loss(logits: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], labels: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, ''] ¤


jpc.compute_accuracy(truths: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], preds: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, ''] ¤


jpc.get_t_max(activities_iters: PyTree[jax.Array]) -> Array ¤


jpc.compute_infer_energies(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities_iters: PyTree[jax.Array], t_max: Array, y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> PyTree[jaxtyping.Shaped[Array, '']] ¤

Calculates layer energies during predictive coding inference.

Main arguments:

  • params: Tuple with callable model layers and optional skip connections.
  • activities_iters: Layer-wise activities at every inference iteration. Note that each set of activities will have 4096 steps as first dimension by diffrax default.
  • t_max: Maximum number of inference iterations to compute energies for.
  • y: Observation or target of the generative model.

Other arguments:

  • x: Optional prior of the generative model.
  • loss: Loss function to use at the output layer (mean squared error "mse" vs cross-entropy "ce").
  • param_type: Determines the parameterisation. Options are "sp", "mupc", or "ntp".
  • weight_decay: Weight decay for the weights.
  • spectral_penalty: Spectral penalty for the weights.
  • activity_decay: Activity decay for the activities.

Returns:

List of layer-wise energies at every inference iteration.


jpc.compute_activity_norms(activities: PyTree[jax.Array]) -> Array ¤

Calculates \(\ell^2\) norm of activities at each layer.


jpc.compute_param_norms(params) ¤

Calculates \(\ell^2\) norm of all model parameters.