Skip to content

Fast autoregressive sampling

Vicentini Filippo requested to merge github/fork/wdphy16/autoreg into master

Created by: wdphy16

The fast autoregressive sampling is described in Ramachandran et. al. To generate one sample using an autoregressive network, we need to evaluate the network N times, where N is the number of input sites. But we only change one input site each time, so we can cache unchanged intermediate results and avoid repeated computation.

I didn't find it implemented in JAX before, so I'm trying to implement it using Flax's concept of module variables.

Changes to the public API of netket.nn.ARNN:

  • conditionals does not take cache as an argument, and does not return the updated cache. Instead, cache is passed in variables when calling apply.

Things to do:

  • VMC training.
  • Remove σ and cache in ARDirectSamplerState. The cache is only relevant inside the autoregressive sampling procedure. Outside that procedure (e.g. when evaluating the observable), we only need configurations and psi, not the cache.
  • Do some profiling and make it fast.
  • Add an example with time comparison.
  • Implement 2D conv.
  • FastMaskedDense1D does not support JIT yet, because it involves slicing the cached inputs and the weights with a dynamic shape. Maybe there is a way to statically JIT it for each index. But it's not a big problem, because the theoretical speedup of fast autoregressive sampling with dense layers is much smaller than that with conv layers, and I guess no one really uses dense layers here.

Merge request reports