Source code for netket.vqs.mc.mc_state.state

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np

import jax
from jax import numpy as jnp

import flax
from flax import serialization

from netket import jax as nkjax
from netket import nn
from netket.stats import Stats
from netket.operator import AbstractOperator
from netket.sampler import Sampler, SamplerState
from netket.utils import (
    maybe_wrap_module,
    deprecated,
    warn_deprecation,
    mpi,
    wrap_afun,
    wrap_to_support_scalar,
)
from netket.utils.types import PyTree, SeedT, NNInitFunc
from netket.optimizer import LinearOperator
from netket.optimizer.qgt import QGTAuto

from netket.vqs.base import VariationalState, expect, expect_and_grad


def compute_chain_length(n_chains, n_samples):
    if n_samples <= 0:
        raise ValueError("Invalid number of samples: n_samples={}".format(n_samples))

    chain_length = int(np.ceil(n_samples / n_chains))

    n_samples_new = chain_length * n_chains
    n_samples_per_rank_new = n_samples_new // mpi.n_nodes

    if n_samples_new != n_samples:
        n_samples_per_rank = n_samples // mpi.n_nodes
        warnings.warn(
            f"n_samples={n_samples} ({n_samples_per_rank} per MPI rank) does not "
            f"divide n_chains={n_chains}, increased to {n_samples_new} "
            f"({n_samples_per_rank_new} per MPI rank)"
        )

    return chain_length


def check_chunk_size(n_samples, chunk_size):
    n_samples_per_rank = n_samples // mpi.n_nodes

    if chunk_size is not None:
        if chunk_size < n_samples_per_rank and n_samples_per_rank % chunk_size != 0:
            raise ValueError(
                f"chunk_size={chunk_size}`<`n_samples_per_rank={n_samples_per_rank}, "
                "chunk_size is not an integer fraction of `n_samples_per rank`. This is"
                "unsupported. Please change `chunk_size` so that it divides evenly the"
                "number of samples per rank or set it to `None` to disable chunking."
            )


def _is_power_of_two(n: int) -> bool:
    return (n != 0) and (n & (n - 1) == 0)


@partial(jax.jit, static_argnums=0)
def jit_evaluate(fun: Callable, *args):
    """
    call `fun(*args)` inside of a `jax.jit` frame.

    Args:
        fun: the hashable callable to be evaluated.
        args: the arguments to the function.
    """
    return fun(*args)


class MCState(VariationalState):
    """Variational State for a Variational Neural Quantum State.

    The state is sampled according to the provided sampler.
    """

    # model: Any
    # """The model"""
    model_state: Optional[PyTree]
    """An Optional PyTree encoding a mutable state of the model that is not trained."""

    _sampler: Sampler
    """The sampler used to sample the Hilbert space."""
    sampler_state: SamplerState
    """The current state of the sampler."""

    _chain_length: int = 0
    """Length of the Markov chain used for sampling configurations."""
    _n_discard_per_chain: int = 0
    """Number of samples discarded at the beginning of every Markov chain."""

    _samples: Optional[jnp.ndarray] = None
    """Cached samples obtained with the last sampling."""

    _init_fun: Callable = None
    """The function used to initialise the parameters and model_state."""
    _apply_fun: Callable = None
    """The function used to evaluate the model."""

    _chunk_size: Optional[int] = None

[docs] def __init__( self, sampler: Sampler, model=None, *, n_samples: int = None, n_samples_per_rank: Optional[int] = None, n_discard: Optional[int] = None, # deprecated n_discard_per_chain: Optional[int] = None, chunk_size: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Args: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler, and n_samples/10 for approximate sampler). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model ) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs), model, ) ) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?" ) # default argument for n_samples/n_samples_per_rank if n_samples is None and n_samples_per_rank is None: # get the first multiple of sampler.n_chains above 1000 to avoid # printing a warning on construction n_samples = int(np.ceil(1000 / sampler.n_chains) * sampler.n_chains) elif n_samples is not None and n_samples_per_rank is not None: raise ValueError( "Only one argument between `n_samples` and `n_samples_per_rank`" "can be specified at the same time." ) if n_discard is not None and n_discard_per_chain is not None: raise ValueError( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Specify only `n_discard_per_chain`." ) elif n_discard is not None: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`." ) n_discard_per_chain = n_discard if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler if n_samples is not None: self.n_samples = n_samples else: self.n_samples_per_rank = n_samples_per_rank self.n_discard_per_chain = n_discard_per_chain self.chunk_size = chunk_size
[docs] def init(self, seed=None, dtype=None): """ Initialises the variational parameters of the variational state. """ if self._init_fun is None: raise RuntimeError( "Cannot initialise the parameters of this state" "because you did not supply a valid init_function." ) if dtype is None: dtype = self.sampler.dtype key = nkjax.PRNGKey(seed) dummy_input = jnp.zeros((1, self.hilbert.size), dtype=dtype) variables = jit_evaluate(self._init_fun, {"params": key}, dummy_input) self.variables = variables
@property def model(self) -> Optional[Any]: """Returns the model definition of this variational state. This field is optional, and is set to `None` if the variational state has been initialized using a custom function. """ return self._model @property def sampler(self) -> Sampler: """The Monte Carlo sampler used by this Monte Carlo variational state.""" return self._sampler @sampler.setter def sampler(self, sampler: Sampler): if not isinstance(sampler, Sampler): raise TypeError( "The sampler should be a subtype of netket.sampler.Sampler, but {} is not.".format( type(sampler) ) ) # Save the old `n_samples` before the new `sampler` is set. # `_chain_length == 0` means that this `MCState` is being constructed. if self._chain_length > 0: n_samples_old = self.n_samples self._sampler = sampler self.sampler_state = self.sampler.init_state( self.model, self.variables, seed=self._sampler_seed ) # Update `n_samples`, `n_samples_per_rank`, and `chain_length` according # to the new `sampler.n_chains`. # If `n_samples` is divisible by the new `sampler.n_chains`, it will be # unchanged; otherwise it will be rounded up. # If the new `n_samples_per_rank` is not divisible by `chunk_size`, a # `ValueError` will be raised. # `_chain_length == 0` means that this `MCState` is being constructed. if self._chain_length > 0: self.n_samples = n_samples_old self.reset() @property def n_samples(self) -> int: """The total number of samples generated at every sampling step.""" return self.chain_length * self.sampler.n_chains @n_samples.setter def n_samples(self, n_samples: int): chain_length = compute_chain_length(self.sampler.n_chains, n_samples) self.chain_length = chain_length @property def n_samples_per_rank(self) -> int: """The number of samples generated on one MPI rank at every sampling step.""" return self.chain_length * self.sampler.n_chains_per_rank @n_samples_per_rank.setter def n_samples_per_rank(self, n_samples_per_rank: int): self.n_samples = n_samples_per_rank * mpi.n_nodes @property def chain_length(self) -> int: """ Length of the markov chain used for sampling configurations. If running under MPI, the total samples will be n_nodes * chain_length * n_batches. """ return self._chain_length @chain_length.setter def chain_length(self, chain_length: int): if chain_length <= 0: raise ValueError(f"Invalid chain length: chain_length={chain_length}") n_samples = chain_length * self.sampler.n_chains check_chunk_size(n_samples, self.chunk_size) self._chain_length = chain_length self.reset() @property def n_discard_per_chain(self) -> int: """ Number of discarded samples at the beginning of the markov chain. """ return self._n_discard_per_chain @n_discard_per_chain.setter def n_discard_per_chain(self, n_discard_per_chain: Optional[int]): if n_discard_per_chain is not None and n_discard_per_chain < 0: raise ValueError( "Invalid number of discarded samples: n_discard_per_chain={}".format( n_discard_per_chain ) ) # don't discard if the sampler is exact if self.sampler.is_exact: if n_discard_per_chain is not None and n_discard_per_chain > 0: warnings.warn( "An exact sampler does not need to discard samples. Setting n_discard_per_chain to 0." ) n_discard_per_chain = 0 self._n_discard_per_chain = ( int(n_discard_per_chain) if n_discard_per_chain is not None else self.n_samples // 10 ) # TODO: deprecate @property def n_discard(self) -> int: """ DEPRECATED: Use `n_discard_per_chain` instead. Number of discarded samples at the beginning of the markov chain. """ warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`." ) return self.n_discard_per_chain @n_discard.setter def n_discard(self, val) -> int: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`." ) self.n_discard_per_chain = val @property def chunk_size(self) -> int: """ Suggested *maximum size* of the chunks used in forward and backward evaluations of the Neural Network model. If your inputs are smaller than the chunk size this setting is ignored. This can be used to lower the memory required to run a computation with a very high number of samples or on a very large lattice. Notice that inputs and outputs must still fit in memory, but the intermediate computations will now require less memory. This option comes at an increased computational cost. While this cost should be negligible for large-enough chunk sizes, don't use it unless you are memory bound! This option is an hint: only some operations support chunking. If you perform an operation that is not implemented with chunking support, it will fall back to no chunking. To check if this happened, set the environment variable `NETKET_DEBUG=1`. """ return self._chunk_size @chunk_size.setter def chunk_size(self, chunk_size: Optional[int]): # disable chunks if it is None if chunk_size is None: self._chunk_size = None return if chunk_size <= 0: raise ValueError("Chunk size must be a positive integer. ") if not _is_power_of_two(chunk_size): warnings.warn( "For performance reasons, we suggest to use a power-of-two chunk size." ) check_chunk_size(self.n_samples, chunk_size) self._chunk_size = chunk_size
[docs] def reset(self): """ Resets the sampled states. This method is called automatically every time that the parameters/state is updated. """ self._samples = None
[docs] def sample( self, *, chain_length: Optional[int] = None, n_samples: Optional[int] = None, n_discard_per_chain: Optional[int] = None, ) -> jnp.ndarray: """ Sample a certain number of configurations. If one among chain_leength or n_samples is defined, that number of samples are gen erated. Otherwise the value set internally is used. Args: chain_length: The length of the markov chains. n_samples: The total number of samples across all MPI ranks. n_discard_per_chain: Number of discarded samples at the beginning of the markov chain. """ if n_samples is None and chain_length is None: chain_length = self.chain_length else: if chain_length is None: chain_length = compute_chain_length(self.sampler.n_chains, n_samples) if self.chunk_size is not None: check_chunk_size(chain_length * self.sampler.n_chains, self.chunk_size) if n_discard_per_chain is None: n_discard_per_chain = self.n_discard_per_chain self.sampler_state = self.sampler.reset( self.model, self.variables, self.sampler_state ) if self.n_discard_per_chain > 0: _, self.sampler_state = self.sampler.sample( self.model, self.variables, state=self.sampler_state, chain_length=n_discard_per_chain, ) self._samples, self.sampler_state = self.sampler.sample( self.model, self.variables, state=self.sampler_state, chain_length=chain_length, ) return self._samples
@property def samples(self) -> jnp.ndarray: """ Returns the set of cached samples. The samples returnede are guaranteed valid for the current state of the variational state. If no cached parameters are available, then they are sampled first and then cached. To obtain a new set of samples either use :ref:`reset` or :ref:`sample`. """ if self._samples is None: self.sample() return self._samples
[docs] def log_value(self, σ: jnp.ndarray) -> jnp.ndarray: """ Evaluate the variational state for a batch of states and returns the logarithm of the amplitude of the quantum state. For pure states, this is :math:`log(<σ|ψ>)`, whereas for mixed states this is :math:`log(<σr|ρ|σc>)`, where ψ and ρ are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as :code:`σr=σ[::,0:N]` and :code:`σc=σ[::,N:]`. Given a batch of inputs (Nb, N), returns a batch of outputs (Nb,). """ return jit_evaluate(self._apply_fun, self.variables, σ)
# override to use chunks
[docs] def expect(self, : AbstractOperator) -> Stats: r"""Estimates the quantum expectation value for a given operator O. In the case of a pure state $\psi$, this is $<O>= <Psi|O|Psi>/<Psi|Psi>$ otherwise for a mixed state $\rho$, this is $<O> = \Tr[\rho \hat{O}/\Tr[\rho]$. Args: Ô: the operator O. Returns: An estimation of the quantum expectation value <O>. """ return expect(self, , self.chunk_size)
# override to use chunks
[docs] def expect_and_grad( self, : AbstractOperator, *, mutable: Optional[Any] = None, use_covariance: Optional[bool] = None, ) -> Tuple[Stats, PyTree]: r"""Estimates both the gradient of the quantum expectation value of a given operator O. Args: Ô: the operator Ô for which we compute the expectation value and it's gradient mutable: Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult `Flax's Module.apply documentation <https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#Module.apply>`_ for a more in-depth exaplanation). use_covariance: whever to use the covariance formula, usually reserved for hermitian operators, ⟨∂logψ Oˡᵒᶜ⟩ - ⟨∂logψ⟩⟨Oˡᵒᶜ⟩ Returns: An estimation of the quantum expectation value <O>. An estimation of the average gradient of the quantum expectation value <O>. """ if mutable is None: mutable = self.mutable return expect_and_grad( self, , use_covariance, self.chunk_size, mutable=mutable )
[docs] @deprecated("Use MCState.log_value(σ) instead.") def evaluate(self, σ: jnp.ndarray) -> jnp.ndarray: """ DEPRECATED: use log_value instead. """ return self.log_value(σ)
[docs] def quantum_geometric_tensor( self, qgt_T: LinearOperator = QGTAuto() ) -> LinearOperator: r"""Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix. Args: qgt_T: the optional type of the quantum geometric tensor. By default it's automatically selected. Returns: nk.optimizer.LinearOperator: A linear operator representing the quantum geometric tensor. """ return qgt_T(self)
[docs] def to_array(self, normalize: bool = True) -> jnp.ndarray: return nn.to_array( self.hilbert, self._apply_fun, self.variables, normalize=normalize )
def __repr__(self): return ( "MCState(" + "\n hilbert = {},".format(self.hilbert) + "\n sampler = {},".format(self.sampler) + "\n n_samples = {},".format(self.n_samples) + "\n n_discard_per_chain = {},".format(self.n_discard_per_chain) + "\n sampler_state = {},".format(self.sampler_state) + "\n n_parameters = {})".format(self.n_parameters) ) def __str__(self): return ( "MCState(" + "hilbert = {}, ".format(self.hilbert) + "sampler = {}, ".format(self.sampler) + "n_samples = {})".format(self.n_samples) ) # serialization def serialize_MCState(vstate): state_dict = { "variables": serialization.to_state_dict(vstate.variables), "sampler_state": serialization.to_state_dict(vstate.sampler_state), "n_samples": vstate.n_samples, "n_discard_per_chain": vstate.n_discard_per_chain, } return state_dict def deserialize_MCState(vstate, state_dict): import copy new_vstate = copy.copy(vstate) new_vstate.reset() new_vstate.variables = serialization.from_state_dict( vstate.variables, state_dict["variables"] ) new_vstate.sampler_state = serialization.from_state_dict( vstate.sampler_state, state_dict["sampler_state"] ) new_vstate.n_samples = state_dict["n_samples"] new_vstate.n_discard_per_chain = state_dict["n_discard_per_chain"] return new_vstate serialization.register_serialization_state( MCState, serialize_MCState, deserialize_MCState, )