Utils¤
JPC provides several standard utilities for neural network training, including creation of simple models, losses, and metrics.
jpc.make_mlp(key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], input_dim: int, width: int, depth: int, output_dim: int, act_fn: str, use_bias: bool = False, param_type: str = 'sp') -> PyTree[typing.Callable]
¤
Creates a multi-layer perceptron compatible with predictive coding updates.
Note
This implementation places the activation function before the linear
transformation, \(\mathbf{W}_\ell \phi(\mathbf{z}_{\ell-1})\), for
compatibility with the μPC
scalings when param_type = "mupc"
in functions including
jpc.init_activities_with_ffwd()
,
jpc.update_activities()
,
and jpc.update_params()
.
Main arguments:
key
:jax.random.PRNGKey
for parameter initialisation.input_dim
: Input dimension.width
: Network width.depth
: Network depth.output_dim
: Output dimension.act_fn
: Activation function (for all layers except the output).use_bias
:False
by default.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). Seejpc._get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.
Returns:
List of callable fully connected layers.
jpc.make_skip_model(depth: int) -> PyTree[typing.Callable]
¤
Creates a residual network with one-layer skip connections at every layer except from the input to the next layer and from the penultimate layer to the output.
This is used for compatibility with the μPC
parameterisation when param_type = "mupc"
in functions including
jpc.init_activities_with_ffwd()
,
jpc.update_activities()
,
and jpc.update_params()
.
jpc.get_act_fn(name: str) -> typing.Callable
¤
jpc.mse_loss(preds: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], labels: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, '']
¤
jpc.cross_entropy_loss(logits: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], labels: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, '']
¤
jpc.compute_accuracy(truths: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], preds: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> Shaped[Array, '']
¤
jpc.get_t_max(activities_iters: PyTree[jax.Array]) -> Array
¤
jpc.compute_infer_energies(params: typing.Tuple[jaxtyping.PyTree[typing.Callable], typing.Optional[jaxtyping.PyTree[typing.Callable]]], activities_iters: PyTree[jax.Array], t_max: Array, y: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, x: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex, NoneType] = None, loss: str = 'mse', param_type: str = 'sp', weight_decay: Shaped[Array, ''] = 0.0, spectral_penalty: Shaped[Array, ''] = 0.0, activity_decay: Shaped[Array, ''] = 0.0) -> PyTree[jaxtyping.Shaped[Array, '']]
¤
Calculates layer energies during predictive coding inference.
Main arguments:
params
: Tuple with callable model layers and optional skip connections.activities_iters
: Layer-wise activities at every inference iteration. Note that each set of activities will have 4096 steps as first dimension by diffrax default.t_max
: Maximum number of inference iterations to compute energies for.y
: Observation or target of the generative model.
Other arguments:
x
: Optional prior of the generative model.loss
: Loss function to use at the output layer (mean squared error"mse"
vs cross-entropy"ce"
).param_type
: Determines the parameterisation. Options are"sp"
,"mupc"
, or"ntp"
.weight_decay
: Weight decay for the weights.spectral_penalty
: Spectral penalty for the weights.activity_decay
: Activity decay for the activities.
Returns:
List of layer-wise energies at every inference iteration.
jpc.compute_activity_norms(activities: PyTree[jax.Array]) -> Array
¤
Calculates \(\ell^2\) norm of activities at each layer.
jpc.compute_param_norms(params)
¤
Calculates \(\ell^2\) norm of all model parameters.