Automatic Backend Dispatch for Sampler
Created by: gcarleo
I am introducing an automatic dispatch mechanics for samplers, so that it is not necessary to do things like sampler.jax.MetropolisLocal
. If one calls samplers.MetropolisLocal
with a Jax machine the correct sampler is automatically dispatched.
The mechanism is based on functools.singledispatch
, and it basically amounts to declaring transition kernels in netket/sampler/_kernels.py
. Registering a new kernel for a given backend is a localized operation that does not require changing files besides those in the dedicated module folder.
Also:
I am adding sampler.jax.MetropolisExchange
The numpy version of ExactSampler
works already fine with jax machines, so there is no need to rewrite it. The other samplers based on operators (CustomSampler
and HamiltonianSampler
) are trickier to implement in Jax, unless we define native local operators... we should maybe do that at some point, but definitely not for 3.0)