Skip to content

Correcting a mpi n_nodes multiplicative factor in nk.jax.expect

Created by: alleSini99

The gradient computed with function nk.jax.expect misses a multiplicative factor given by the number of mpi ranks. The issue is in the function mpi_mean inside the definition of f within _expect_bwd, because using it we already do the mean on the different ranks inside the function itself while we should do it afterwards. This results in a rescaling of the output gradient of 1/nk.utils.mpi.n_nodes. Therefore, we substitute mpi_mean with a simple mean over the samples and we perform the mpi mean over the gradients afterwards.

closes #1355

Merge request reports

Loading