Info
JPC provides two types of API depending on the use case:
- a simple, high-level API that allows to train and test models with predictive coding in a few lines of code, and
- a more advanced API offering greater flexibility as well as additional features.
Basic usage¤
At a high level, JPC provides a single convenience function jpc.make_pc_step
to update the parameters of a neural network with PC.
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
import jpc
# toy data
x = jnp.array([1., 1., 1.])
y = -x
# define model and optimiser
key = jr.PRNGKey(0)
model = jpc.make_mlp(key, layer_sizes=[3, 100, 100, 3], act_fn="tanh")
optim = optax.adam(1e-3)
opt_state = optim.init(
(eqx.filter(model, eqx.is_array), None)
)
# perform one training step with PC
update_result = jpc.make_pc_step(
model=model,
optim=optim,
opt_state=opt_state,
output=y,
input=x
)
# updated model and optimiser
model = update_result["model"]
optim, opt_state = update_result["optim"], update_result["opt_state"]
jpc.make_pc_step
takes a model, an optax
optimiser and its
state, and some data. The model needs to be compatible with PC updates in the
sense that it's split into callable layers (see the
example notebooks
). Also note
that the input
is actually not needed for unsupervised training. In fact,
jpc.make_pc_step
can be used for classification and generation tasks, for
supervised as well as unsupervised training (again see the example notebooks
).
Under the hood, jpc.make_pc_step
uses diffrax
to solve the activity (inference)
dynamics of PC. Many default arguments, for example related to the ODE solver,
can be changed, including the ODE solver, and there is an option to record a
variety of metrics such as loss, accuracy, and energies. See the docs
for more
details.
A similar convenience function jpc.make_hpc_step
is provided for updating the
parameters of a hybrid PCN (Tschantz et al., 2023
).
import jax.random as jr
import equinox as eqx
import optax
import jpc
# models
key = jr.PRNGKey(0)
subkeys = jr.split(key, 2)
layer_sizes = [3, 100, 100, 3]
generator = jpc.make_mlp(subkeys[0], layer_sizes, "tanh")
amortiser = jpc.make_mlp(subkeys[1], layer_sizes[::-1], "tanh")
# optimisers
gen_optim = optax.adam(1e-3)
amort_optim = optax.adam(1e-3)
gen_pt_state = gen_optim.init(
(eqx.filter(generator, eqx.is_array), None)
)
amort_opt_state = amort_optim.init(
eqx.filter(amortiser, eqx.is_array)
)
update_result = jpc.make_hpc_step(
generator=generator,
amortiser=amortiser,
optims=[gen_optim, amort_optim],
opt_states=[gen_opt_state, amort_opt_state],
output=y,
input=x
)
generator, amortiser = update_result["generator"], update_result["amortiser"]
optims, opt_states = update_result["optims"], update_result["opt_states"]
gen_loss, amort_loss = update_result["losses"]