Add support for a single-file HDF5 log
Created by: femtobit
[This issue is part of UnitaryHack and comes with a bounty of 75$]
Context
NetKet simulation drivers support output of the current state of an optimization as well as expectation values of observables and custom data via the classes provided in netket.logging
.
Currently, NetKet has two main logging implementations:
-
JsonLog
, which is the standard logger and writes log data to a JSON file (and also saves a regularly overwritten snapshot of the current network parameters as a MessagePack file). -
StateLog
, which saves intermediate network parameters as a separate file for each step (1.mpack
,2.mpack
,3.mpack
, etc.) to a folder or ZIP file.
While these work, it would be nice for easier data handling and interoperability with other tools to support writing simulation output into a single file in the commonly used HDF5 format via h5py
.
Implementation notes
To resolve this issue, the following should be implemented:
- A new logger class
netket.logging.HDF5Log
which writes both the information currently contained inJsonLog
and the network parameters (at each step or every certain number of steps as it can be configured inStateLog
) to an HDF5 file specified by the user. - The logger needs to be compatible with the current NetKet logging interface, so that it can be used in place of
JsongLog
andStateLog
in NetKet drivers. (For this PR this means compatibility with the current use of the loggers innetket.driver.AbstractVariationalDriver
.) Specifically:- The logger needs to implement
__call__(self, step, log_data, state)
which is called at each optimization step and write the provided data to the HDF5 log file. - Since the number of steps (i.e., number of calls to
__call__
) is not known before the start of the simulation, the log must support appending data to the log every at every step (therefore, the datasets within the HDF5 file need to be resized as necessary).
- The logger needs to implement
Note that log_data
is a dictionary mapping a name to a specific logged quantity. The value can be of several different types. The HDF5Log
should support scalar numbers, NumPy/JAX arrays, and netket.stats.Stats
objects.
- Scalars should be stored as a group with one dataset
values
of shape(n_steps,)
containing the logged values and another datasetiters
of the same shape containing the value ofstep
at which each entry was logged (compareJsonLog
output). - Arrays should be stored in the same way, but with values having shape
(n_steps, *array_shape)
. -
netket.stats.Stats
are essentially dataclasses with the fieldsmean
,variance
,error_of_mean
,tau_corr
,R_hat
. For a stats object, the log file should contain a group containing each field as a separate dataset (and aniters
field like above).
NetKet stores network parameters as JAX pytrees with leaves being complex-valued or real-valued arrays. The HDF5Log
should store a flattened version (as returned by netket.jax.tree_ravel
) in a single dataset of shape (n_steps, n_parameters)
.
Here is an example layout, showing what a resulting HDF5 log file should contain after 1001 logging steps for a network with 256 variational parameters:
Name Data type Shape
# network parameters:
/parameters/iters int (1001,)
/parameters/values complex128 (1001, 256)
# an entry of `log_data` called `Energy` of type `netket.stats.Stats`
/data/Energy/iters int (1001,)
/data/Energy/mean complex128 (1001,)
/data/Energy/error_of_mean float64 (1001,)
/data/Energy/variance float64 (1001,)
/data/Energy/tau_corr float64 (1001,)
/data/Energy/R_hat float64 (1001,)
# an entry of `log_data` called `S_x` of type `netket.stats.Stats`
/data/S_x/iters int (1001,)
/data/S_x/mean complex128 (1001,)
/data/S_x/error_of_mean float64 (1001,)
/data/S_x/variance float64 (1001,)
/data/S_x/tau_corr float64 (1001,)
/data/S_x/R_hat float64 (1001,)
# an entry of `log_data` called `acceptance` of type `float64`
/data/acceptance/iters int (1001,)
/data/acceptance/values float64 (1001,)
Note that for a normal netket.VMC
run, there is a lot of redundancy in the .../iters
arrays (as they will all be equal and of the form [0, 1, ..., n_steps - 1]
). We accept this overhead both for compatibility with the existing JsongLog
and for the added flexibility it provides for custom logging at subsets of steps.