Fast Jax backend: sampler and jacobians for real and complex networks
Created by: gcarleo
After having experimented with several backends (pytorch, tensorflow, jax), I decided to try to improve our jax backend. Main reasons are:
- Jax has support for complex-valued networks...
- Jax as a very nice (at least for my taste!) functional interface, with autovectorization and possible the most advanced AD interfance on the market, including jacobians, hessians,jacobianvector,vectorjacobian etc
- It runs on gpu and tpu
- It can be jit-compiled
Basically, I wrote a general jax-based sampler with arbitrary transition kernels largely inspired by this blog post and this code (thank you @rlouf !). For real-valued machines this is very fast, and it basically doesn't have the small chains overhead that our numpy sampler has, thus rivaling with our old pure C++ implementation. Not to mention that this can run on GPU.
In the process I also improved our interface to Jax machines and simplified it. Most importantly, der_log
is now computed using a vmap and it is very fast. The old implementation using jacfwd was instead incredibly slow. It should also be considered that numpy uses multi-threading on my laptopt whereas jax is not at the moment. Also, numpy is hard to scale for deep networks and doesn't have AD...
Overall, before starting this work, the example in here was taking about 30 minutes on my laptop. It is now down to about 1 minute and half... still a bit slower than the numpy counterpart, but most likely there is some improvement that can be done.
Everything can be used as a drop in replacements, as
ma = nk.machine.JaxRbm(hi, alpha=1, dtype=complex)
ma.init_random_parameters(sigma=0.01, seed=1232)
# Jax Sampler
sa = nk.sampler.JaxMetropolisLocal(machine=ma, n_chains=16)
I think that now, for the first time since we started thinking about using backends, we have a clear path towards a general and efficient use of modern frameworks for variational applications with neural-network quantum states.
To have a great 3.0 release, I think we would need at minimum to do some further improvements to the Jax backend:
- Remove unnecessary casting to numpy types, that are now becoming the bottleneck when computing local estimators
- Avoid casting parameters of machines to numpy types, and (re) allow using Jax optimizers (see also #392 (closed) and previous discussions with @femtobit )
- Write a jax version of SR
All of this is relatively easy now that we have all-python implementations...
Let me know if you have thoughts on some of these points!