fix mpi test
Created by: PhilipVinc
This PR fixes two things:
- an mpi test broken by recent jax changes (how they split rng keys)
- The fact that until https://github.com/google/jax/issues/11916 is fixed, some tests will be failing on recent versions of jax. As the failing tests are quite exotic, this should not be a big problem