Skip to content

fix MPI compilation deadlock

Vicentini Filippo requested to merge fix-deadlock into master

Created by: PhilipVinc

We have an MPI deadlock situation that is triggered in mcstate.expect_and_grad (and other places), which essentially because:

  • nk.stats.total_size(x) = sum_inplace(x.size) = mpi4py.MPI.Allreduce computes the total size among all ranks of an array. Even if called on a jax-array, since the shape is static this is dispatched to mpi4py communicator.

    • If the function calling total_size is jitted, since shapes are static, the MPI.Allreduce will only be called during compilation and never during execution.
  • nk.stats.statistics, that computes error and variance, calls several times total_size.

So..

  • MCState.expect(..., x) is called with the samples computed on every rank. Since all ranks have the same number of samples, this is always compiled at the same time on all ranks.
  • MCstate.expect_and_grad(...x, xp) is called with the same samples x, but with different xp on every rank. The shape of xp, thanks to several optimisation (and exacerbated by my recent optimisation #623) has a shape that depends on the runtime value of the samples x. In particular, it depends on the maximum number of connected elements among all samples x.
    • Imagine the case where the samples on rank 0 all have only 1 connected element, while one sample on rank 1 have 2 connected elements.
    • Initially this compiles two different functions on the two ranks. This is all fine.
    • Now imagine that rank 1 gets the same samples of rank 0. the shape of xp has changed, so he will recompile expect_and_grad, but rank 0 Is not recompiling.
    • Since rank 1 is recompiling he calls total_size -> MPI.Allreduce during compilation, but rank 0 is not compiling and therefore we have the deadlock.

--

This PR implements a cheap fix: in total_size (and therefore statistics) assume that all ranks have the same shape so that this compile-time reduction is not necessary.

In the future, I think we should do two things:

  • Re-generalise this, by (maybe) making n_samples or n_dof an argument to be specified of statistics and var.
  • Re-factor the mpi code. In particular I would like to remove sum_inplace because it hides the fact that sometimes we use MPI4PY and other times MPI4JAX, and prevents us from using tokens if ever needed for mpi4jax.

Merge request reports