Chunking for expect/grad
Created by: PhilipVinc
This is #918 rebased on master
i'd like one last comment from @inailuig and @attila-i-szabo on the chunk size logic.
We have the limitation that if chunk_size < n_samples
it should divide evenly n_samples
(getting rid of this limitation can be done, but is annoying and not a priority now). So if chunk_size > n_samples
any value is accepted, if chunk_size < n_samples
then an error is thrown if it is not an even divider.
This is an alternative implementation to #830, taking into account some of the feedback that was raised in that PR.
The implementation is much simpler and builds upon a nk.jax.vmap_batched that is semantically equivalent to vmap but uses scan under the hood to evaluate one batch at the time.
I'd like some feedback and someone to benchmark this a bit (I do not have the bandwith right now).
TODO: Make all operator types work with batching, make vmaop_batched autodiffable (@inailuig). docs
Credit for a lot of the batching logic goes to @inailuig who definitely was not tied to a chair and obliged to do this under duress. This PR also contains a commit from his PR #912