Skip to content

[experimental] Add initial support for jax multi-device and multi-process environments

Vicentini Filippo requested to merge github/fork/inailuig/pjit into master

Created by: inailuig

Fully relies on the trivial automatic parallelization over the chains using shared global jax.Array's.

Does the following:

  • (>1 host) broadcast the parameters e.g. using jax.experimental.multihost_utils.broadcast_one_to_all Not necessary anymore as we enforce the same seed on all processes.
  • turn the sampler state (just the \sigma) into a global device array e.g. using jax.make_array_from_single_device_arrays using jit with out_shardings. If we have hilbert.random written in jax we can jit that instead.
  • wrap the call to the numba operator so that the xp and mels are shared arrays just like the samples inputted to it were (e.g. using jax.device_put and jax.make_array_from_single_device_arrays using shard_map), or use operators written in jax
  • extract the results from the shared arrays before printing/saving
  • use jax_threefry_partitionable=True (their new rng impl)

Supports both multiple local devices as well as jax.distributed (the latter is only available on gpu and tpu) On gpu it Internally uses the nvidia NCCL library for communication, so it should be quite efficient. On cpu one can use multiple devices with threads via --xla_force_host_platform_device_count=XX, multi-node is not available yet.

Uses grpc for setting up the communication. On the cluster, if you get seemingly unrelated grpc errors coming from unrelated ip addresses, or it does not work at all, this might be because of incompatible http_proxy no_proxy lists with wildcards, which grpc is not able to parse, and thus tries to send the traffic through the proxy even if it shouldnt. unset http_proxy and https_proxy env variables, or set the no_proxy/grpc_no_proxy correctly by hand (see https://grpc.github.io/grpc/cpp/md_doc_environment_variables.html)

My initial benchmarks on gpu showed it's competitive with mpi.

Merge request reports

Loading