Stochastic Reconfiguration in Jax
Created by: chrisrothUT
Hi everyone,
I fixed the code so that stochastic reconfiguration works with Jax. Right now the only SR solver that is implemented with Jax types is "jaxcg" which does conjugate gradient with Jax. I'll post a speed comparison shortly.
Unfortunately, the exact solvers in Jax are causing SIGBUS errors for systems of more than 10 electrons, so I have left their SciPy implementations currently.
Now that Jax is working on CPUs, I will try to get it working on a GPU. Soon we may be able to train deep networks with stochastic reconfiguration.