Skip to content

Fix gradient of expectation value for exact state with complex parameters

Created by: yannra

Hi all, I noticed that there was a small error in the computation of the gradient of expectation values for the exact state when complex parameters are used in the model. This PR fixes that.

The issue can for example be seen with the following little script comparing the computed gradient with one obtained by numdifftools (yes, I know, it's a little bit strange to use numdifftools to verify the gradients in code based on jax but setting it up with numdifftools was just quicker than making the expectation value evaluation differentiable with jax).

import netket as nk
import numpy as np
import numdifftools as nd
import jax
from jax.nn.initializers import normal

g = nk.graph.Chain(6)
hi = nk.hilbert.Spin(0.5, N=g.n_nodes, total_sz=0.0)
ha = nk.operator.Heisenberg(hi, g, sign_rule=False)
model = nk.models.RBM(dtype=complex, use_visible_bias=False, use_hidden_bias=False, kernel_init=normal(stddev=1.))

vs = nk.vqs.ExactState(hi, model=model, seed=jax.random.PRNGKey(123))
vs_sampled = nk.vqs.MCState(nk.sampler.MetropolisExchange(hi, graph=g), model=model, n_samples=1000000)

""" We are doing a little bit of reshaping and stacking to map the complex
parameters to a flattened array of real values (two times the total number of params). This way,
we can evaluate the gradient with numdifftool and compare."""

kernel_shape = vs.parameters["Dense"]["kernel"].shape
def calculate_expect(parameters):
    kernel = parameters[:len(parameters)//2] + 1.j*parameters[len(parameters)//2:]
    vs.parameters = {"Dense" : {"kernel" : kernel.reshape(kernel_shape)}}
    return vs.expect(ha).mean.real

def calculate_grad_flattened(parameters):
    kernel = parameters[:len(parameters)//2] + 1.j*parameters[len(parameters)//2:]
    vs.parameters = {"Dense" : {"kernel" : kernel.reshape(kernel_shape)}}
    grad_flattened = vs.expect_and_grad(ha)[1]["Dense"]["kernel"].flatten()
    return 2 * np.concatenate((grad_flattened.real, grad_flattened.imag))

""" Just to ensure that the sign definitions are in-line with the rest of NetKet,
we also compute a sampled expectation value."""
def calculate_grad_sampled_flattened(parameters):
    kernel = parameters[:len(parameters)//2] + 1.j*parameters[len(parameters)//2:]
    vs_sampled.parameters = {"Dense" : {"kernel" : kernel.reshape(kernel_shape)}}
    grad_flattened = vs_sampled.expect_and_grad(ha)[1]["Dense"]["kernel"].flatten()
    return 2 * np.concatenate((grad_flattened.real, grad_flattened.imag))

init_pars = vs.parameters["Dense"]["kernel"].flatten()
init_pars_split = np.concatenate((init_pars.real, init_pars.imag))

# gradient as calculated by NetKet
grad_nk = calculate_grad_flattened(init_pars_split)
grad_nk_sampled = calculate_grad_sampled_flattened(init_pars_split)

# Gradient computed by numdifftools
grad_nd = nd.Gradient(calculate_expect)(init_pars_split)

print(max(abs(grad_nd - grad_nk))) # returns large values without the fix in this PR
print(np.mean(abs(grad_nk_sampled - grad_nk)**2)) # returns large values without the fix in this PR

The script above just verifies the gradient by comparing the one obtained by calling vs.expect_and_grad(ha) with the numerical gradient of vs.expect(ha) as obtained with numdifftools. WIthout the fix of this pull request, it can be seen that the gradients are wrong when complex parameters are used, i.e. there is a significant deviation between the different gradients (which then also results in weird behaviour in the optimization of the exact state with complex parameters).

Incidentally, I realised that the gradient with respect to complex parameters is defined differently in NetKet as compared to jax which might be a little bit confusing for people (happy to open an issue if this is wanted): NetKet gives the complex conjugated version of the gradient which jax would usually return (in the sense that if you do gradient descent with the jax gradients, you need to apply complex conjugation to the complex-valued gradients returned by jax, the NetKet version can directly be used for GD). This discrepancy will probably also lead to issues when using the interface to evaluate gradients for non-hermitian operators (which presumably has never really been tested for models with complex parameters?) as this just evaluates the gradient by jax-differentiating on the full expectation value evaluation. The fix in this PR explicitly assumes Hermitian operators but I guess this is fine as there is no clear definition how a gradient would be defined if the operator is non-hermitian and complex valued models are used (as this can give complex-valued expectation values).

I hope this all makes sense, let me know if you require further info.

Best wishes, Yannic

Merge request reports