Add time evolution driver with jax-based ODE solver
Created by: femtobit
This PR adds an implementation of t-VMC based on the synthesis of @PhilipVinc's and my ideas for t-VMC propagation using jax
. This is added to the netket.experimental
submodule for the moment, to give us more freedom to continue improving the API (in particular of the included ODE solver routines).
This PR contains two main parts:
- A jax-based implementation of Runge-Kutta integration schemes (both adaptive and with fixed step size) in
netket.experimental.dynamics.runge_kutta
. These are not tied to t-VMC or even variational quantum states, but can in principle be applied to any explicit ODE defined as a Python functiondx/dt = f(x, t)
. The code is written with the goal of making the whole solver jitable byjax
(although this is not enabled in this PR, since our VMC code is not yet jitable.) - The
netket.experimental.TDVP
driver, which supports t-VMC propagation of pure NQS (that is,MCState
) based on either the real-time or imaginary-time Schrödinger equation. Lindbladian real-time dynamics for NDMs usingMCMixedState
are also supported.
Basic usage of the driver looks like this:
integrator = nkx.dynamics.RK23(dt=0.01, adaptive=True, rtol=1e-3, atol=1e-3)
te = nkx.TDVP(
hamiltonian,
variational_state=vstate,
integrator=integrator,
)
te.run(
T=1.0,
out="example_dynamics",
tstops=np.linspace(0.0, 1.0, 101, endpoint=True),
)
See the included example (Examples/Dynamics/ising1d.py
) for a complete demo.
There are also tests under test/dynamics
that show some of the parts of this PR in action.
Closes #289 (closed), closes #435 (closed).