Skip to content

Switch the order of axes in the output of the samplers

Created by: inailuig

Previously the samplers were outputting arrays of shape (n_samples_per_chain, n_chains, ...), an artefact of the jax scan being used to do the sampling, which was inconsistent with e.g. nk.stats.statistics which expects (n_chains, n_samples_per_chain, ...). This PR swaps the order of the axes in the output of the samplers to (n_chains, n_samples_per_chain, ...), so that we have a consistent order everywhere.

Merge request reports