[RFC] Autonomous schedules
Created by: attila-i-szabo
We sometimes have hyperparameters we'd like to change during the training procedure, but they aren't learning rates (so they don't fit with full optax
optimizers) nor is it convenient to store them in the object that uses them. The best example is diag_shift
in QGT: literature suggests it's good to lower its value as SR goes on, but the QGT object lives for a single iteration, so it has no memory of such evolution. In #1060, we'd have double the problem, as the pseudo-Hessian object would have to tune both diag_shift
and the learning rate that no longer fits in the optax
paradigm.
This PR solves this problem by introducing a decorator that turns a schedule (i.e. a step count → parameter function) into an "autonomous schedule" (i.e. a function with internal state that can be called without arguments to return schedule(0), schedule(1),...
on each call). E.g.:
@nk.optimizer.autonomous
def schedule(x):
return 3*x
schedule()
>>> 0
schedule()
>>> 3
schedule()
>>> 6
The QGT "smart constructors" are also modified to take such autonomous schedules for diag_shift
in addition to numbers. If a function is passed to them, they call it without arguments to convert it into the desired value of diag_shift
for that iteration. It would be nice to place this logic closer to the definition of LinearOperator
, but I can't think of a way to do this with a dataclass
(without breaking the contract with using setattr
).