Skip to content

Stochastic Reconfiguration in Jax

Vicentini Filippo requested to merge github/fork/chrisrothUT/jax_sr into v3.0

Created by: chrisrothUT

Hi everyone,

I fixed the code so that stochastic reconfiguration works with Jax. Right now the only SR solver that is implemented with Jax types is "jaxcg" which does conjugate gradient with Jax. I'll post a speed comparison shortly.

Unfortunately, the exact solvers in Jax are causing SIGBUS errors for systems of more than 10 electrons, so I have left their SciPy implementations currently.

Now that Jax is working on CPUs, I will try to get it working on a GPU. Soon we may be able to train deep networks with stochastic reconfiguration.

Merge request reports