Chunking for expect/expect_and_grad
Created by: PhilipVinc
This is an alternative implementation to #830, taking into account some of the feedback that was raised in that PR.
The implementation is much simpler and builds upon a nk.jax.vmap_batched
that is semantically equivalent to vmap
but uses scan under the hood to evaluate one batch at the time.
The limitation of minibatch size being a multiple of the number of chains is also relaxed. Now minibatch size should be an integer divisor of total number of samples per mpi rank, and attempts to achieve that by increasing the chain length. This requirement is not mandatory and could be relaxed, and i agree that if you have a prime number of chains it might blow up the chain length, however respecting it leads to much shorter compilation times and slightly better performance. So I am open to alternatives.
I'd like some feedback and someone to benchmark this a bit (I do not have the bandwith right now).
TODO: Make all operator types work with batching, make vmaop_batched
autodiffable (@inailuig). docs
Credit for a lot of the batching logic goes to @inailuig who definitely was not tied to a chair and obliged to do this under duress. This PR also contains a commit from his PR #912