Skip to content

Add time evolution driver with jax-based ODE solver

Vicentini Filippo requested to merge dh/tevo-jax into master

Created by: femtobit

This PR adds an implementation of t-VMC based on the synthesis of @PhilipVinc's and my ideas for t-VMC propagation using jax. This is added to the netket.experimental submodule for the moment, to give us more freedom to continue improving the API (in particular of the included ODE solver routines).

This PR contains two main parts:

  1. A jax-based implementation of Runge-Kutta integration schemes (both adaptive and with fixed step size) in netket.experimental.dynamics.runge_kutta. These are not tied to t-VMC or even variational quantum states, but can in principle be applied to any explicit ODE defined as a Python function dx/dt = f(x, t). The code is written with the goal of making the whole solver jitable by jax (although this is not enabled in this PR, since our VMC code is not yet jitable.)
  2. The netket.experimental.TDVP driver, which supports t-VMC propagation of pure NQS (that is, MCState) based on either the real-time or imaginary-time Schrödinger equation. Lindbladian real-time dynamics for NDMs using MCMixedState are also supported.

Basic usage of the driver looks like this:

integrator = nkx.dynamics.RK23(dt=0.01, adaptive=True, rtol=1e-3, atol=1e-3)
te = nkx.TDVP(
    hamiltonian,
    variational_state=vstate,
    integrator=integrator,
)
te.run(
    T=1.0,
    out="example_dynamics",
    tstops=np.linspace(0.0, 1.0, 101, endpoint=True),
)

See the included example (Examples/Dynamics/ising1d.py) for a complete demo. There are also tests under test/dynamics that show some of the parts of this PR in action.

Closes #289 (closed), closes #435 (closed).

Merge request reports