Source code for netket.nn.masked_linear

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Tuple

import flax
from flax import linen as nn
from jax import lax
from jax import numpy as jnp
from jax.nn.initializers import lecun_normal, zeros

from netket.utils.types import Array, DType, NNInitFunc

default_kernel_init = lecun_normal()


def wrap_kernel_init(kernel_init, mask):
    """Correction to LeCun normal init."""

    corr = jnp.sqrt(mask.size / mask.sum())

    def wrapped_kernel_init(*args):
        return corr * mask * kernel_init(*args)

    return wrapped_kernel_init


[docs]class MaskedDense1D(nn.Module): """1D linear transformation module with mask for autoregressive NN.""" features: int """number of output features, should be the last dimension.""" exclusive: bool """True if an output element does not depend on the input element at the same index.""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the weight matrix.""" bias_init: NNInitFunc = zeros """initializer for the bias."""
[docs] @nn.compact def __call__(self, inputs: Array) -> Array: """ Applies a masked linear transformation to the inputs. Args: inputs: input data with dimensions (batch, length, features). Returns: The transformed data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) is_single_input = False if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) batch, size, in_features = inputs.shape inputs = inputs.reshape((batch, size * in_features)) mask = jnp.ones((size, size), dtype=self.dtype) mask = jnp.triu(mask, self.exclusive) mask = jnp.kron(mask, jnp.ones((in_features, self.features), dtype=self.dtype)) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, mask), (size * in_features, size * self.features), self.dtype, ) mask = jnp.asarray(mask, dtype) kernel = jnp.asarray(kernel, dtype) y = lax.dot(inputs, mask * kernel, precision=self.precision) y = y.reshape((batch, size, self.features)) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (size, self.features), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
[docs]class MaskedConv1D(nn.Module): """1D convolution module with mask for autoregressive NN.""" features: int """number of convolution filters.""" kernel_size: int """length of the convolutional kernel.""" kernel_dilation: int """dilation factor of the convolution kernel.""" exclusive: bool """True if an output element does not depend on the input element at the same index.""" feature_group_count: int = 1 """if specified, divides the input features into groups (default: 1).""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the convolutional kernel.""" bias_init: NNInitFunc = zeros """initializer for the bias."""
[docs] @nn.compact def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. For 1D convolution, there is not really a mask. We only need to apply appropriate padding. Args: inputs: input data with dimensions (batch, length, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size = self.kernel_size - self.exclusive dilation = self.kernel_dilation is_single_input = False if inputs.ndim == 2: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = ( kernel_size, in_features // self.feature_group_count, self.features, ) kernel = self.param("kernel", self.kernel_init, kernel_shape, self.dtype) kernel = jnp.asarray(kernel, dtype) if self.exclusive: inputs = inputs[:, :-dilation, :] # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_size - (not self.exclusive)) * dilation, 0), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, kernel, window_strides=(1,), padding="VALID", lhs_dilation=(1,), rhs_dilation=(dilation,), dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y
[docs]class MaskedConv2D(nn.Module): """2D convolution module with mask for autoregressive NN.""" features: int """number of convolution filters.""" kernel_size: Tuple[int, int] """shape of the convolutional kernel `(h, w)`. Typically, `h = w // 2 + 1`.""" kernel_dilation: Tuple[int, int] """a sequence of 2 integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel.""" exclusive: bool """True if an output element does not depend on the input element at the same index.""" feature_group_count: int = 1 """if specified, divides the input features into groups (default: 1).""" use_bias: bool = True """whether to add a bias to the output (default: True).""" dtype: DType = jnp.float64 """the dtype of the computation (default: float64).""" precision: Any = None """numerical precision of the computation, see `jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """initializer for the convolutional kernel.""" bias_init: NNInitFunc = zeros """initializer for the bias."""
[docs] def setup(self): kernel_h, kernel_w = self.kernel_size mask = jnp.ones((kernel_h, kernel_w, 1, 1), dtype=self.dtype) mask = mask.at[-1, kernel_w // 2 + (not self.exclusive) :].set(0) self.mask = mask
[docs] @nn.compact def __call__(self, inputs: Array) -> Array: """ Applies a masked convolution to the inputs. Args: inputs: input data with dimensions (batch, width, height, features). Returns: The convolved data. """ dtype = jnp.promote_types(inputs.dtype, self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_h, kernel_w = self.kernel_size dilation_h, dilation_w = self.kernel_dilation ones = (1, 1) is_single_input = False if inputs.ndim == 3: is_single_input = True inputs = jnp.expand_dims(inputs, axis=0) in_features = inputs.shape[-1] assert in_features % self.feature_group_count == 0 kernel_shape = self.kernel_size + ( in_features // self.feature_group_count, self.features, ) kernel = self.param( "kernel", wrap_kernel_init(self.kernel_init, self.mask), kernel_shape, self.dtype, ) mask = jnp.asarray(self.mask, dtype) kernel = jnp.asarray(kernel, dtype) # Zero padding y = jnp.pad( inputs, ( (0, 0), ((kernel_h - 1) * dilation_h, 0), (kernel_w // 2 * dilation_w, (kernel_w - 1) // 2 * dilation_w), (0, 0), ), ) dimension_numbers = flax.linen.linear._conv_dimension_numbers(inputs.shape) y = lax.conv_general_dilated( y, mask * kernel, window_strides=ones, padding="VALID", lhs_dilation=ones, rhs_dilation=self.kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) if is_single_input: y = y.squeeze(axis=0) if self.use_bias: bias = self.param("bias", self.bias_init, (self.features,), self.dtype) bias = jnp.asarray(bias, dtype) y = y + bias return y