Make a CPU-Pmappable metropolis sampler
Created by: PhilipVinc
Our (mertopolis) samplers saturate one single CPU but cannot exploit multiple CPUs. Though metropolis is trivially parallizable. This is particularly interesting together with the SVD/Cholesky solvers coming with the QGT pr, because those benefit from having few mpi ranks with lots of cpu cores...
This implements a new MetropolisSamplerPmap which is a 100% dropin replacement for MetropolisSampler, but uses jax.pmap
.
Using it on CPU requires setting the environment varible
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=X"
where X is the number of jax devices run on CPU. A good approximation might be using 1 cpu per core. Using it on GPU would distribute the sampling across multiple GPUs but i'm not sure this is usefull.
(All changes required for this PR are in the MetropolisPmap file, the other changes are just a cleanup of the sampler interface, but not really required)