[experimental] Add initial support for jax multi-device and multi-process environments
Loading
Created by: inailuig
Fully relies on the trivial automatic parallelization over the chains using shared global jax.Array's.
Does the following:
Supports both multiple local devices as well as jax.distributed (the latter is only available on gpu and tpu)
On gpu it Internally uses the nvidia NCCL library for communication, so it should be quite efficient.
On cpu one can use multiple devices with threads via --xla_force_host_platform_device_count=XX
, multi-node is not available yet.
Uses grpc for setting up the communication. On the cluster, if you get seemingly unrelated grpc errors coming from unrelated ip addresses, or it does not work at all, this might be because of incompatible http_proxy no_proxy lists with wildcards, which grpc is not able to parse, and thus tries to send the traffic through the proxy even if it shouldnt. unset http_proxy and https_proxy env variables, or set the no_proxy/grpc_no_proxy correctly by hand (see https://grpc.github.io/grpc/cpp/md_doc_environment_variables.html)
My initial benchmarks on gpu showed it's competitive with mpi.