diff --git a/Test/Optimizer/test_sr.py b/Test/Optimizer/test_sr.py index 113c5a54a5080506a360a1252badbbd15b878645..47d3a564625152b3f108b5bbea1674e9de941b25 100644 --- a/Test/Optimizer/test_sr.py +++ b/Test/Optimizer/test_sr.py @@ -23,23 +23,3 @@ def test_svd_threshold(): match="The svd_threshold option is available only for non-sparse solvers.", ): SR(use_iterative=True, svd_threshold=1e-3) - - a = np.diag([1e0 + 0j, 1e-3, 1e-6]) - b = np.array([1.0 + 0j, 1.0, 1.0]) - - def SR_with_threshold(t): - return SR(lsq_solver="SVD", svd_threshold=t, diag_shift=0, is_holomorphic=True) - - def solve(sr, a, b): - a1 = np.sqrt(a) * np.sqrt(a.shape[0]) - out = sr.compute_update(a1, b) - return out - - sr = SR_with_threshold(1e-1) - assert np.allclose(solve(sr, a, b), [1.0, 0.0, 0.0]) - - sr = SR_with_threshold(1e-4) - assert np.allclose(solve(sr, a, b), [1.0, 1e3, 0.0]) - - sr = SR_with_threshold(1e-7) - assert np.allclose(solve(sr, a, b), [1.0, 1e3, 1e6]) diff --git a/netket/vmc_common.py b/netket/vmc_common.py index 444cb79c83ae04ada6a05a484c1a06dbd1bad5be..a21c941aac904739924506f44e15d34c0c49599d 100644 --- a/netket/vmc_common.py +++ b/netket/vmc_common.py @@ -25,7 +25,6 @@ if jax_available: from jax.tree_util import tree_flatten, tree_unflatten, tree_map import jax.numpy as jnp - def shape_for_sr(grads, jac): r"""Reshapes grads and jax from tree like structures to arrays if jax_available @@ -43,7 +42,6 @@ if jax_available: return grads, jac - def shape_for_update(update, shape_like): r"""Reshapes grads from array to tree like structure if neccesary for update