Skip to content

Speed up `QGTOnTheFly` (take 2)

Vicentini Filippo requested to merge github/fork/attila-i-szabo/QGTOnTheFly into master

Created by: attila-i-szabo

This PR builds on the discussion in #781 to make the logic of QGTOnTheFly more performant.

  • All the heavy lifting in class QGTOnTheFlyT is done by the Callable member mat_vec which performs the matrix-vector product for pytrees
  • This function is generated by the jitted function mat_vec_factory as jax.Partial(mat_vec, jvp_fn), where jvp_fn is the output of jax.linearize, a pytree encoding a linear transformation. This allows us to perform the linearisation work only once per SR step, and optimise the way it's done once and for all.
  • In the logic file I chose to separate mat_vec from the factory both for aesthetic reasons and because jitting the Partial in solve and @ might be more efficient if it's always the same function, not constructed each time inside mat_vec_factory

I also fixed the deprecation of centered; if it's passed to the upstream version, it breaks the constructor of QGTOnTheFlyT

Outstanding issues:

  • I've commented out all the tests relevant for QGTOnTheFly from test/optimizer/test_qgt_logic.py as the logic has completely changed, and I'm not sure what new tests would be appropriate for it.
  • I'll implement benchmarks to compare the speed of this implementation to the upstream one. My cluster is down for the weekend, so I'll report back on this later.
  • When I can run the benchmarks, I want to check whether jitting mat_vec in the logic file would improve performance (I suspect not)

Merge request reports