Source code for netket.logging.json_log

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import orjson
import time

import os
from os import path as _path
import numpy as np

from flax import serialization

from .runtime_log import RuntimeLog


def _exists_json(prefix):
    return _path.exists(prefix + ".log") or _path.exists(prefix + ".mpack")


def default(obj):
    if hasattr(obj, "to_json"):
        return obj.to_json()
    elif hasattr(obj, "to_dict"):
        return obj.to_dict()
    elif isinstance(obj, np.ndarray):
        if np.issubdtype(obj.dtype, np.complexfloating):
            return {"real": obj.real, "imag": obj.imag}
        else:
            if obj.ndim == 0:
                return obj.item()
            elif obj.ndim == 1:
                return obj.tolist()
            else:
                raise TypeError

    elif hasattr(obj, "_device"):
        return np.array(obj)
    elif isinstance(obj, complex):
        return {"real": obj.real, "imag": obj.imag}

    raise TypeError


[docs]class JsonLog(RuntimeLog): """ Json Logger, that can be passed with keyword argument `logger` to Monte Carlo drivers in order to serialize the outpit data of the simulation. If the model state is serialized, then it is serialized using the msgpack protocol of flax. For more information on how to de-serialize the output, see `here <https://flax.readthedocs.io/en/latest/flax.serialization.html>`_. The target of the serialization is the variational state itself. Data is serialized to json as several nested dictionaries. You can deserialize by simply calling :code:`json.load(open(filename))`. Logged expectation values will be captured inside histories objects, so they will have a subfield `iter` with the iterations at which that quantity has been computed, then `Mean` and others. Complex numbers are logged as dictionaries :code:`{'real':list, 'imag':list}`. """
[docs] def __init__( self, output_prefix: str, mode: str = "write", save_params_every: int = 50, write_every: int = 50, save_params: bool = True, autoflush_cost: float = 0.005, ): """ Construct a Json Logger. Args: output_prefix: the name of the output files before the extension save_params_every: every how many iterations should machine parameters be flushed to file write_every: every how many iterations should data be flushed to file mode: Specify the behaviour in case the file already exists at this output_prefix. Options are - `[w]rite`: (default) overwrites file if it already exists; - `[x]` or `fail`: fails if file already exists; save_params: bool flag indicating whever variables of the variational state should be serialized at some interval. The output file is overwritten every time variables are saved again. autoflush_cost: Maximum fraction of runtime that can be dedicated to serializing data. Defaults to 0.005 (0.5 per cent) """ super().__init__() # Shorthands for mode if mode == "w": mode = "write" elif mode == "a": mode = "append" elif mode == "x": mode = "fail" if not ((mode == "write") or (mode == "append") or (mode == "fail")): raise ValueError( "Mode not recognized: should be one of `[w]rite`, `[a]ppend` or" "`[x]`(fail)." ) if mode == "append": raise ValueError("Append mode is no longer supported.") file_exists = _exists_json(output_prefix) if file_exists and mode == "fail": raise ValueError( "Output file already exists. Either delete it manually or" "change `output_prefix`." ) dir_name = _path.dirname(output_prefix) if dir_name != "": os.makedirs(dir_name, exist_ok=True) self._prefix = output_prefix self._file_mode = mode self._write_every = write_every self._save_params_every = save_params_every self._old_step = 0 self._steps_notflushed_write = 0 self._steps_notflushed_pars = 0 self._save_params = save_params self._files_open = [output_prefix + ".log", output_prefix + ".mpack"] self._autoflush_cost = autoflush_cost self._last_flush_time = time.time() self._last_flush_runtime = 0.0 self._flush_log_time = 0.0 self._flush_pars_time = 0.0
[docs] def __call__(self, step, item, variational_state): old_step = self._old_step super().__call__(step, item, variational_state) # Check if the time from the last flush is higher than the maximum # allowed runtime cost of flushing elapsed_time = time.time() - self._last_flush_time flush_anyway = (self._last_flush_runtime / elapsed_time) < self._autoflush_cost if ( self._steps_notflushed_write % self._write_every == 0 or step == old_step - 1 or flush_anyway ): self._flush_log() if ( self._steps_notflushed_pars % self._save_params_every == 0 or step == old_step - 1 ): self._flush_params(variational_state) self._old_step = step self._steps_notflushed_write += 1 self._steps_notflushed_pars += 1
def _flush_log(self): self._last_flush_time = time.time() with open(self._prefix + ".log", "wb") as outfile: outfile.write(orjson.dumps(self.data, default=default)) self._steps_notflushed_write = 0 # Time how long flushing data takes. self._last_flush_runtime = time.time() - self._last_flush_time self._flush_log_time += self._last_flush_runtime def _flush_params(self, variational_state): if not self._save_params: return _time = time.time() binary_data = serialization.to_bytes(variational_state.variables) with open(self._prefix + ".mpack", "wb") as outfile: outfile.write(binary_data) self._steps_notflushed_pars = 0 self._flush_pars_time += time.time() - _time
[docs] def flush(self, variational_state=None): """ Writes to file the content of this logger. Args: variational_state: optionally also writes the parameters of the machine. """ self._flush_log() if variational_state is not None: self._flush_params(variational_state)
def __repr__(self): _str = f"JsonLog('{self._prefix}', mode={self._file_mode}, " _str = _str + f"autoflush_cost={self._autoflush_cost})" _str = _str + "\n Runtime cost:" _str = _str + f"\n \tLog: {self._flush_log_time}" _str = _str + f"\n \tParams: {self._flush_pars_time}" return _str