[experimental] Add initial support for jax multi-device and multi-process environments
Created by: inailuig
Fully relies on the trivial automatic parallelization over the chains using shared global jax.Array's.
Does the following:
-
(>1 host) broadcast the parameters e.g. using jax.experimental.multihost_utils.broadcast_one_to_allNot necessary anymore as we enforce the same seed on all processes. - turn the sampler state (just the \sigma) into a global device array
e.g. using jax.make_array_from_single_device_arraysusing jit with out_shardings. If we have hilbert.random written in jax we can jit that instead. - wrap the call to the numba operator so that the xp and mels are shared arrays just like the samples inputted to it were (
e.g. using jax.device_put and jax.make_array_from_single_device_arraysusing shard_map), or use operators written in jax - extract the results from the shared arrays before printing/saving
- use jax_threefry_partitionable=True (their new rng impl)
Supports both multiple local devices as well as jax.distributed (the latter is only available on gpu and tpu)
On gpu it Internally uses the nvidia NCCL library for communication, so it should be quite efficient.
On cpu one can use multiple devices with threads via --xla_force_host_platform_device_count=XX
, multi-node is not available yet.
Uses grpc for setting up the communication. On the cluster, if you get seemingly unrelated grpc errors coming from unrelated ip addresses, or it does not work at all, this might be because of incompatible http_proxy no_proxy lists with wildcards, which grpc is not able to parse, and thus tries to send the traffic through the proxy even if it shouldnt. unset http_proxy and https_proxy env variables, or set the no_proxy/grpc_no_proxy correctly by hand (see https://grpc.github.io/grpc/cpp/md_doc_environment_variables.html)
My initial benchmarks on gpu showed it's competitive with mpi.