Skip to content

How to use NumpySampler with VMC.estimate (or MCState.expect)?

Created by: SemyonSinchenko

Hello! Trying the VMC.estimate with the MetropolisSamplerNumpy i got an error:

  File "main.py", line 85, in <module>
    res = model.get_results()
  File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
    [
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
    return tree_map(self._estimate_stats, observables)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
    return self.state.expect(observable)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
    mels,
jax._src.traceback_util.FilteredStackTrace: TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "main.py", line 85, in <module>
    res = model.get_results()
  File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 95, in get_results
    [
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 295, in estimate
    return tree_map(self._estimate_stats, observables)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/tree_util.py", line 162, in tree_map
    return treedef.unflatten(map(f, leaves))
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 109, in _estimate_stats
    return self.state.expect(observable)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/variational/mc_state.py", line 407, in expect
    mels,
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 330, in cache_miss
    _check_arg(arg)
  File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 2209, in _check_arg
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument 'MetropolisSamplerNumpy(rule = LocalRuleNumpy(), n_chains = 64, machine_power = 2, n_sweeps = 8, dtype = <class 'numpy.float32'>)' of type <class 'netket.sampler.metropolis_numpy.MetropolisSamplerNumpy'> is not a valid JAX type.

Could you please help what is the right way to use estimate (or MCState.expect) methods with Numpy samplers? NetKet version: 3.0b1.post5

P.S. Error only in estimate. VMC.run works correctly! P.P.S. With JAX samplers it works correctly too but results are strange and significantly worse than results from NetKet 2.1 and same parameters...