Initialisation¤
JPC provides 3 ways of initialising the activities of a PC network:
- jpc.init_activities_with_ffwd() for a feedforward pass (standard),
- jpc.init_activities_from_normal() for random initialisation, and
- jpc.init_activities_with_amort() for use of an amortised network.
jpc.init_activities_with_ffwd(model: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], *, skip_model: typing.Optional[jaxtyping.PyTree[typing.Callable]] = None, param_type: str = 'sp') -> PyTree[jax.Array]
¤
Initialises the layers' activity with a feedforward pass \(\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L\) where \(f_\ell(\cdot)\) is some callable layer transformation and \(\mathbf{z}_0 = \mathbf{x}\) is the input.
Warning
param_type = "mupc"
(μPC) assumes
that one is using jpc.make_mlp()
to create the model.
Main arguments:
model
: List of callable model (e.g. neural network) layers.input
: input to the model.
Other arguments:
skip_model
: Optional skip connection model.param_type
: Determines the parameterisation. Options are"sp"
(standard parameterisation),"mupc"
(μPC), or"ntp"
(neural tangent parameterisation). See_get_param_scalings()
for the specific scalings of these different parameterisations. Defaults to"sp"
.
Returns:
List with activity values of each layer.
jpc.init_activities_from_normal(key: typing.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, 2]], layer_sizes: PyTree[int], mode: str, batch_size: int, sigma: Shaped[Array, ''] = 0.05) -> PyTree[jax.Array]
¤
Initialises network activities from a zero-mean Gaussian \(z_i \sim \mathcal{N}(0, \sigma^2)\).
Main arguments:
key
:jax.random.PRNGKey
for sampling.layer_sizes
: List with dimension of all layers (input, hidden and output).mode
: If"supervised"
, all hidden layers are initialised. If"unsupervised"
the input layer \(\mathbf{z}_0\) is also initialised.batch_size
: Dimension of data batch.sigma
: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
Returns:
List of randomly initialised activities for each layer.
jpc.init_activities_with_amort(amortiser: PyTree[typing.Callable], generator: PyTree[typing.Callable], input: typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> PyTree[jax.Array]
¤
Initialises layers' activity with an amortised network \(\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L\) where \(\mathbf{z}_0 = \mathbf{y}\) is the input or generator's target.
Note
The output order is reversed for downstream use by the generator.
Main arguments:
amortiser
: List of callable layers for model amortising the inference of thegenerator
.generator
: List of callable layers for the generative model.input
: Input to the amortiser.
Returns:
List with amortised initialisation of each layer.