Add option to disable JIT in Jax machines
Created by: femtobit
I've found it to be useful for debugging and development purposes to turn off jax.jit
in machines sometimes. Thus, this PR adds an option enable_jit
to netket.machine.Jax
.
Beside making it possible to put a quick print
statement in the machine code, this more importantly allows one to write proof-of-concept implementations that don't need to be jit
-able right away. Since some things that are natural in normal Python require more effort to implement in JAX (e.g., control flow, ragged arrays), this can be helpful to test an idea before optimizing it for JAX.
Possible downside: Since MetropolisHastings
decides whether to JIT its kernel based on the machine in this PR, the kernel is now recompiled every time the sampler is recreated. I think this is not much of a problem, as recompilation also happens, e.g., when the chain size changes or similar. I don't expect realistic code that re-creates the sampler often enough to see a performance impact to exists.