Source code for netket.models.autoreg

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from math import sqrt
from typing import Any, Callable, Iterable, Tuple, Union

import jax
from flax import linen as nn
from jax import numpy as jnp
from jax.nn.initializers import zeros
from plum import dispatch

from netket.hilbert import Fock, Qubit, Spin
from netket.hilbert.homogeneous import HomogeneousHilbert
from netket.nn import MaskedConv1D, MaskedConv2D, MaskedDense1D
from netket.nn.masked_linear import default_kernel_init
from netket.utils.types import Array, DType, NNInitFunc


[docs]class AbstractARNN(nn.Module): """ Base class for autoregressive neural networks. Subclasses must implement the methods `__call__` and `conditionals`. They can also override `_conditional` to implement the caching for fast autoregressive sampling. See :ref:`netket.nn.FastARNNConv1D` for example. They must also implement the field `machine_pow`, which specifies the exponent to normalize the outputs of `__call__`. """ hilbert: HomogeneousHilbert """the Hilbert space. Only homogeneous unconstrained Hilbert spaces are supported.""" # machine_pow: int = 2 Must be defined on subclasses def __post_init__(self): super().__post_init__() if not isinstance(self.hilbert, HomogeneousHilbert): raise ValueError( f"Only homogeneous Hilbert spaces are supported by ARNN, but hilbert is a {type(self.hilbert)}." ) if self.hilbert.constrained: raise ValueError("Only unconstrained Hilbert spaces are supported by ARNN.") def _conditional(self, inputs: Array, index: int) -> Array: """ Computes the conditional probabilities for a site to take a given value. It should only be called successively with indices 0, 1, 2, ..., as in the autoregressive sampling procedure. Args: inputs: configurations with dimensions (batch, Hilbert.size). index: index of the site. Returns: The probabilities with dimensions (batch, Hilbert.local_size). """ return self.conditionals(inputs)[:, index, :]
[docs] @abc.abstractmethod def conditionals(self, inputs: Array) -> Array: """ Computes the conditional probabilities for each site to take each value. Args: inputs: configurations with dimensions (batch, Hilbert.size). Returns: The probabilities with dimensions (batch, Hilbert.size, Hilbert.local_size). Examples: >>> import pytest; pytest.skip("skip automated test of this docstring") >>> >>> p = model.apply(variables, σ, method=model.conditionals) >>> print(p[2, 3, :]) [0.3 0.7] # For the 3rd spin of the 2nd sample in the batch, # it takes probability 0.3 to be spin down (local state index 0), # and probability 0.7 to be spin up (local state index 1). """
[docs]class ARNNDense(AbstractARNN): """Autoregressive neural network with dense layers.""" layers: int """number of layers.""" features: Union[Iterable[int], int] """number of features in each layer. If a single number is given, all layers except the last one will have the same number of features.""" activation: Callable[[Array], Array] = jax.nn.selu """the nonlinear activation function between hidden layers (default: selu).""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the weights.""" bias_init: NNInitFunc = zeros """initializer for the biases.""" machine_pow: int = 2 """exponent to normalize the outputs of `__call__`.""" def setup(self): if isinstance(self.features, int): features = [self.features] * (self.layers - 1) + [self.hilbert.local_size] else: features = self.features assert len(features) == self.layers assert features[-1] == self.hilbert.local_size self._layers = [ MaskedDense1D( features=features[i], exclusive=(i == 0), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for i in range(self.layers) ]
[docs] def conditionals(self, inputs: Array) -> Array: return _conditionals(self, inputs)
def __call__(self, inputs: Array) -> Array: return _call(self, inputs)
[docs]class ARNNConv1D(AbstractARNN): """Autoregressive neural network with 1D convolution layers.""" layers: int """number of layers.""" features: Union[Iterable[int], int] """number of features in each layer. If a single number is given, all layers except the last one will have the same number of features.""" kernel_size: int """length of the convolutional kernel.""" kernel_dilation: int = 1 """dilation factor of the convolution kernel (default: 1).""" activation: Callable[[Array], Array] = jax.nn.selu """the nonlinear activation function between hidden layers (default: selu).""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the weights.""" bias_init: NNInitFunc = zeros """initializer for the biases.""" machine_pow: int = 2 """exponent to normalize the outputs of `__call__`.""" def setup(self): if isinstance(self.features, int): features = [self.features] * (self.layers - 1) + [self.hilbert.local_size] else: features = self.features assert len(features) == self.layers assert features[-1] == self.hilbert.local_size self._layers = [ MaskedConv1D( features=features[i], kernel_size=self.kernel_size, kernel_dilation=self.kernel_dilation, exclusive=(i == 0), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for i in range(self.layers) ]
[docs] def conditionals(self, inputs: Array) -> Array: return _conditionals(self, inputs)
def __call__(self, inputs: Array) -> Array: return _call(self, inputs)
[docs]class ARNNConv2D(AbstractARNN): """Autoregressive neural network with 2D convolution layers.""" layers: int """number of layers.""" features: Union[Iterable[int], int] """number of features in each layer. If a single number is given, all layers except the last one will have the same number of features.""" kernel_size: Tuple[int, int] """shape of the convolutional kernel `(h, w)`. Typically, `h = w // 2 + 1`.""" kernel_dilation: Tuple[int, int] = (1, 1) """a sequence of 2 integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1).""" activation: Callable[[Array], Array] = jax.nn.selu """the nonlinear activation function between hidden layers (default: selu).""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the weights.""" bias_init: NNInitFunc = zeros """initializer for the biases.""" machine_pow: int = 2 """exponent to normalize the outputs of `__call__`.""" def setup(self): self.L = int(sqrt(self.hilbert.size)) assert self.L**2 == self.hilbert.size if isinstance(self.features, int): features = [self.features] * (self.layers - 1) + [self.hilbert.local_size] else: features = self.features assert len(features) == self.layers assert features[-1] == self.hilbert.local_size self._layers = [ MaskedConv2D( features=features[i], kernel_size=self.kernel_size, kernel_dilation=self.kernel_dilation, exclusive=(i == 0), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for i in range(self.layers) ]
[docs] def conditionals(self, inputs: Array) -> Array: return _conditionals(self, inputs)
def __call__(self, inputs: Array) -> Array: return _call(self, inputs)
def _normalize(log_psi: Array, machine_pow: int) -> Array: """ Normalizes log_psi to have L2-norm 1 along the last axis. """ return log_psi - 1 / machine_pow * jax.scipy.special.logsumexp( machine_pow * log_psi.real, axis=-1, keepdims=True ) def _conditionals_log_psi(model: AbstractARNN, inputs: Array) -> Array: """ Computes the log of the conditional wave-functions for each site if it takes each value. See `AbstractARNN.conditionals`. """ inputs = _reshape_inputs(model, inputs) x = jnp.expand_dims(inputs, axis=-1) for i in range(model.layers): if i > 0: x = model.activation(x) x = model._layers[i](x) x = x.reshape((x.shape[0], -1, x.shape[-1])) log_psi = _normalize(x, model.machine_pow) return log_psi def _conditionals(model: AbstractARNN, inputs: Array) -> Array: """ Computes the conditional probabilities for each site to take each value. See `AbstractARNN.conditionals`. """ if inputs.ndim == 1: inputs = jnp.expand_dims(inputs, axis=0) log_psi = _conditionals_log_psi(model, inputs) p = jnp.exp(model.machine_pow * log_psi.real) return p def _call(model: AbstractARNN, inputs: Array) -> Array: """Returns log_psi.""" if inputs.ndim == 1: inputs = jnp.expand_dims(inputs, axis=0) idx = _local_states_to_numbers(model.hilbert, inputs) idx = jnp.expand_dims(idx, axis=-1) log_psi = _conditionals_log_psi(model, inputs) log_psi = jnp.take_along_axis(log_psi, idx, axis=-1) log_psi = log_psi.reshape((inputs.shape[0], -1)).sum(axis=1) return log_psi @dispatch def _reshape_inputs(model: ARNNConv2D, inputs: Array) -> Array: # noqa: F811 return inputs.reshape((inputs.shape[0], model.L, model.L)) @dispatch def _reshape_inputs(model: Any, inputs: Array) -> Array: # noqa: F811 return inputs @dispatch def _local_states_to_numbers(hilbert: Spin, x: Array) -> Array: # noqa: F811 numbers = (x + hilbert.local_size - 1) / 2 numbers = jnp.asarray(numbers, jnp.int32) return numbers @dispatch def _local_states_to_numbers( # noqa: F811 hilbert: Union[Fock, Qubit], x: Array ) -> Array: numbers = jnp.asarray(x, jnp.int32) return numbers @dispatch def _local_states_to_numbers(hilbert: Any, x: Array) -> Array: # noqa: F811 raise NotImplementedError( f"_local_states_to_numbers is not implemented for hilbert {type(hilbert)}." )