Switch the order of axes in the output of the samplers (#1502)
Previously the samplers were outputting arrays of shape
`(n_samples_per_chain, n_chains, ...)`, an artefact of the jax scan
being used to do the sampling, which was inconsistent with e.g.
`nk.stats.statistics` which expects `(n_chains, n_samples_per_chain,
...)`.
This PR swaps the order of the axes in the output of the samplers to
`(n_chains, n_samples_per_chain, ...)`, so that we have a consistent
order everywhere.
---------
Co-authored-by:
Filippo Vicentini <filippovicentini@gmail.com>
Showing
- CHANGELOG.md 3 additions, 1 deletionCHANGELOG.md
- netket/experimental/sampler/metropolis_pmap.py 2 additions, 2 deletionsnetket/experimental/sampler/metropolis_pmap.py
- netket/jax/_expect.py 2 additions, 2 deletionsnetket/jax/_expect.py
- netket/sampler/autoreg.py 2 additions, 2 deletionsnetket/sampler/autoreg.py
- netket/sampler/base.py 1 addition, 1 deletionnetket/sampler/base.py
- netket/sampler/exact.py 2 additions, 2 deletionsnetket/sampler/exact.py
- netket/sampler/metropolis.py 2 additions, 1 deletionnetket/sampler/metropolis.py
- netket/sampler/metropolis_numpy.py 3 additions, 0 deletionsnetket/sampler/metropolis_numpy.py
- netket/vqs/mc/mc_state/expect.py 5 additions, 6 deletionsnetket/vqs/mc/mc_state/expect.py
- netket/vqs/mc/mc_state/expect_forces.py 4 additions, 4 deletionsnetket/vqs/mc/mc_state/expect_forces.py
- netket/vqs/mc/mc_state/expect_grad.py 4 additions, 4 deletionsnetket/vqs/mc/mc_state/expect_grad.py
- netket/vqs/mc/mc_state/state.py 1 addition, 3 deletionsnetket/vqs/mc/mc_state/state.py
- test/sampler/test_sampler.py 4 additions, 4 deletionstest/sampler/test_sampler.py
- test/stats/test_stats.py 2 additions, 2 deletionstest/stats/test_stats.py
- test/variational/test_variational.py 6 additions, 6 deletionstest/variational/test_variational.py
- test/variational/test_variational_mixed.py 14 additions, 14 deletionstest/variational/test_variational_mixed.py
Loading
Please register or sign in to comment