Batched stuff
Created by: inailuig
This PR contains several utilities that go in nk.jax
that replace standard jax
functionality in order to reduce memory consumption through batching.
Those are mainly nk.jax.vjp_batched
, nk.jax.vmap_batched
(and nk.jax.scan_append/accum
on top of which those are built).
Both of those support a subset of the original jax
API, plus a new argument, batch_size
, and they will essentially loop through the input data in blocks of batch_size
-elements along the batch_axis
argument.
This bounds the memory consumption. Some rough support for AD through those new operations is implemented but it is not complete.
Those APIs are highly experimental, use them at your own risk.
They will be used to implement batching of expect
and of QGT
.
Ideally one day we will move them to another package, like utils4jax
.