Skip to content

Speed up JaxSR and simplify it

Vicentini Filippo requested to merge jax_sr into v3.0

Created by: gcarleo

I sped up JaxSR using a couple of well placed jit decorators + some luck with exchanging the order of one matmul. I have removed the non iterative solvers, since they perform quite poorly. The situation might change as they become more stable in jax.

In my tests on the standard ising benchmark, JaxSr is now about 40% faster than the corresponding Numpy version, whereas the pure gradient descent (no SR) is about 60% faster.

This version is still purely serial, since there is an issue with jitting calls to MPI (that should be solvable, but maybe @PhilipVinc knows better).

Merge request reports