Skip to content

Fix jax 0.2 bug

Vicentini Filippo requested to merge PhilipVinc/jax_bugfix into v3.0

Created by: PhilipVinc

Jax recently released v0.2, which enables new cool omnistaging stuff but it broke this code for some reason.

I think the double jitting was confusing a bit jax.

This should not affect performance, as the external block is jitted anyway, and it's backward compatible.

Merge request reports