Add jax-only local_cost_function machinery
Created by: PhilipVinc
This PR adds utilities to define local_cost_functions that will have their gradients and everything defined, which will be used in the future to write an unsupervised learnng driver.
The api is something along the lines of
@partial(local_cost_function, static_argnums=0, batch_axes=(None, None, 0, 0, 0))
def local_energy(logpsi, pars, vp, mel, v):
return jax.numpy.sum(mel * jax.numpy.exp(logpsi(pars, vp) - logpsi(pars, v)))
# usage
ma = ...
v = ...
vp, mels = op.get_conn_flattened(v)
loc_vals = local_energy(ma.jax_forward, ma.parameters, vp, mel, v)
_loc_vals, der_loc_vals = local_costs_and_grads_function(local_energy, ma.jax_forward, ma.parameters, vp, mel, v)
assert loc_vals == _loc_vals
I'm not happy about having to explicitly extract the two fields of machine, but i have to find a way to work around that without triggering a lot of recompilation...
@gcarleo, tell me if this seems ok for you, and i'll wrap this up.