Skip to content

Make a CPU-Pmappable metropolis sampler

Vicentini Filippo requested to merge pv/sample/pmap into master

Created by: PhilipVinc

Our (mertopolis) samplers saturate one single CPU but cannot exploit multiple CPUs. Though metropolis is trivially parallizable. This is particularly interesting together with the SVD/Cholesky solvers coming with the QGT pr, because those benefit from having few mpi ranks with lots of cpu cores...

This implements a new MetropolisSamplerPmap which is a 100% dropin replacement for MetropolisSampler, but uses jax.pmap.

Using it on CPU requires setting the environment varible

os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=X"

where X is the number of jax devices run on CPU. A good approximation might be using 1 cpu per core. Using it on GPU would distribute the sampling across multiple GPUs but i'm not sure this is usefull.

(All changes required for this PR are in the MetropolisPmap file, the other changes are just a cleanup of the sampler interface, but not really required)

Merge request reports