Advanced usage¤
Advanced users can access all the underlying functions of jpc.make_pc_step
as
well as additional features. A custom PC training step looks like the following:
import jpc
# 1. initialise activities with a feedforward pass
activities = jpc.init_activities_with_ffwd(model=model, input=x)
# 2. run inference to equilibrium
equilibrated_activities = jpc.solve_inference(
params=(model, None),
activities=activities,
output=y,
input=x
)
# 3. update parameters at the activities' solution with PC
param_update_result = jpc.update_params(
params=(model, None),
activities=equilibrated_activities,
optim=param_optim,
opt_state=param_opt_state,
output=y,
input=x
)
# updated model and optimiser
model = param_update_result["model"]
param_optim = param_update_result["optim"]
param_opt_state = param_update_result["opt_state"]
activity_optim = optax.sgd(1e-3)
# 1. initialise activities
...
# 2. infer with gradient descent
activity_opt_state = activity_optim.init(activities)
for t in range(T):
activity_update_result = jpc.update_activities(
params=(model, None),
activities=activities,
optim=activity_optim,
opt_state=activity_opt_state,
output=y,
input=x
)
# updated activities and optimiser
activities = activity_update_result["activities"]
activity_optim = activity_update_result["optim"]
activity_opt_state = activity_update_result["opt_state"]
# 3. update parameters at the activities' solution with PC
...