Speed up `QGTOnTheFly`
Created by: attila-i-szabo
This PR introduces a more effective way to handle JVPs and VJPs in QGTOnTheFly.
Current changes
-
mat_vecin the logic file is replaced withmat_vec_funthat takes care of the logic included inDeltaO...andmat_vecand returns a matrix-vector product function that won't perform any differentation -
solveis augmented with a function_fast_solvethat calls thismat_vec_funonly once and runs the solver with the output. We need to figure out how to switch between this and_solve: the current approach won't work as the VMC object won't adjust thefastparameter. We might want to add it as a boolean parameter toQGTOnTheFLyT?
The observation is that QGTOnTheFlyT::__matmul__ calls jax.jvp and jax.vjp every time a matrix-vector product is needed, sometimes more than once. These are heavy functions because they have to linearise the network. By contrast, the function produced by vjp (and jax.linearize in place of jvp) is a bunch of linear transformations, which is faster to evaluate than jvp or VJPing from scratch. This means that if we can cache the matrix-vector product function in terms of these linear functions, solve with an iterative solver (which needs a bunch of matrix-vector products) should speed up.
Rewriting QGTOnTheFlyT in terms of such a cached function doesn't work well, however, because we can't jit over the cached function (see this branch and this question for more details, the code in that branch is way slower than the upstream one). This mean that I leave the broad structure unchanged: mat_vec and solve takes PyTrees with parameters of the vstate etc. and returns a PyTree, so they can be jitted as a whole.
Benchmarks
See here. The bottom line is that the code in this PR speeds up solve by about 20% for dense networks and by a factor of 2 for GCNNs. This suggests that jit does a pretty good job at eliminating the overhead I'm worrying about, but we can improve on it, especially if the computational graph is complicated.
Further ideas
- We still perform the same linearisation twice, once in
linearize, once invjp. We could save on this by writing one as thejax.linear_transposeof the other. Would this be an actual improvement? - The same idea could remove
O_meanand the associated complications in calculating the centred S matrix. Ifjvp_funis a linear function, so islambda v: subtract_mean(jvp_fun(v)), so in principle we canlinear_transposethat too. This linear transpose be the VJP with the centred Jacobian, without having to worry about the dtypes of the network arguments, meaning we can get the centred S matrix with a single linearisation step. The only place where this can go wrong issubtract_meanwhich calls MPI. @PhilipVinc would you expect this to work?
Tests
I'll fix the tests as soon as any changes due to the ideas above are done. (It's not clear which of the logic functions survive and in what form.)