Source code for netket.experimental.driver.tdvp

# Copyright 2020, 2021  The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Union
import warnings

import jax
import jax.numpy as jnp
import jax.experimental.host_callback as hcb
import numpy as np
from tqdm import tqdm

import netket as nk
from netket.driver import AbstractVariationalDriver
from netket.driver.abstract_variational_driver import _to_iterable
from netket.driver.vmc_common import info
from netket.jax import HashablePartial
from netket.logging.json_log import JsonLog
from netket.operator import AbstractOperator
from netket.optimizer import LinearOperator
from netket.optimizer.qgt import QGTAuto
from netket.utils import mpi
from netket.utils.dispatch import dispatch
from netket.utils.types import PyTree
from netket.vqs import VariationalState, VariationalMixedState, MCState

from netket.experimental.dynamics import RKIntegratorConfig
from netket.experimental.dynamics._rk_solver import euclidean_norm, maximum_norm


class TDVP(AbstractVariationalDriver):
    """
    Variational time evolution based on the time-dependent variational principle which,
    when used with Monte Carlo sampling via :ref:`~netket.vqs.MCState`, is the time-dependent VMC
    (t-VMC) method.
    """

[docs] def __init__( self, operator: AbstractOperator, variational_state: VariationalState, integrator: RKIntegratorConfig, *, t0: float = 0.0, propagation_type="real", qgt: LinearOperator = None, linear_solver=None, linear_solver_restart: bool = False, error_norm: Union[str, Callable] = "euclidean", ): r""" Initializes the time evolution driver. Args: operator: The generator of the dynamics (Hamiltonian for pure states, Lindbladian for density operators). variational_state: The variational state. integrator: Configuration of the algorithm used for solving the ODE. t0: Initial time at the start of the time evolution. propagation_type: Determines the equation of motion: "real" for the real-time Schödinger equation (SE), "imag" for the imaginary-time SE. qgt: The QGT specification. linear_solver: The solver for solving the linear system determining the time evolution. linear_solver_restart: If False (default), the last solution of the linear system is used as initial value in subsequent steps. error_norm: Norm function used to calculate the error with adaptive integrators. Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`, "maximum" for the maximum norm :math:`\max_i |w_i|` or "qgt", in which case the scalar product induced by the QGT :math:`S` is used to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested in PRL 125, 100503 (2020). Additionally, it possible to pass a custom function with signature :code:`norm(x: PyTree) -> float` which maps a PyTree of parameters :code:`x` to the corresponding norm. Note that norm is used in jax.jit-compiled code. """ self._t0 = t0 if linear_solver is None: linear_solver = nk.optimizer.solver.svd if qgt is None: qgt = QGTAuto(solver=linear_solver) super().__init__( variational_state, optimizer=None, minimized_quantity_name="Generator" ) self._generator_repr = repr(operator) if isinstance(operator, AbstractOperator): op = operator.collect() self._generator = lambda _: op else: self._generator = operator self.propagation_type = propagation_type if isinstance(variational_state, VariationalMixedState): # assuming Lindblad Dynamics # TODO: support density-matrix imaginary time evolution if propagation_type == "real": self._loss_grad_factor = 1.0 else: raise ValueError( "only real-time Lindblad evolution is supported for " "mixed states" ) else: if propagation_type == "real": self._loss_grad_factor = -1.0j elif propagation_type == "imag": self._loss_grad_factor = -1.0 else: raise ValueError("propagation_type must be one of 'real', 'imag'") self.qgt = qgt self.linear_solver = linear_solver self.linear_solver_restart = linear_solver_restart self._dw = None # type: PyTree self._last_qgt = None if isinstance(error_norm, Callable): pass elif error_norm == "euclidean": error_norm = euclidean_norm elif error_norm == "maximum": error_norm = maximum_norm elif error_norm == "qgt": w = self.state.parameters norm_dtype = nk.jax.dtype_real(nk.jax.tree_dot(w, w)) # QGT norm is called via host callback since it accesses the driver error_norm = lambda x: hcb.call( HashablePartial(qgt_norm, self), x, result_shape=jax.ShapeDtypeStruct((), norm_dtype), ) else: raise ValueError( "error_norm must be a callable or one of 'euclidean', 'qgt', 'maximum'." ) self._odefun = HashablePartial(odefun_host_callback, self.state, self) self._integrator = integrator( self._odefun, t0, self.state.parameters, norm=error_norm, ) self._stop_count = 0 self._postfix = {}
@property def integrator(self): """ The underlying integrator which computes the time steps. """ return self._integrator @property def generator(self) -> Callable: """ The generator of the dynamics as a function with signature generator(t: float) -> AbstractOperator """ return self._generator
[docs] def advance(self, T: float): """ Advance the time propagation by :code:`T` to :code:`self.t + T`. Args: T: Length of the integration interval. """ for _ in self.iter(T): pass
[docs] def iter(self, T: float, *, tstops: Optional[Sequence[float]] = None): """ Returns a generator which advances the time evolution for an interval of length :code:`T`, stopping at :code:`tstops`. Args: T: Length of the integration interval. tstops: A sequence of stopping times, each within the intervall :code:`[self.t0, self.t0 + T]`, at which this method will stop and yield. By default, a stop is performed after each time step (at potentially varying step size if an adaptive integrator is used). Yields: The current step count. """ yield from self._iter(T, tstops)
def _iter( self, T: float, tstops: Optional[Sequence[float]] = None, callback: Callable = None, ): """ Implementation of :code:`iter`. This method accepts and additional `callback` object, which is called after every accepted step. """ t_end = self.t + T if tstops is not None and ( np.any(np.less(tstops, self.t)) or np.any(np.greater(tstops, t_end)) ): raise ValueError(f"All tstops must be in range [t, t + T]=[{self.t}, {T}]") if tstops is not None and len(tstops) > 0: tstops = np.sort(tstops) always_stop = False else: tstops = [] always_stop = True while self.t < t_end: if always_stop or ( len(tstops) > 0 and (np.isclose(self.t, tstops[0]) or self.t > tstops[0]) ): self._stop_count += 1 yield self.t tstops = tstops[1:] step_accepted = False while not step_accepted: if not always_stop and len(tstops) > 0: max_dt = tstops[0] - self.t else: max_dt = None step_accepted = self._integrator.step(max_dt=max_dt) if self._integrator.errors: raise RuntimeError( f"RK solver: {self._integrator.errors.message()}" ) elif self._integrator.warnings: warnings.warn( f"RK solver: {self._integrator.warnings.message()}", UserWarning, ) self._step_count += 1 # optionally call callback if callback: callback() # Yield one last time if the remaining tstop is at t_end if (always_stop and np.isclose(self.t, t_end)) or ( len(tstops) > 0 and np.isclose(tstops[0], t_end) ): yield self.t
[docs] def run( self, T, out=None, obs=None, *, tstops=None, show_progress=True, callback=None, ): """ Runs the time evolution. By default uses :ref:`netket.logging.JsonLog`. To know about the output format check it's documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output. Args: T: The integration time period. out: A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger. obs: An iterable containing the observables that should be computed. tstops: A sequence of stopping times, each within the intervall :code:`[self.t0, self.t0 + T]`, at which the driver will stop and perform estimation of observables, logging, and excecute the callback function. By default, a stop is performed after each time step (at potentially varying step size if an adaptive integrator is used). show_progress: If true displays a progress bar (default=True) callback: Callable or list of callable callback functions to be executed at each stoping time. """ if obs is None: obs = {} if callback is None: callback = lambda *_args, **_kwargs: True # Log only non-root nodes if self._mynode == 0: if out is None: loggers = () # if out is a path, create an overwriting Json Log for output elif isinstance(out, str): loggers = (JsonLog(out, "w"),) else: loggers = _to_iterable(out) else: loggers = tuple() show_progress = False callbacks = _to_iterable(callback) callback_stop = False with tqdm( total=np.asarray(self.t + T), disable=not show_progress, unit_scale=True, ) as pbar: first_step = True # We need a closure to pass to self._iter in order to update the progress bar even if # there are no tstops def update_progress_bar(): # Reset the timing of tqdm after the first step to ignore compilation time nonlocal first_step if first_step: first_step = False pbar.unpause() pbar.n = np.asarray(self._integrator.t) self._postfix["n"] = self.step_count pbar.set_postfix(self._postfix) pbar.refresh() for step in self._iter(T, tstops=tstops, callback=update_progress_bar): log_data = self.estimate(obs) self._postfix = {"n": self.step_count} # if the cost-function is defined then report it in the progress bar if self._loss_stats is not None: self._postfix.update( { self._loss_name: str(self._loss_stats), } ) log_data[self._loss_name] = self._loss_stats pbar.set_postfix(self._postfix) # Execute callbacks before loggers because they can append to log_data for callback in callbacks: if not callback(step, log_data, self): callback_stop = True for logger in loggers: logger(self.step_value, log_data, self.state) if len(callbacks) > 0: if mpi.mpi_any(callback_stop): break update_progress_bar() # Final update so that it shows up filled. update_progress_bar() # flush at the end of the evolution so that final values are saved to # file for logger in loggers: logger.flush(self.state) return loggers
def _log_additional_data(self, obs, step): obs["t"] = self.t @property def _default_step_size(self): # Essentially means return None @property def step_value(self): return self.t @property def dt(self): """Current time step.""" return self._integrator.dt @property def t(self): """Current time.""" return self._integrator.t @t.setter def t(self, t): self._integrator.t = jnp.array(t, dtype=self._integrator.t) @property def t0(self): """ The initial time set when the driver was created. """ return self._t0 def __repr__(self): return f"{type(self).__name__}(step_count={self.step_count}, t={self.t})"
[docs] def info(self, depth=0): lines = [ "{}: {}".format(name, info(obj, depth=depth + 1)) for name, obj in [ ("generator ", self._generator_repr), ("integrator ", self._integrator), ("linear solver ", self.linear_solver), ("state ", self.state), ] ] return "\n{}".format(" " * 3 * (depth + 1)).join([str(self)] + lines)
[docs] def ode(self, t=None, w=None): r""" Evaluates the TDVP equation of motion .. math:: G(w) \dot w = \gamma F(w, t) where :math:`G(w)` is the QGT, :math:`F(w, t)` the gradient of :code:`self.generator` and :math:`\gamma` one of :math:`\gamma = -1` (imaginary-time dynamics for :code:`MCState`), :math:`\gamma = -i` (real-time dynamics for :code:`MCState`), or :math:`\gamma = 1` (real-time dynamics for :code:`MCMixedState`). Args: t: Time (defaults to :code:`self.t`). w: Variational parameters (defaults to :code:`self.state.parameters`). Returns: The time-derivative :math:`\dot w`. """ if t is None: t = self.t if w is None: w = self.state.parameters return self._odefun(t, w)
def qgt_norm(driver: TDVP, x: PyTree): """ Computes the norm induced by the QGT :math:`S`, i.e, :math:`x^\\dagger S x`. """ y = driver._last_qgt @ x # pylint: disable=protected-access xc_dot_y = nk.jax.tree_dot(nk.jax.tree_conj(x), y) return jnp.sqrt(jnp.real(xc_dot_y)) @dispatch def odefun(state, driver, t, w, **kwargs): # pylint: disable=unused-argument raise NotImplementedError(f"odefun not implemented for {type(state)}") @dispatch def odefun(state: MCState, driver: TDVP, t, w, *, stage=0): # noqa: F811 # pylint: disable=protected-access state.parameters = w state.reset() driver._loss_stats, driver._loss_grad = state.expect_and_grad( driver.generator(t), use_covariance=True, ) driver._loss_grad = jax.tree_map( lambda x: driver._loss_grad_factor * x, driver._loss_grad ) qgt = driver.qgt(driver.state) if stage == 0: # TODO: This does not work with FSAL. driver._last_qgt = qgt initial_dw = None if driver.linear_solver_restart else driver._dw driver._dw, _ = qgt.solve(driver.linear_solver, driver._loss_grad, x0=initial_dw) return driver._dw def odefun_host_callback(state, driver, *args, **kwargs): """ Calls odefun through a host callback in order to make the rest of the ODE solver jit-able. """ result_shape = jax.tree_map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), state.parameters, ) return hcb.call( lambda args_and_kw: odefun(state, driver, *args_and_kw[0], **args_and_kw[1]), # pack args and kwargs together, since host_callback passes a single argument: (args, kwargs), result_shape=result_shape, )