Add `MCState.log_gradient` implementation
Created by: femtobit
This is a first draft of a PR which adds a MCState.log_gradients
function that, for given samples, returns the corresponding scores O_k(s) = ∂[ln ψ(s)]/∂θ_k
.
This is done internally by calling qgt_jacobian_pytree_logic.py::prepare_centered_oks
and is thus essentially the same code path QGTJacobianPyTree
uses to compute its .O
member. (This PR currently renames prepare_centered_oks
to prepare_log_gradients
since centering is now optional.)
log_gradient
accepts a mode
option (with the same meaning as for QGTJacobian*
). There is one difference, though: mode=complex
in QGTJacobianPyTree
creates the log gradients qgt.O
as real-valued arrays with stacked real and imaginary parts of the complex O
. log_gradients
avoids this (since I don't think returning this to users by default is a good choice), but this comes at the cost of adding yet another option to prepare_log_gradients
.
There are several things to address before this can be merged (like adding tests), but I'd like some feedback on this first draft (@inailuig @PhilipVinc @gcarleo @attila-i-szabo @danielalcalde):
- Does the general design make sense? Reusing code from
qgt_jacobian_*
seems to me like the best way to prevent code duplication (and we already support chunking etc. inlog_gradients
). - Do we need to add an option to obtain the log gradients by generalizing
qgt_jacobian_dense_logic.py::prepare_centered_oks
in the same way this is done for the pytree version right now? That would probably be useful in the same cases whereQGTJacobianDense
is a good choice, right? - There is probably a way to clean up all the nested
if
branches inprepare_centered_oks
/prepare_log_gradients
and make the logic less convoluted. Any suggestions (@inailuig)? - I am not convinced it is worth the additional effort (since these methods are internal anyways), but conceptually it could make more sense to move most of the Jacobian code to something like
netket.jax._jacobian_{dense,pytree}
(fromnetket.optimizer.qgt.qgt_jacobian_{dense,pytree}_logic
) as most of it is not actually that QGT specific.
Example
import netket as nk
import jax
L = 4
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
ma = nk.models.RBMModPhase(alpha=1, use_hidden_bias=False, dtype=float)
sa = nk.sampler.MetropolisLocal(hi, n_chains=16)
vs = nk.vqs.MCState(sa, ma, n_samples=1024)
>>> print(jax.tree_map(jax.numpy.shape, vs.parameters))
FrozenDict({
Dense_0: {
kernel: (4, 4),
},
Dense_1: {
kernel: (4, 4),
},
})
>>> print(jax.tree_map(jax.numpy.shape, vs.log_gradient(vs.samples)))
FrozenDict({
Dense_0: {
kernel: (1024, 4, 4),
},
Dense_1: {
kernel: (1024, 4, 4),
},
})
>>> print(vs.log_gradient(vs.samples))
FrozenDict({
Dense_0: {
kernel: DeviceArray([[[-3.29284369e-04+0.j, 5.88890101e-04+0.j,
-1.15402368e-05+0.j, -1.53835420e-03+0.j],
[ 3.29284369e-04+0.j, -5.88890101e-04+0.j,
...