Speed up `QGTOnTheFly`
Created by: attila-i-szabo
This PR introduces a more effective way to handle JVPs and VJPs in QGTOnTheFly
.
Current changes
-
mat_vec
in the logic file is replaced withmat_vec_fun
that takes care of the logic included inDeltaO...
andmat_vec
and returns a matrix-vector product function that won't perform any differentation -
solve
is augmented with a function_fast_solve
that calls thismat_vec_fun
only once and runs the solver with the output. We need to figure out how to switch between this and_solve
: the current approach won't work as the VMC object won't adjust thefast
parameter. We might want to add it as a boolean parameter toQGTOnTheFLyT
?
The observation is that QGTOnTheFlyT::__matmul__
calls jax.jvp
and jax.vjp
every time a matrix-vector product is needed, sometimes more than once. These are heavy functions because they have to linearise the network. By contrast, the function produced by vjp
(and jax.linearize
in place of jvp
) is a bunch of linear transformations, which is faster to evaluate than jvp
or VJPing from scratch. This means that if we can cache the matrix-vector product function in terms of these linear functions, solve
with an iterative solver (which needs a bunch of matrix-vector products) should speed up.
Rewriting QGTOnTheFlyT
in terms of such a cached function doesn't work well, however, because we can't jit
over the cached function (see this branch and this question for more details, the code in that branch is way slower than the upstream one). This mean that I leave the broad structure unchanged: mat_vec
and solve
takes PyTrees with parameters of the vstate
etc. and returns a PyTree, so they can be jitted as a whole.
Benchmarks
See here. The bottom line is that the code in this PR speeds up solve
by about 20% for dense networks and by a factor of 2 for GCNNs. This suggests that jit
does a pretty good job at eliminating the overhead I'm worrying about, but we can improve on it, especially if the computational graph is complicated.
Further ideas
- We still perform the same linearisation twice, once in
linearize
, once invjp
. We could save on this by writing one as thejax.linear_transpose
of the other. Would this be an actual improvement? - The same idea could remove
O_mean
and the associated complications in calculating the centred S matrix. Ifjvp_fun
is a linear function, so islambda v: subtract_mean(jvp_fun(v))
, so in principle we canlinear_transpose
that too. This linear transpose be the VJP with the centred Jacobian, without having to worry about the dtypes of the network arguments, meaning we can get the centred S matrix with a single linearisation step. The only place where this can go wrong issubtract_mean
which calls MPI. @PhilipVinc would you expect this to work?
Tests
I'll fix the tests as soon as any changes due to the ideas above are done. (It's not clear which of the logic functions survive and in what form.)