Fast autoregressive sampling
Created by: wdphy16
The fast autoregressive sampling is described in Ramachandran et. al. To generate one sample using an autoregressive network, we need to evaluate the network N
times, where N
is the number of input sites. But we only change one input site each time, so we can cache unchanged intermediate results and avoid repeated computation.
I didn't find it implemented in JAX before, so I'm trying to implement it using Flax's concept of module variables.
Changes to the public API of netket.nn.ARNN
:
-
conditionals
does not takecache
as an argument, and does not return the updatedcache
. Instead,cache
is passed invariables
when callingapply
.
Things to do:
VMC training.-
RemoveThe cache is only relevant inside the autoregressive sampling procedure. Outside that procedure (e.g. when evaluating the observable), we only need configurations and psi, not the cache.σ
andcache
inARDirectSamplerState
. Do some profiling and make it fast.Add an example with time comparison.Implement 2D conv.-
FastMaskedDense1D
does not support JIT yet, because it involves slicing the cached inputs and the weights with a dynamic shape. Maybe there is a way to statically JIT it for eachindex
. But it's not a big problem, because the theoretical speedup of fast autoregressive sampling with dense layers is much smaller than that with conv layers, and I guess no one really uses dense layers here.