ode4jax experiment
Created by: PhilipVinc
cc @femtobit
import jax
from jax import numpy as jnp
from netket.experimental import ode4jax
def ode(t, x, **_):
return x
tspan = (0.0, 1.0)
u0 = jnp.array(0.5)
prob = ode4jax.ODEProblem(ode, tspan, u0)
solver = ode4jax.Euler()
# easy one (not jittable)
sol = ode4jax.solve(prob, solver, dt=0.02, saveat=51)
# or
integrator = ode4jax.init(prob, solver, dt=0.02, saveat=51)
jstep = jax.jit(ode4jax.step)
while integrator.t < tspan[1]:
integrator = jstep(integrator)