Avoid errors when calling .get_conn_flattened with jax array
Created by: PhilipVinc
Right now if we call .get_conn_flattened
with a jax array we have a crash due to a numba bug https://github.com/numba/numba/issues/6980
In this PR we convert everything to numpy before calling numba.
This is useful as when we will remove legacy all states will mostly be jax and not numpy...