Switch the order of axes in the output of the samplers
Created by: inailuig
Previously the samplers were outputting arrays of shape (n_samples_per_chain, n_chains, ...)
, an artefact of the jax scan being used to do the sampling, which was inconsistent with e.g. nk.stats.statistics
which expects (n_chains, n_samples_per_chain, ...)
.
This PR swaps the order of the axes in the output of the samplers to (n_chains, n_samples_per_chain, ...)
, so that we have a consistent order everywhere.