Rework netket.utils.mpi and sum_inplace
Created by: PhilipVinc
sum_inplace is a thing left from the old numpy days where mpi was actually acting inplace. Now with jax it's not really the case anymore.
Moreover, I find myself wanting to use a bunch more mpi operations (Allreduce with MPI.MAX, MPI.LOR, or with NORM2...) so I want to add those to a netket.utils.mpi
module.
I would also like to make explicit the choice between the jax (mpi4jax) version and the non jax (mpi4py) version to avoid bugs like #655 . Also I would like to expose mpi4jax token's mechanism in case we'll need it in the future. Now we don't really need it, but maybe it might be useful.
(The token mechanism essentially means that all mpi4jax operations take and return an extra 'token', which is a fake object that forces mpi operations not to be reordered by jax, preventing deadlocks. We don't really need it for now, because all our mpi operations are executed in an order because of the data input/output, but who knows. maybe in the future it might be useful.)
For who reviews:
- the first commit is the one with the new functionality and new mpi module.
- The second commit is just boilerplate changes everywhere
- The third commit adapts netket.stats and removes
sum_inplace
, the last commit is updates following this commit.
I'd like some feedback on this
Things that prompted this: The bug hunt for #655, making ExactSampler exploit MPI and fixing #651 (closed)