Skip to content

Info

JPC provides two types of API depending on the use case:

  • a simple, high-level API that allows to train and test models with predictive coding in a few lines of code, and
  • a more advanced API offering greater flexibility as well as additional features.

Basic usage¤

At a high level, JPC provides a single convenience function jpc.make_pc_step to update the parameters of a neural network with PC.

import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
import jpc

# toy data
x = jnp.array([1., 1., 1.])
y = -x

# define model and optimiser
key = jr.PRNGKey(0)
model = jpc.make_mlp(key, layer_sizes=[3, 100, 100, 3], act_fn="tanh")
optim = optax.adam(1e-3)
opt_state = optim.init(
    (eqx.filter(model, eqx.is_array), None)
)

# perform one training step with PC
update_result = jpc.make_pc_step(
    model=model,
    optim=optim,
    opt_state=opt_state,
    output=y,
    input=x
)

# updated model and optimiser
model = update_result["model"]
optim, opt_state = update_result["optim"], update_result["opt_state"]
As shown above, at a minimum jpc.make_pc_step takes a model, an optax optimiser and its state, and some data. The model needs to be compatible with PC updates in the sense that it's split into callable layers (see the example notebooks ). Also note that the input is actually not needed for unsupervised training. In fact, jpc.make_pc_step can be used for classification and generation tasks, for supervised as well as unsupervised training (again see the example notebooks ).

Under the hood, jpc.make_pc_step uses diffrax to solve the activity (inference) dynamics of PC. Many default arguments, for example related to the ODE solver, can be changed, including the ODE solver, and there is an option to record a variety of metrics such as loss, accuracy, and energies. See the docs for more details.

A similar convenience function jpc.make_hpc_step is provided for updating the parameters of a hybrid PCN (Tschantz et al., 2023 ).

import jax.random as jr
import equinox as eqx
import optax
import jpc

# models
key = jr.PRNGKey(0)
subkeys = jr.split(key, 2)
layer_sizes = [3, 100, 100, 3]
generator = jpc.make_mlp(subkeys[0], layer_sizes, "tanh")
amortiser = jpc.make_mlp(subkeys[1], layer_sizes[::-1], "tanh")

# optimisers
gen_optim = optax.adam(1e-3)
amort_optim = optax.adam(1e-3)
gen_pt_state = gen_optim.init(
    (eqx.filter(generator, eqx.is_array), None)
)
amort_opt_state = amort_optim.init(
    eqx.filter(amortiser, eqx.is_array)
)

update_result = jpc.make_hpc_step(
    generator=generator,
    amortiser=amortiser,
    optims=[gen_optim, amort_optim],
    opt_states=[gen_opt_state, amort_opt_state],
    output=y,
    input=x
)
generator, amortiser = update_result["generator"], update_result["amortiser"]
optims, opt_states = update_result["optims"], update_result["opt_states"]
gen_loss, amort_loss = update_result["losses"]
See the docs and the example notebook for more details.