Skip to content

MPI+jax=mpi4jax

Vicentini Filippo requested to merge PhilipVinc/mmagimmagia into v3.0

Created by: PhilipVinc

Use mpi4jax to jit through MPI calls in jax-jitted code. This allows SR to work with jax machines in many node setups.

The dependency is optional: if jax is installed, but only 1 process is used, nothing is required. If more than 1 process is used with jax machines, at SR construction an error is thrown instructing the user to pip install mpi4jax.

This complicates the test matrix, however, as we should be running MPI tests with and without mpi4jax, which I still have to do. Missing from this PR are tests and maybe some informational message when someone first runs netket

Merge request reports