Skip to content
Snippets Groups Projects
Unverified Commit 1f80e7d5 authored by Filippo Vicentini's avatar Filippo Vicentini Committed by GitHub
Browse files

Add `pinv` solver, based on jax.numpy.linalg.pinv (#1525)

For completeness, this is based on `pinv`. 
The docstring suggests that `pinv_smooth` performs better in general.
parent 6d10c148
No related branches found
No related tags found
No related merge requests found
......@@ -79,6 +79,7 @@ And the following dense solvers for Stochastic Reconfiguration:
solver.cholesky
solver.LU
solver.pinv
solver.pinv_smooth
solver.solve
solver.svd
......
from .solvers import cholesky, LU, solve, svd, pinv_smooth
from .solvers import cholesky, LU, solve, svd, pinv, pinv_smooth
from netket.utils import _hide_submodules
......
......@@ -36,6 +36,19 @@ def pinv_smooth(A, b, rcond=1e-14, rcond_smooth=1e-14, x0=None):
\tilde\lambda_i^{-1}=\frac{\lambda_i^{-1}}{1+\big(\epsilon\frac{\lambda_\text{max}}{\lambda_i}\big)^6}
.. note::
In general, we found that this custom implementation of
the pseudo-inverse outperform
jax's :func:`~jax.numpy.linalg.pinv`. This might be
because :func:`~jax.numpy.linalg.pinv` internally calls
:obj:`~jax.numpy.linalg.svd`, while this solver internally
uses :obj:`~jax.numpy.linalg.eigh`.
For that reason, we suggest you use this solver instead of
:obj:`~netket.optimizer.solver.pinv`.
Args:
A: LinearOperator (matrix)
b: vector or Pytree
......@@ -67,6 +80,49 @@ def pinv_smooth(A, b, rcond=1e-14, rcond_smooth=1e-14, x0=None):
return unravel(x), None
def pinv(A, b, rcond=1e-12, x0=None):
"""
Solve the linear system using jax's implementation of the
pseudo-inverse.
Internally it calls :ref:`~jax.numpy.linalg.pinv` which
uses a :ref:`~jax.numpy.linalg.svd` decomposition with
the same value of **rcond**.
.. note::
In general, we found that our custom implementation of
the pseudo-inverse
:func:`netket.optimizer.solver.pinv_smooth` (which
internally uses hermitian diagonaliation) outperform
jax's :ref:`~jax.numpy.linalg.pinv`.
For that reason, we suggest to use
:func:`~netket.optimizer.solver.pinv_smooth` instead of
:obj:`~netket.optimizer.solver.pinv`.
The diagonal shift on the matrix can be 0 and the
**rcond** variable can be used to truncate small
eigenvalues.
Args:
A: the matrix A in Ax=b
b: the vector b in Ax=b
rcond: The condition number
"""
del x0
A = A.to_dense()
b, unravel = tree_ravel(b)
x, residuals, rank, s = jnp.linalg.lstsq(A, b, rcond=rcond)
A_inv = jnp.linalg.pinv(A, rcond=rcond, hermitian=True)
x = jnp.dot(A_inv, b)
return unravel(x), None
def svd(A, b, rcond=None, x0=None):
"""
Solve the linear system using Singular Value Decomposition.
......
......@@ -39,7 +39,8 @@ solvers["svd"] = nk.optimizer.solver.svd
solvers["cholesky"] = nk.optimizer.solver.cholesky
solvers["LU"] = nk.optimizer.solver.LU
solvers["solve"] = nk.optimizer.solver.solve
solvers["eigh"] = nk.optimizer.solver.pinv_smooth
solvers["pinv"] = nk.optimizer.solver.pinv
solvers["pinv_smooth"] = nk.optimizer.solver.pinv_smooth
dtypes = {"float": float, "complex": complex}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment