Source code for netket.driver.steady_state

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp

from netket.operator import Squared, AbstractSuperOperator
from netket.vqs import MCMixedState
from netket.utils import warn_deprecation
from netket.optimizer import (
    identity_preconditioner,
    PreconditionerT,
)

from .vmc_common import info
from .abstract_variational_driver import AbstractVariationalDriver


[docs]class SteadyState(AbstractVariationalDriver): """ Steady-state driver minimizing L^†L. """
[docs] def __init__( self, lindbladian, optimizer, *args, variational_state: MCMixedState = None, preconditioner: PreconditionerT = None, sr: PreconditionerT = None, sr_restart: bool = None, **kwargs, ): """ Initializes the driver class. Args: lindbladian: The Lindbladian of the system. optimizer: Determines how optimization steps are performed given the bare energy gradient. preconditioner: Determines which preconditioner to use for the loss gradient. This must be a tuple of `(object, solver)` as documented in the section `preconditioners` in the documentation. The standard preconditioner included with NetKet is Stochastic Reconfiguration. By default, no preconditioner is used and the bare gradient is passed to the optimizer. """ if variational_state is None: variational_state = MCMixedState(*args, **kwargs) if not isinstance(lindbladian, AbstractSuperOperator): raise TypeError("The first argument must be a super-operator") if sr is not None: if preconditioner is not None: raise ValueError( "sr is deprecated in favour of preconditioner kwarg. You should not pass both" ) else: preconditioner = sr warn_deprecation( ( "The `sr` keyword argument is deprecated in favour of `preconditioner`." "Please update your code to `SteadyState(.., precondioner=your_sr)`" ) ) if sr_restart is not None: if preconditioner is None: raise ValueError( "sr_restart only makes sense if you have a preconditioner/SR." ) else: preconditioner.solver_restart = sr_restart warn_deprecation( ( "The `sr_restart` keyword argument is deprecated in favour of specifiying " "`solver_restart` in the constructor of the SR object." "Please update your code to `SteadyState(.., preconditioner=nk.optimizer.SR(..., solver_restart=True/False))`" ) ) # move as kwarg once deprecations are removed if preconditioner is None: preconditioner = identity_preconditioner super().__init__(variational_state, optimizer, minimized_quantity_name="LdagL") self._lind = lindbladian self._ldag_l = Squared(lindbladian) self.preconditioner = preconditioner self._dp = None self._S = None self._sr_info = None
def _forward_and_backward(self): """ Performs a number of VMC optimization steps. Args: n_steps (int): Number of steps to perform. """ self.state.reset() # Compute the local energy estimator and average Energy self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ldag_l) # if it's the identity it does # self._dp = self._loss_grad self._dp = self.preconditioner(self.state, self._loss_grad) # If parameters are real, then take only real part of the gradient (if it's complex) self._dp = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real), self._dp, self.state.parameters, ) return self._dp @property def ldagl(self): """ Return MCMC statistics for the expectation value of observables in the current state of the driver. """ return self._loss_stats # def reset(self): # super().reset() def __repr__(self): return ( "SteadyState(" + f"\n step_count = {self.step_count}," + f"\n state = {self.state})" )
[docs] def info(self, depth=0): lines = [ "{}: {}".format(name, info(obj, depth=depth + 1)) for name, obj in [ ("Lindbladian ", self._lind), ("Optimizer ", self._optimizer), ("SR solver ", self.sr), ] ] return "\n{}".format(" " * 3 * (depth + 1)).join([str(self)] + lines)