Skip to content

ode4jax experiment

Vicentini Filippo requested to merge pv/tevo-jax into tevo-jax

Created by: PhilipVinc

cc @femtobit

import jax
from jax import numpy as jnp

from netket.experimental import ode4jax


def ode(t, x, **_):
    return x

tspan = (0.0, 1.0)
u0 = jnp.array(0.5)

prob = ode4jax.ODEProblem(ode, tspan, u0)

solver = ode4jax.Euler()

# easy one (not jittable)
sol = ode4jax.solve(prob, solver, dt=0.02, saveat=51)

# or 
integrator = ode4jax.init(prob, solver, dt=0.02, saveat=51)
jstep = jax.jit(ode4jax.step)
while integrator.t < tspan[1]:
    integrator = jstep(integrator)

Merge request reports