Abstract Python class for samplers
Created by: gcarleo
In order to better integrate with other backends (pytorch, tensorflow, etc), we need to define abstract python classes for the common objects and functionalities used by the high-level python drivers of NetKet. This issue is to start a discussion about the APIs of the sampler. Here is a proposal, feedback is appreciated, especially in light of the requirements from other backends. @noamwies @orsharir
from abc import ABC,property,abstractmethod
class Sampler(ABC):
"""
Abstract base class for NetKet samplers.
A `Sampler` generates quantum numbers $$ s = s_1\dots s_N $$ distributed
according to the probability :
$$P(s_1\dots s_N) = F(\Psi(s_1\dots s_N)),$$
where F is an arbitrary function. By default F(X)=|X|^2.
Samplers can generate multiple independent samples at once (see Sweep method).
The current state of the sample (i.e. the value of s=s_1\dots s_N for all
the chains) is stored in self.current_sample.
"""
@abstractmethod
def seed(self,base_seed):
"""Seeds the random number generator used by the Sampler."""
pass
@abstractmethod
def reset(self,init_random):
"""Resets the state of the sampler, including the acceptance rate statistics
and optionally initializing at random the visible units being sampled.
Args:
init_random: bool, optional
If True the quantum numbers (visible units) are initialized at random,
otherwise their value is preserved. (default is False)
"""
pass
@abstractmethod
def sweep(self):
"""Performs a sampling sweep. Typically a single sweep
consists of an extensive number of local moves.
"""
pass
@property
@abstractmethod
def current_sample(self):
"""Returns a 2d numpy array of currently sampled configurations. The size of
the matrix is [N_c,N_v], where N_c is the number of independent chains being
sampled and N_v is the size of visible vector being sampled.
"""
pass
@property
@abstractmethod
def current_state(self):
"""The current sampling state of the sampler. This contains a pair
(visible,log_val) where log_val is the result of machine.log_val(visible)
"""
pass
@property
@abstractmethod
def n_chains(self):
"""Number of independent chains being sampled.
"""
pass
@property
@abstractmethod
def machine(self):
"""The machine (Psi(s)) used for the sampling.
"""
pass
@abstractmethod
def machine_func(self):
"""function(complex): Sets and gets the function F to be used for sampling.
by default $$|\Psi(x)|^2$$ is sampled, however in general $$F(\Psi(v))$$
"""
pass