How to use NumpySampler with VMC.estimate (or MCState.expect)?
Created by: SemyonSinchenko
Hello!
Trying the VMC.estimate
with the MetropolisSamplerNumpy
i got an error:
File "main.py", line 85, in <module>
res = model.get_results()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
[
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
return tree_map(self._estimate_stats, observables)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
return self.state.expect(observable)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
mels,
jax._src.traceback_util.FilteredStackTrace: TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "main.py", line 85, in <module>
res = model.get_results()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
[
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
return tree_map(self._estimate_stats, observables)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/tree_util.py", line 162, in tree_map
return treedef.unflatten(map(f, leaves))
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
return self.state.expect(observable)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
mels,
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 330, in cache_miss
_check_arg(arg)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 2209, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.
Could you please help what is the right way to use estimate
(or MCState.expect
) methods with Numpy samplers?
NetKet version:
3.0b1.post5
P.S. Error only in estimate
. VMC.run
works correctly!
P.P.S. With JAX
samplers it works correctly too but results are strange and significantly worse than results from NetKet
2.1 and same parameters...