Skip to content

Operators written purely in Jax

Vicentini Filippo requested to merge github/fork/inailuig/jax_operators into master

Created by: inailuig

Currently the operators are written using numba for the very valid reason that the array shapes (number of connected elements) are in general data-dependent which is something not (yet?) supported by jax. These operators can only be evaluated on the cpu, and are called either outside of a jax jit block or via numba4jax which has some overhead.

This PR adds a Ising operator with padding written in jax and a branchless, jax-jittable MetropolisSampler (I wanted some way to test the operator...) that runs fully on the gpu.

For Paulistrings the same can also be done easily (apart from the constructor).

Merge request reports