Correcting a mpi n_nodes multiplicative factor in nk.jax.expect
Created by: alleSini99
The gradient computed with function nk.jax.expect
misses a multiplicative factor given by the number of mpi ranks. The issue is in the function mpi_mean
inside the definition of f
within _expect_bwd
, because using it we already do the mean on the different ranks inside the function itself while we should do it afterwards. This results in a rescaling of the output gradient of 1/nk.utils.mpi.n_nodes
. Therefore, we substitute mpi_mean
with a simple mean over the samples and we perform the mpi mean over the gradients afterwards.
closes #1355