Initialisation¤
Info
JPC provides 3 standard ways of initialising the activities: a feedforward pass, randomly, or using an amortised network.
jpc.init_activities_with_ffwd(model: PyTree[typing.Callable], input: ArrayLike, skip_model: Optional[PyTree[Callable]] = None) -> PyTree[Array]
¤
Initialises layers' activity with a feedforward pass \(\{ f_\ell(\mathbf{z}_{\ell-1}) \}_{\ell=1}^L\) where \(\mathbf{z}_0 = \mathbf{x}\) is the input.
Main arguments:
model
: List of callable model (e.g. neural network) layers.input
: input to the model.
Other arguments:
skip_model
: Optional skip connection model.
Returns:
List with activity values of each layer.
jpc.init_activities_from_normal(key: PRNGKeyArray, layer_sizes: PyTree[int], mode: str, batch_size: int, sigma: Array = 0.05) -> PyTree[Array]
¤
Initialises network activities from a zero-mean Gaussian \(\sim \mathcal{N}(0, \sigma^2)\).
Main arguments:
key
:jax.random.PRNGKey
for sampling.layer_sizes
: List with dimension of all layers (input, hidden and output).mode
: Ifsupervised
, all hidden layers are initialised. Ifunsupervised
the input layer \(\mathbf{z}_0\) is also initialised.batch_size
: Dimension of data batch.sigma
: Standard deviation for Gaussian to sample activities from. Defaults to 5e-2.
Returns:
List of randomly initialised activities for each layer.
jpc.init_activities_with_amort(amortiser: PyTree[typing.Callable], generator: PyTree[typing.Callable], input: ArrayLike) -> PyTree[Array]
¤
Initialises layers' activity with an amortised network \(\{ f_{L-\ell+1}(\mathbf{z}_{L-\ell}) \}_{\ell=1}^L\) where \(\mathbf{z}_0 = \mathbf{y}\) is the input or generator's target.
Note
The output order is reversed for downstream use by the generator.
Main arguments:
amortiser
: List of callable layers for model amortising the inference of thegenerator
.generator
: List of callable layers for the generative model.input
: Input to the amortiser.
Returns:
List with amortised initialisation of each layer.