Skip to content

Avoid errors when calling .get_conn_flattened with jax array

Vicentini Filippo requested to merge pv/numba into master

Created by: PhilipVinc

Right now if we call .get_conn_flattened with a jax array we have a crash due to a numba bug https://github.com/numba/numba/issues/6980

In this PR we convert everything to numpy before calling numba.

This is useful as when we will remove legacy all states will mostly be jax and not numpy...

Merge request reports