Skip to content

Chunking for expect/grad

Vicentini Filippo requested to merge pv/chunking into master

Created by: PhilipVinc

This is #918 rebased on master

i'd like one last comment from @inailuig and @attila-i-szabo on the chunk size logic. We have the limitation that if chunk_size < n_samples it should divide evenly n_samples (getting rid of this limitation can be done, but is annoying and not a priority now). So if chunk_size > n_samples any value is accepted, if chunk_size < n_samples then an error is thrown if it is not an even divider.


This is an alternative implementation to #830, taking into account some of the feedback that was raised in that PR.

The implementation is much simpler and builds upon a nk.jax.vmap_batched that is semantically equivalent to vmap but uses scan under the hood to evaluate one batch at the time.

I'd like some feedback and someone to benchmark this a bit (I do not have the bandwith right now).

TODO: Make all operator types work with batching, make vmaop_batched autodiffable (@inailuig). docs

Credit for a lot of the batching logic goes to @inailuig who definitely was not tied to a chair and obliged to do this under duress. This PR also contains a commit from his PR #912

Merge request reports