Skip to content

Rework netket.utils.mpi and sum_inplace

Vicentini Filippo requested to merge fix-mpi into master

Created by: PhilipVinc

sum_inplace is a thing left from the old numpy days where mpi was actually acting inplace. Now with jax it's not really the case anymore. Moreover, I find myself wanting to use a bunch more mpi operations (Allreduce with MPI.MAX, MPI.LOR, or with NORM2...) so I want to add those to a netket.utils.mpi module.

I would also like to make explicit the choice between the jax (mpi4jax) version and the non jax (mpi4py) version to avoid bugs like #655 . Also I would like to expose mpi4jax token's mechanism in case we'll need it in the future. Now we don't really need it, but maybe it might be useful.

(The token mechanism essentially means that all mpi4jax operations take and return an extra 'token', which is a fake object that forces mpi operations not to be reordered by jax, preventing deadlocks. We don't really need it for now, because all our mpi operations are executed in an order because of the data input/output, but who knows. maybe in the future it might be useful.)

For who reviews:

  • the first commit is the one with the new functionality and new mpi module.
  • The second commit is just boilerplate changes everywhere
  • The third commit adapts netket.stats and removes sum_inplace, the last commit is updates following this commit.

I'd like some feedback on this

Things that prompted this: The bug hunt for #655, making ExactSampler exploit MPI and fixing #651 (closed)

Merge request reports