Skip to content

[RFC] Autonomous schedules

Created by: attila-i-szabo

We sometimes have hyperparameters we'd like to change during the training procedure, but they aren't learning rates (so they don't fit with full optax optimizers) nor is it convenient to store them in the object that uses them. The best example is diag_shift in QGT: literature suggests it's good to lower its value as SR goes on, but the QGT object lives for a single iteration, so it has no memory of such evolution. In #1060, we'd have double the problem, as the pseudo-Hessian object would have to tune both diag_shift and the learning rate that no longer fits in the optax paradigm.

This PR solves this problem by introducing a decorator that turns a schedule (i.e. a step count → parameter function) into an "autonomous schedule" (i.e. a function with internal state that can be called without arguments to return schedule(0), schedule(1),... on each call). E.g.:

@nk.optimizer.autonomous
def schedule(x):
    return 3*x
schedule()
>>> 0
schedule()
>>> 3
schedule()
>>> 6

The QGT "smart constructors" are also modified to take such autonomous schedules for diag_shift in addition to numbers. If a function is passed to them, they call it without arguments to convert it into the desired value of diag_shift for that iteration. It would be nice to place this logic closer to the definition of LinearOperator, but I can't think of a way to do this with a dataclass (without breaking the contract with using setattr).

Merge request reports