JIT-compile the ODE time step functions
Created by: femtobit
This PR makes the time-stepping routines of (adaptive and fixed-step) RK integrators fully compatible with jax.jit
and enables it.
Our MCState.expect_and_grad
code is not jit-able yet. Therefore, as a workaround, we use jax.experimental.host_callback
in order to have the compiled solver code call back into the pure Python ODE function for now.
Enabling JIT shows a clear speed-up when used with jit-able ODE functions, like those test_solvers.py
. These are timings for the tests on my machine, before and after:
# before
207.35s call test/dynamics/test_solvers.py::test_adaptive_solver[RK12]
7.25s call test/dynamics/test_solvers.py::test_adaptive_solver[RK23]
1.45s call test/dynamics/test_solvers.py::test_adaptive_solver[RK45]
0.13s call test/dynamics/test_solvers.py::test_ode_solver[RK4]
0.08s call test/dynamics/test_solvers.py::test_ode_solver[Midpoint]
0.08s call test/dynamics/test_solvers.py::test_ode_solver[Heun]
0.07s call test/dynamics/test_solvers.py::test_ode_solver[Euler]
# after
1.53s call test/dynamics/test_solvers.py::test_adaptive_solver[RK12]
0.41s call test/dynamics/test_solvers.py::test_adaptive_solver[RK45]
0.34s call test/dynamics/test_solvers.py::test_adaptive_solver[RK23]
0.13s call test/dynamics/test_solvers.py::test_ode_solver[RK4]
0.08s call test/dynamics/test_solvers.py::test_ode_solver[Heun]
0.08s call test/dynamics/test_solvers.py::test_ode_solver[Midpoint]
0.07s call test/dynamics/test_solvers.py::test_ode_solver[Euler]
In t-VMC, the complexity of sampling and solving the TDVP equation is more relevant and we need the host_callback
, but at least on my machine I still see a slight speed-up.
Note that I needed to move the name: str
attribute out of the tableau class, because str
is not a JAX-compatible type and it is nice to be able to pass around the tableau in jit-ed functions.