How to use NumpySampler with VMC.estimate (or MCState.expect)?
*Created by: SemyonSinchenko* Hello! Trying the `VMC.estimate` with the `MetropolisSamplerNumpy` i got an error: ```python File "main.py", line 85, in <module> res = model.get_results() File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results [ File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate return tree_map(self._estimate_stats, observables) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats return self.state.expect(observable) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect mels, jax._src.traceback_util.FilteredStackTrace: TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type. The stack trace above excludes JAX-internal frames. The following is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "main.py", line 85, in <module> res = model.get_results() File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results [ File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate return tree_map(self._estimate_stats, observables) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/tree_util.py", line 162, in tree_map return treedef.unflatten(map(f, leaves)) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats return self.state.expect(observable) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect mels, File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 330, in cache_miss _check_arg(arg) File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 2209, in _check_arg raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.") TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type. ``` Could you please help what is the right way to use `estimate` (or `MCState.expect`) methods with Numpy samplers? NetKet version: `3.0b1.post5` P.S. Error only in `estimate`. `VMC.run` works correctly! P.P.S. With `JAX` samplers it works correctly too but results are strange and significantly worse than results from `NetKet` 2.1 and same parameters...
issue