[WIP] Add `nk.jax.jacobian(..., dense=True/False)` by excising all QGTJacobian code.
Created by: PhilipVinc
This PR moves the code used to compute the log-jacobian from nk.optimizer.qgt.
to nk.jax
and unifies the codes for the dense and pytree variants, as well as adding a cantered
flag.
The idea is that this code can then be used to write the long-await MCState.log_jacobian
function without having to re-write yet another time this code. Eventually I'd like to properly document nk.jax.jacobian
because it's a fairly useful function (together with nk.jax.expect
) to write custom operations, but I'd like to think a bit about the correct user-facing signature so I'll leave not officially exported for now.
This should change nothing for QGTJacobian**
as the two implementations were already equivalent.
The only (very minor) change is that the rescaling
coming from diag_scale
is not applied in the same jit-block as were the jacobian is computed, but in another one. However, I don't think that this should change anything...