Skip to content

NetKet Dataclass

Vicentini Filippo requested to merge dataclass into master

Created by: PhilipVinc

This PR adds a NetKet flavoured Dataclass. This is equivalent to flax dataclass but gives us a bit extra flexibility that we often needed. This was motivated by #706 and some gripes we had for quite some time. See the docstring below

    Decorator creating a NetKet-flavour dataclass.
    This behaves as a flax dataclass, that is a Frozen python dataclass, with a twist!
    See their documentation for standard behaviour.

    The new functionalities added by NetKet are:
     - it is possible to define a method `__pre_init__(*args, **kwargs) -> Tuple[Tuple,Dict]` that processes the arguments
       and keyword arguments provided to the dataclass constructor. This allows to deprecate argument
       names and add some logic to customize the constructors.
       This function should return a tuple of the edited `(args, kwargs)`. If inheriting from other classes it is reccomended
       (though not mandated) to call the same method in parent classes.
       The function should return arguments and keyword arguments that will match the standard dataclass constructor.
       The function can also not be called in some internal cases, so it should not be a strict requirement to execute it.

     - Cached Properties. It is possible to mark properties of a netket dataclass with `@property_cached`. This will make the
       property behave as a standard property, but it's value is cached and reset every time a dataclass is manipulated.
       Cached properties can be part of the flattened pytree or not. See :ref:`netket.utils.struct.property_cached` for more info.

    Optinal Args:
        init_doc: the docstring for the init method. Otherwise it's inherited from `__pre_init__`.

In this PR I switch Sampler over to this dataclass because they do a bit a mess with the constructor. This will allow to actually make #706 work.

Also, I switch over to this dataclass the semigroup (so subclasses should now subclass this) so that cached properties behave better (before we were recomputing cached properties every time we entered jit, now we can compute them only once provided they are computed before entering jit). There is an utility method to precompute all cached things at some point. This will have to be tuned in the future but for now it's a strict improvement.

This should be rebased on master and not merged so that we can revert individual commits if necessary.

Merge request reports