Speed up `QGTOnTheFly` (take 2)
Created by: attila-i-szabo
This PR builds on the discussion in #781 to make the logic of QGTOnTheFly
more performant.
- All the heavy lifting in class
QGTOnTheFlyT
is done by theCallable
membermat_vec
which performs the matrix-vector product for pytrees - This function is generated by the jitted function
mat_vec_factory
asjax.Partial(mat_vec, jvp_fn)
, wherejvp_fn
is the output ofjax.linearize
, a pytree encoding a linear transformation. This allows us to perform the linearisation work only once per SR step, and optimise the way it's done once and for all. - In the logic file I chose to separate
mat_vec
from the factory both for aesthetic reasons and because jitting thePartial
insolve
and@
might be more efficient if it's always the same function, not constructed each time insidemat_vec_factory
I also fixed the deprecation of centered
; if it's passed to the upstream version, it breaks the constructor of QGTOnTheFlyT
Outstanding issues:
- I've commented out all the tests relevant for
QGTOnTheFly
fromtest/optimizer/test_qgt_logic.py
as the logic has completely changed, and I'm not sure what new tests would be appropriate for it. - I'll implement benchmarks to compare the speed of this implementation to the upstream one. My cluster is down for the weekend, so I'll report back on this later.
- When I can run the benchmarks, I want to check whether jitting
mat_vec
in the logic file would improve performance (I suspect not)