Skip to content

Speed up `QGTOnTheFly`

Vicentini Filippo requested to merge github/fork/attila-i-szabo/fly-fast3 into master

Created by: attila-i-szabo

This PR introduces a more effective way to handle JVPs and VJPs in QGTOnTheFly.

Current changes

  • mat_vec in the logic file is replaced with mat_vec_fun that takes care of the logic included in DeltaO... and mat_vec and returns a matrix-vector product function that won't perform any differentation
  • solve is augmented with a function _fast_solve that calls this mat_vec_fun only once and runs the solver with the output. We need to figure out how to switch between this and _solve: the current approach won't work as the VMC object won't adjust the fast parameter. We might want to add it as a boolean parameter to QGTOnTheFLyT?

The observation is that QGTOnTheFlyT::__matmul__ calls jax.jvp and jax.vjp every time a matrix-vector product is needed, sometimes more than once. These are heavy functions because they have to linearise the network. By contrast, the function produced by vjp (and jax.linearize in place of jvp) is a bunch of linear transformations, which is faster to evaluate than jvp or VJPing from scratch. This means that if we can cache the matrix-vector product function in terms of these linear functions, solve with an iterative solver (which needs a bunch of matrix-vector products) should speed up.

Rewriting QGTOnTheFlyT in terms of such a cached function doesn't work well, however, because we can't jit over the cached function (see this branch and this question for more details, the code in that branch is way slower than the upstream one). This mean that I leave the broad structure unchanged: mat_vec and solve takes PyTrees with parameters of the vstate etc. and returns a PyTree, so they can be jitted as a whole.

Benchmarks

See here. The bottom line is that the code in this PR speeds up solve by about 20% for dense networks and by a factor of 2 for GCNNs. This suggests that jit does a pretty good job at eliminating the overhead I'm worrying about, but we can improve on it, especially if the computational graph is complicated.

Further ideas

  1. We still perform the same linearisation twice, once in linearize, once in vjp. We could save on this by writing one as the jax.linear_transpose of the other. Would this be an actual improvement?
  2. The same idea could remove O_mean and the associated complications in calculating the centred S matrix. If jvp_fun is a linear function, so is lambda v: subtract_mean(jvp_fun(v)), so in principle we can linear_transpose that too. This linear transpose be the VJP with the centred Jacobian, without having to worry about the dtypes of the network arguments, meaning we can get the centred S matrix with a single linearisation step. The only place where this can go wrong is subtract_mean which calls MPI. @PhilipVinc would you expect this to work?

Tests

I'll fix the tests as soon as any changes due to the ideas above are done. (It's not clear which of the logic functions survive and in what form.)

Merge request reports