[RFC] Make n_chains set the total number of chains across all MPI processes
Created by: PhilipVinc
Recently we had several people confused by the fact that MPI does not particularly improve performance. There are two issues:
- They don't read the documentation (and we don't have a page on MPI)
- Our n_chains is a rank-local property, so that if you increase number of MPI ranks you get more chains. However the number of samples is kept fixed.
Point 1) can be solved with better docs.
Point 2) is about inconsistency with the way we set n_samples. I propose to change the bahviour of n_chains so that it sets the number of chains globally according to the formula
n_chains_per_rank = n_chains_per_rank = max(
int(np.ceil(n_chains / mpi.n_nodes)), 1
)
One can still specify n_chains_per_rank
if he so desires.
This is just a skeleton implementaiton (though it should mostly work). As fixing tests everywhere to use everywhere n_chains_per_rank instead of n_chains will take some time, i'll finish this PR only if we get consensus on this.
Note that it will be a fairly breaking change in the behaviour (though it won't technically break code)