Skip to content

Add option to disable JIT in Jax machines

Vicentini Filippo requested to merge github/fork/femtobit/enable-jit into master

Created by: femtobit

I've found it to be useful for debugging and development purposes to turn off jax.jit in machines sometimes. Thus, this PR adds an option enable_jit to netket.machine.Jax.

Beside making it possible to put a quick print statement in the machine code, this more importantly allows one to write proof-of-concept implementations that don't need to be jit-able right away. Since some things that are natural in normal Python require more effort to implement in JAX (e.g., control flow, ragged arrays), this can be helpful to test an idea before optimizing it for JAX.

Possible downside: Since MetropolisHastings decides whether to JIT its kernel based on the machine in this PR, the kernel is now recompiled every time the sampler is recreated. I think this is not much of a problem, as recompilation also happens, e.g., when the chain size changes or similar. I don't expect realistic code that re-creates the sampler often enough to see a performance impact to exists.

Merge request reports