Skip to content

Add `MCState.log_gradient` implementation

Vicentini Filippo requested to merge loggrads into localvals

Created by: femtobit

This is a first draft of a PR which adds a MCState.log_gradients function that, for given samples, returns the corresponding scores O_k(s) = ∂[ln ψ(s)]/∂θ_k.

This is done internally by calling qgt_jacobian_pytree_logic.py::prepare_centered_oks and is thus essentially the same code path QGTJacobianPyTree uses to compute its .O member. (This PR currently renames prepare_centered_oks to prepare_log_gradients since centering is now optional.)

log_gradient accepts a mode option (with the same meaning as for QGTJacobian*). There is one difference, though: mode=complex in QGTJacobianPyTree creates the log gradients qgt.O as real-valued arrays with stacked real and imaginary parts of the complex O. log_gradients avoids this (since I don't think returning this to users by default is a good choice), but this comes at the cost of adding yet another option to prepare_log_gradients.

There are several things to address before this can be merged (like adding tests), but I'd like some feedback on this first draft (@inailuig @PhilipVinc @gcarleo @attila-i-szabo @danielalcalde):

  1. Does the general design make sense? Reusing code from qgt_jacobian_* seems to me like the best way to prevent code duplication (and we already support chunking etc. in log_gradients).
  2. Do we need to add an option to obtain the log gradients by generalizing qgt_jacobian_dense_logic.py::prepare_centered_oks in the same way this is done for the pytree version right now? That would probably be useful in the same cases where QGTJacobianDense is a good choice, right?
  3. There is probably a way to clean up all the nested if branches in prepare_centered_oks/prepare_log_gradients and make the logic less convoluted. Any suggestions (@inailuig)?
  4. I am not convinced it is worth the additional effort (since these methods are internal anyways), but conceptually it could make more sense to move most of the Jacobian code to something like netket.jax._jacobian_{dense,pytree} (from netket.optimizer.qgt.qgt_jacobian_{dense,pytree}_logic) as most of it is not actually that QGT specific.

Example

import netket as nk
import jax

L = 4
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
ma = nk.models.RBMModPhase(alpha=1, use_hidden_bias=False, dtype=float)

sa = nk.sampler.MetropolisLocal(hi, n_chains=16)
vs = nk.vqs.MCState(sa, ma, n_samples=1024)
>>> print(jax.tree_map(jax.numpy.shape, vs.parameters))
FrozenDict({
    Dense_0: {
        kernel: (4, 4),
    },
    Dense_1: {
        kernel: (4, 4),
    },
})
>>> print(jax.tree_map(jax.numpy.shape, vs.log_gradient(vs.samples)))
FrozenDict({
    Dense_0: {
        kernel: (1024, 4, 4),
    },
    Dense_1: {
        kernel: (1024, 4, 4),
    },
})
>>> print(vs.log_gradient(vs.samples))
FrozenDict({
    Dense_0: {
        kernel: DeviceArray([[[-3.29284369e-04+0.j,  5.88890101e-04+0.j,
                       -1.15402368e-05+0.j, -1.53835420e-03+0.j],
                      [ 3.29284369e-04+0.j, -5.88890101e-04+0.j,
...

Merge request reports