Skip to content
Snippets Groups Projects
user avatar
Clemens Giuliani authored
Previously the samplers were outputting arrays of shape
`(n_samples_per_chain, n_chains, ...)`, an artefact of the jax scan
being used to do the sampling, which was inconsistent with e.g.
`nk.stats.statistics` which expects `(n_chains, n_samples_per_chain,
...)`.
This PR swaps the order of the axes in the output of the samplers to
`(n_chains, n_samples_per_chain, ...)`, so that we have a consistent
order everywhere.

---------

Co-authored-by: default avatarFilippo Vicentini <filippovicentini@gmail.com>
5ccd622e
History
Name Last commit Last update