Skip to content

Automatic Backend Dispatch for Sampler

Vicentini Filippo requested to merge more_jax_samplers into v3.0

Created by: gcarleo

I am introducing an automatic dispatch mechanics for samplers, so that it is not necessary to do things like sampler.jax.MetropolisLocal. If one calls samplers.MetropolisLocal with a Jax machine the correct sampler is automatically dispatched.

The mechanism is based on functools.singledispatch, and it basically amounts to declaring transition kernels in netket/sampler/_kernels.py. Registering a new kernel for a given backend is a localized operation that does not require changing files besides those in the dedicated module folder.

Also:

I am adding sampler.jax.MetropolisExchange

The numpy version of ExactSampler works already fine with jax machines, so there is no need to rewrite it. The other samplers based on operators (CustomSampler and HamiltonianSampler) are trickier to implement in Jax, unless we define native local operators... we should maybe do that at some point, but definitely not for 3.0)

Merge request reports