Skip to content

Initialisation¤

JPC provides 3 ways of initialising the activities of a PC network:

jpc.init_activities_with_ffwd(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> PyTree[jax.Array] ¤

Initialises the layers' activity with a feedforward pass \(\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L\) where \(f_\ell(\cdot)\) is some callable layer transformation and \(\mathbf{z}_0 = \mathbf{x}\) is the input.

Warning

param_type = "mupc" (μPC) assumes that one is using jpc.make_mlp() to create the model.

Main arguments:

  • model: List of callable model (e.g. neural network) layers.
  • input: input to the model.

Other arguments:

  • skip_model: Optional skip connection model.
  • param_type: Determines the parameterisation. Options are "sp" (standard parameterisation), "mupc" (μPC), or "ntp" (neural tangent parameterisation). See _get_param_scalings() for the specific scalings of these different parameterisations. Defaults to "sp".

Returns:

List with activity values of each layer.


jpc.init_activities_from_normal(key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], layer_sizes: PyTree[int], mode: str, batch_size: int, sigma: Shaped[Array, ''] = 0.05) -> PyTree[jax.Array] ¤

Initialises network activities from a zero-mean Gaussian \(z_i \sim \mathcal{N}(0, \sigma^2)\).

Main arguments:

  • key: jax.random.PRNGKey for sampling.
  • layer_sizes: List with dimension of all layers (input, hidden and output).
  • mode: If "supervised", all hidden layers are initialised. If "unsupervised" the input layer \(\mathbf{z}_0\) is also initialised.
  • batch_size: Dimension of data batch.
  • sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.

Returns:

List of randomly initialised activities for each layer.


jpc.init_activities_with_amort(amortiser: PyTree[typing.Callable], generator: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> PyTree[jax.Array] ¤

Initialises layers' activity with an amortised network \(\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L\) where \(\mathbf{z}_0 = \mathbf{y}\) is the input or generator's target.

Note

The output order is reversed for downstream use by the generator.

Main arguments:

  • amortiser: List of callable layers for model amortising the inference of the generator.
  • generator: List of callable layers for the generative model.
  • input: Input to the amortiser.

Returns:

List with amortised initialisation of each layer.