Utils¤
jpc.make_mlp(key: PRNGKeyArray, layer_sizes: PyTree[int], act_fn: str, use_bias: bool = True) -> PyTree[typing.Callable]
¤
Creates a multi-layer perceptron compatible with predictive coding updates.
Main arguments:
key
:jax.random.PRNGKey
for parameter initialisation.layer_sizes
: Dimension of all layers (input, hidden and output). Options arelinear
,tanh
andrelu
.act_fn
: Activation function for all layers except the output.use_bias
:True
by default.
Returns:
List of callable fully connected layers.
jpc.get_act_fn(name: str) -> Callable
¤
jpc.mse_loss(preds: ArrayLike, labels: ArrayLike) -> Array
¤
jpc.cross_entropy_loss(logits: ArrayLike, labels: ArrayLike) -> Array
¤
jpc.compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Array
¤
jpc.get_t_max(activities_iters: PyTree[Array]) -> Array
¤
jpc.compute_infer_energies(params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]], activities_iters: PyTree[Array], t_max: Array, y: ArrayLike, x: Optional[ArrayLike] = None, loss: str = 'MSE') -> PyTree[Array]
¤
Calculates layer energies during predictive coding inference.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities_iters
: Layer-wise activities at every inference iteration. Note that each set of activities will have 4096 steps as first dimension by diffrax default.t_max
: Maximum number of inference iterations to compute energies for.y
: Observation or target of the generative model.x
: Optional prior of the generative model.
Other arguments:
loss
: Loss function specified at the output layer (mean squared error 'MSE' vs cross-entropy 'CE').
Returns:
List of layer-wise energies at every inference iteration.
jpc.compute_activity_norms(activities: PyTree[Array]) -> Array
¤
Calculates l2 norm of activities at each layer.
jpc.compute_param_norms(params)
¤
Calculates l2 norm of all model parameters.