jit a few internal functions to make them work with global device arrays
Created by: inailuig
i.e. arrays which are isinstance(x, jax.Array) and not x.is_fully_addressable
.
See https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration