fix MPI compilation deadlock
Created by: PhilipVinc
We have an MPI deadlock situation that is triggered in mcstate.expect_and_grad
(and other places), which essentially because:
-
nk.stats.total_size(x) = sum_inplace(x.size) = mpi4py.MPI.Allreduce
computes the total size among all ranks of an array. Even if called on a jax-array, since the shape is static this is dispatched to mpi4py communicator.- If the function calling
total_size
is jitted, since shapes are static, theMPI.Allreduce
will only be called during compilation and never during execution.
- If the function calling
-
nk.stats.statistics
, that computes error and variance, calls several timestotal_size
.
So..
-
MCState.expect(..., x)
is called with the samples computed on every rank. Since all ranks have the same number of samples, this is always compiled at the same time on all ranks. -
MCstate.expect_and_grad(...x, xp)
is called with the same samplesx
, but with differentxp
on every rank. The shape ofxp
, thanks to several optimisation (and exacerbated by my recent optimisation #623) has a shape that depends on the runtime value of the samplesx
. In particular, it depends on the maximum number of connected elements among all samplesx
.- Imagine the case where the samples on rank 0 all have only 1 connected element, while one sample on rank 1 have 2 connected elements.
- Initially this compiles two different functions on the two ranks. This is all fine.
- Now imagine that rank 1 gets the same samples of rank 0. the shape of xp has changed, so he will recompile
expect_and_grad
, but rank 0 Is not recompiling. - Since rank 1 is recompiling he calls
total_size -> MPI.Allreduce
during compilation, but rank 0 is not compiling and therefore we have the deadlock.
--
This PR implements a cheap fix: in total_size
(and therefore statistics) assume that all ranks have the same shape so that this compile-time reduction is not necessary.
In the future, I think we should do two things:
- Re-generalise this, by (maybe) making n_samples or n_dof an argument to be specified of
statistics
andvar
. - Re-factor the mpi code. In particular I would like to remove
sum_inplace
because it hides the fact that sometimes we use MPI4PY and other times MPI4JAX, and prevents us from using tokens if ever needed for mpi4jax.