Skip to content

Add jax-only local_cost_function machinery

Vicentini Filippo requested to merge PhilipVinc/grads into v3.0

Created by: PhilipVinc

This PR adds utilities to define local_cost_functions that will have their gradients and everything defined, which will be used in the future to write an unsupervised learnng driver.

The api is something along the lines of

@partial(local_cost_function, static_argnums=0, batch_axes=(None, None, 0, 0, 0))
def local_energy(logpsi, pars, vp, mel, v):
    return jax.numpy.sum(mel * jax.numpy.exp(logpsi(pars, vp) - logpsi(pars, v)))

# usage
ma = ...
v = ...
vp, mels = op.get_conn_flattened(v)
loc_vals = local_energy(ma.jax_forward, ma.parameters, vp, mel, v)
_loc_vals,  der_loc_vals = local_costs_and_grads_function(local_energy, ma.jax_forward, ma.parameters, vp, mel, v)

assert loc_vals == _loc_vals

I'm not happy about having to explicitly extract the two fields of machine, but i have to find a way to work around that without triggering a lot of recompilation...

@gcarleo, tell me if this seems ok for you, and i'll wrap this up.

Merge request reports