MPI+jax=mpi4jax
Created by: PhilipVinc
Use mpi4jax to jit through MPI calls in jax-jitted code. This allows SR to work with jax machines in many node setups.
The dependency is optional: if jax is installed, but only 1 process is used, nothing is required. If more than 1 process is used with jax machines, at SR construction an error is thrown instructing the user to pip install mpi4jax
.
This complicates the test matrix, however, as we should be running MPI tests with and without mpi4jax, which I still have to do. Missing from this PR are tests and maybe some informational message when someone first runs netket