Skip to content
Snippets Groups Projects

Correcting a mpi n_nodes multiplicative factor in nk.jax.expect

Created by: alleSini99

The gradient computed with function nk.jax.expect misses a multiplicative factor given by the number of mpi ranks. The issue is in the function mpi_mean inside the definition of f within _expect_bwd, because using it we already do the mean on the different ranks inside the function itself while we should do it afterwards. This results in a rescaling of the output gradient of 1/nk.utils.mpi.n_nodes. Therefore, we substitute mpi_mean with a simple mean over the samples and we perform the mpi mean over the gradients afterwards.

closes #1355

Merge request reports

Loading
Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
  • Loading
  • Loading
  • Loading
  • Loading
Please register or sign in to reply
Loading