Fix jax 0.2 bug
Created by: PhilipVinc
Jax recently released v0.2, which enables new cool omnistaging stuff but it broke this code for some reason.
I think the double jitting was confusing a bit jax.
This should not affect performance, as the external block is jitted anyway, and it's backward compatible.