Skip to content

JIT-compile the ODE time step functions

Vicentini Filippo requested to merge tevo-jit into master

Created by: femtobit

This PR makes the time-stepping routines of (adaptive and fixed-step) RK integrators fully compatible with jax.jit and enables it.

Our MCState.expect_and_grad code is not jit-able yet. Therefore, as a workaround, we use jax.experimental.host_callback in order to have the compiled solver code call back into the pure Python ODE function for now.

Enabling JIT shows a clear speed-up when used with jit-able ODE functions, like those test_solvers.py. These are timings for the tests on my machine, before and after:

# before
207.35s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK12]
7.25s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK23]
1.45s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK45]
0.13s call     test/dynamics/test_solvers.py::test_ode_solver[RK4]
0.08s call     test/dynamics/test_solvers.py::test_ode_solver[Midpoint]
0.08s call     test/dynamics/test_solvers.py::test_ode_solver[Heun]
0.07s call     test/dynamics/test_solvers.py::test_ode_solver[Euler]
# after
1.53s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK12]
0.41s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK45]
0.34s call     test/dynamics/test_solvers.py::test_adaptive_solver[RK23]
0.13s call     test/dynamics/test_solvers.py::test_ode_solver[RK4]
0.08s call     test/dynamics/test_solvers.py::test_ode_solver[Heun]
0.08s call     test/dynamics/test_solvers.py::test_ode_solver[Midpoint]
0.07s call     test/dynamics/test_solvers.py::test_ode_solver[Euler]

In t-VMC, the complexity of sampling and solving the TDVP equation is more relevant and we need the host_callback, but at least on my machine I still see a slight speed-up.

Note that I needed to move the name: str attribute out of the tableau class, because str is not a JAX-compatible type and it is nice to be able to pass around the tableau in jit-ed functions.

Merge request reports