Skip to content

[WIP] Add `nk.jax.jacobian(..., dense=True/False)` by excising all QGTJacobian code.

Vicentini Filippo requested to merge pv/jacobians into master

Created by: PhilipVinc

This PR moves the code used to compute the log-jacobian from nk.optimizer.qgt. to nk.jax and unifies the codes for the dense and pytree variants, as well as adding a cantered flag.

The idea is that this code can then be used to write the long-await MCState.log_jacobian function without having to re-write yet another time this code. Eventually I'd like to properly document nk.jax.jacobian because it's a fairly useful function (together with nk.jax.expect) to write custom operations, but I'd like to think a bit about the correct user-facing signature so I'll leave not officially exported for now.

This should change nothing for QGTJacobian** as the two implementations were already equivalent. The only (very minor) change is that the rescaling coming from diag_scale is not applied in the same jit-block as were the jacobian is computed, but in another one. However, I don't think that this should change anything...

Merge request reports