How to use NumpySampler with VMC.estimate (or MCState.expect)?
Created by: SemyonSinchenko
Hello!
Trying the VMC.estimate with the MetropolisSamplerNumpy i got an error:
File "main.py", line 85, in <module>
res = model.get_results()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
[
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
return tree_map(self._estimate_stats, observables)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
return self.state.expect(observable)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
mels,
jax._src.traceback_util.FilteredStackTrace: TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "main.py", line 85, in <module>
res = model.get_results()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
[
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
return tree_map(self._estimate_stats, observables)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/tree_util.py", line 162, in tree_map
return treedef.unflatten(map(f, leaves))
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
return self.state.expect(observable)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
mels,
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 330, in cache_miss
_check_arg(arg)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 2209, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.
Could you please help what is the right way to use estimate (or MCState.expect) methods with Numpy samplers?
NetKet version:
3.0b1.post5
P.S. Error only in estimate. VMC.run works correctly!
P.P.S. With JAX samplers it works correctly too but results are strange and significantly worse than results from NetKet 2.1 and same parameters...