Skip to content
Snippets Groups Projects
Unverified Commit 31730d4d authored by Filippo Vicentini's avatar Filippo Vicentini Committed by GitHub
Browse files

Make Callback objects serialisable (needed for checkpointing..) (#1516)

Makes them fully data class compliant by moving some cache definition
into the class declaration.
parent e0da0db6
No related branches found
No related tags found
No related merge requests found
......@@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
from netket.utils import struct
@dataclass
# Mark this class a NetKet dataclass so that it can automatically be serialized by Flax.
@struct.dataclass(_frozen=False)
class EarlyStopping:
"""A simple callback to stop NetKet if there are no more improvements in the training.
based on `driver._loss_name`."""
based on `driver._loss_name`.
"""
min_delta: float = 0.0
"""Minimum change in the monitored quantity to qualify as an improvement."""
......@@ -38,13 +41,15 @@ class EarlyStopping:
monitor: str = "mean"
"""Loss statistic to monitor. Should be one of 'mean', 'variance', 'sigma'."""
def __post_init__(self):
self._best_val: float = np.infty
"""Stores the best loss seen so far"""
self._best_iter: int = 0
"""Stores the iteration at which we've seen the best loss so far"""
self._best_patience_counter: int = 0
"""Stores the iteration at which we've seen the best loss so far"""
# The quantities below are internal and should not be edited directly
# by the user
_best_val: float = np.infty
"""Best value of the loss observed up to this iteration. """
_best_iter: int = 0
"""Iteration at which the `_best_val` was observed."""
_best_patience_counter: int = 0
"""Stores the iteration at which we've seen the best loss so far"""
def __call__(self, step, log_data, driver):
"""
......
......@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
from netket.utils import struct
@dataclass
# Mark this class a NetKet dataclass so that it can automatically be serialized by Flax.
@struct.dataclass(_frozen=False)
class InvalidLossStopping:
"""A simple callback to stop NetKet if there are no more improvements in the training.
based on `driver._loss_name`."""
......@@ -28,8 +29,8 @@ class InvalidLossStopping:
patience: Union[int, float] = 0
"""Number of epochs with invalid loss after which training will be stopped."""
def __post_init__(self):
self._last_valid_iter = 0
_last_valid_iter: int = 0
"""Last valid iteration, to check against patience"""
def __call__(self, step, log_data, driver):
"""
......
......@@ -12,23 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import time
from netket.utils import struct
# Mark this class a NetKet dataclass so that it can automatically be serialized by Flax.
@struct.dataclass(_frozen=False)
class Timeout:
"""A simple callback to stop NetKet after some time has passed."""
"""A simple callback to stop NetKet after some time has passed.
def __init__(self, timeout):
"""
Constructs a new Timeout object that monitors whether a driver has been training
for more than a given timeout in order to hard stop training.
This callback monitors whether a driver has been training for more
than a given timeout in order to hard stop training.
"""
Args:
timeout: Number of seconds to wait before hard stopping training.
"""
assert timeout > 0
self.__timeout = timeout
self.__init_time = None
timeout: float
"""Number of seconds to wait before the training will be stopped."""
_init_time: Optional[float] = None
"""
Internal field storing the time at which the first iteration has been
performed.
"""
def __post_init__(self):
if not self.timeout > 0:
raise ValueError("`timeout` must be larger than 0.")
def reset(self):
"""Resets the initial time of the training"""
......@@ -49,10 +59,10 @@ class Timeout:
Note:
This callback does not make use of `step`, `log_data` nor `driver`.
"""
if self.__init_time is None:
self.__init_time = time.time()
if self._init_time is None:
self._init_time = time.time()
if time.time() - self._init_time >= self.timeout:
return False
else:
print(time.time() - self.__init_time)
if time.time() - self.__init_time >= self.__timeout:
return False
return True
return True
......@@ -145,8 +145,6 @@ def local_operators_to_pauli_strings(hilbert, operators, acting_on, constant, dt
weights = []
if len(operators) > 0:
mats = map(_convert_to_dense, operators)
# maximum number of non-identity operators
n_max = max(list(map(len, acting_on)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment