[RFC] Redesigning SR interface (3.0 or later?)
Created by: PhilipVinc
Me and @gcarleo were recently discussing the fact that we should also redesign the interface used to access the S matrix and to perform Stochastic Reconfiguration Natural Gradient.
The aim is to achieve easier extendability (so that anyone can write his own version of the S matrix if he wants) and composability (play well with solvers from jax/scipy and others, possibly without requiring to wrap them as it's needed now).
This Issue wants to discuss two items:
- Whever we should do this redesign for v3.0, when to deprecate the old interface, or if we should put it on hold for a v3.1 or 3.2 later.
- The design of the new system
To recap, the current interface is the following:
-
A SR object holds settings on the type of S matrix representation used (dense, lazy-onthefly, lazy-jacobian...) and the algorithm used to solve Sx=F. Every object should correspond to only 1 set of choices (lazy-onthefly+cg, lazy-onthefly+gmres, ecc...) . In netket v3.0b1 there is only one possible representation of the S matrix so that was not a big issue.
- As shown in the PR #648, this design has issues, as adding a new type of S matrix representation (lazy-jacobian in the PR) requires either duplicating all the types corresponding to solvers, which is inconvenient
- The natural thing to do would be to split the S matrix representation from the solver used.
-
The SR object can build the S matrix itself if it's given a variational state
S = sr.create_S(state)
. -
The S matrix keeps a reference to the SR object, so that you can do
S.solve(F)
and it will use the parameters from the sr object.
sr = nk.optimizer.sr.LazyCG(diag_shift=0.01, maxiters=100)
S = vstate.quantum_geometric_tensor(sr) # equivalent to S = sr.create(vstate)
_, F = vstate.expect_and_grad(ham)
x, info = S.solve(F) # uses an inner field S.sr for the parameters and solve function
While the system described above, where all configuration is stored in a single structure, it makes it rather easy to use SR inside of a driver, as we can simply pass
gs = nk.VMC(ham, optim, variational_state=vstate, sr=sr)
and inside the driver will use this object like shown above.
--
Tentative new design:
Goal: be able to use scipy/jax solvers out of the box.
-
Every kind of SR matrix has it's own constructor.
SMatrixOnTheFly(vstate)
,SMatrixJacbian(vstate)
etcetera that can be used explicitly. -
An helper function
SMatrix(SMatrixType, vstate) = SMatrixType(vstate)
is provided. If SMatrixType is not passed some sensible default representation that always works is used. -
vstate.quantum_geometric_tensor()
will now take as input theSMatrix
type and relay the call to the type. If no type is passed, a sensible default is used. -
All S matrix must have the methods
__matmul__
and__call__
supporting both PyTrees and dense vectors, so that the matrix can be passed to sparse solvers.
With this, what is below should work.
S = SMatrixOnTheFly(vstate)
x, info = jax.scipy.sparse.cg(S, F, maxiter=100)
A question arises: we often regularise the S matrix to have a shift in the diagonal. To support it under this API, we should also ask all implementation of the S matrix to support ___add__(self, x:Number)
and keep the diagonal shift in memory (if it's a lazy representation) or simply add it to the dense matrix if it's a dense representation.
Some implementations might even support __mul__
or other conditionings.
We will then be able to do
S = SMatrixOnTheFly(vstate)
S = S + 0.01 # equivalent to S.diag_shift+=0.01
x, info = jax.scipy.sparse.cg(S, F, maxiter=100)
(Note: this is inconsistent with numpy api, where adding a number to a matrix adds it to all the entries in the matrix, but is consistent with our implementation of LocalOperators where adding a number to a local operator only adds it to the diagonal).
We could even get this to work with scipy (not jax) sparse solvers if we also implement S.shape
to report the number of parameters.
S = SMatrixOnTheFly(vstate)
# assuming S.shape = (vstate.n_parameters, vstate.n_parameters)
S = S + 0.01 # equivalent to S.diag_shift+=0.01
F_dense, F_unravel = nk.jax.ravel(F)
x, info = scipy.sparse.cg(S, F_dense, maxiter=100)
or even
S = SMatrixOnTheFly(vstate)
S = S + 0.01 # equivalent to S.diag_shift+=0.01
Sm1 = np.linalg.pinv(S.to_dense)
x = Sm1@F_dense
So all seems great! The only thing we need to think about is how to make all this play with the Driver API.
How to support this? we could accept two kwargs in the drivers:
- S_type: Optional[SMatrixType] = The type of the S Matrix you want to use, that should support doing `S = S_type(vstate)`
- SR_solver: Optional[Callable] = The function to solve the linear system Sx=F. It must have signature SR_solver(S:SMatrix, F:PyTree, **kwargs) -> Tuple[x:PyTree, info:Any]
If both are not declared (None) sr is not used. If one of the two is passed we use SR and the unspecified kwarg goes to a default. Internally we could do something like
def __init__(S_type=None, SR_solver=None):
if S_type is None and Sr_solver is None:
self.use_sr = False
else:
self.use_sr = True
if use_sr = True and S_type is None:
S_type = default
...
def _forward_and_backward(self):
"""
Performs a number of VMC optimization steps.
"""
self.state.reset()
# Compute the local energy estimator and average Energy
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
if self.sr is not None:
self._S = self.S_type(self.vstate)
# use the previous solution as an initial guess to speed up the solution of the linear system
x0 = self._dp if self.sr_restart is False else None
self._dp, self._sr_info = self.SR_solver(self._S, self._loss_grad, x0=x0)
For the user, to specify kwargs of the solver like we do now, he would need to consult the docs of that solver and specify it with a functools.partial
. Example:
from functools import partial
SR_solver = partial(jax.scipy.sparse.gmres, maxiter=300, restart=10)
# use default S_matrix type
gs = nk.VMC(ham, optim, variational_state=vstate, SR_solver= SR_solver)
# or
gs = nk.VMC(ham, optim, variational_state=vstate, SR_solver= SR_solver, S_matrix=nk.optimizer.sr.SJacobian)
However how to include the diagonal shift? One would have to do
def srsolver(S,F,**kwargs):
return jax.scipy.sparse.gmres(S+0.01, F, **kwargs)
SR_solver = partial(srsolver, maxiter=300, restart=10)
which is not too clean...