Initialisation¤
JPC provides 3 ways of initialising the activities of a PC network:
- jpc.init_activities_with_ffwd() for a feedforward pass (standard),
- jpc.init_activities_from_normal() for random initialisation, and
- jpc.init_activities_with_amort() for use of an amortised network.
jpc.init_activities_with_ffwd(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> PyTree[jax.Array]
¤
Initialises the layers' activity with a feedforward pass \(\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L\) where \(f_\ell(\cdot)\) is some callable layer transformation and \(\mathbf{z}_0 = \mathbf{x}\) is the input.
Warning
param_type = "mupc" (μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
model: List of callable model (e.g. neural network) layers.input: input to the model.
Other arguments:
skip_model: Optional skip connection model.param_type: Determines the parameterisation. Options are"sp"(standard parameterisation),"mupc"(μPC), or"ntp"(neural tangent parameterisation). See_get_param_scalings()for the specific scalings of these different parameterisations. Defaults to"sp".
Returns:
List with activity values of each layer.
jpc.init_activities_from_normal(key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], layer_sizes: PyTree[int], mode: str, batch_size: int, sigma: Shaped[Array, ''] = 0.05) -> PyTree[jax.Array]
¤
Initialises network activities from a zero-mean Gaussian \(z_i \sim \mathcal{N}(0, \sigma^2)\).
Main arguments:
key:jax.random.PRNGKeyfor sampling.layer_sizes: List with dimension of all layers (input, hidden and output).mode: If"supervised", all hidden layers are initialised. If"unsupervised"the input layer \(\mathbf{z}_0\) is also initialised.batch_size: Dimension of data batch.sigma: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
Returns:
List of randomly initialised activities for each layer.
jpc.init_activities_with_amort(amortiser: PyTree[typing.Callable], generator: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> PyTree[jax.Array]
¤
Initialises layers' activity with an amortised network \(\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L\) where \(\mathbf{z}_0 = \mathbf{y}\) is the input or generator's target.
Note
The output order is reversed for downstream use by the generator.
Main arguments:
amortiser: List of callable layers for model amortising the inference of thegenerator.generator: List of callable layers for the generative model.input: Input to the amortiser.
Returns:
List with amortised initialisation of each layer.