Skip to content

Add basic permutation-invariant RBMSymm implementation

Vicentini Filippo requested to merge github/fork/femtobit/nk3-symm into nk3

Created by: femtobit

Recently, I have been interested in trying out the permutation-invariant RBM (RbmSpinSymm in Netket 2.x) within the new nk3 architecture. Therefore, I have started writing a flax-based implementation and since this is still missing from nk3 I am opening this PR as a basis for discussion.

This is still a work in progress and I have some open jax/flax related questions, so I'd appreciate any feedback on how to improve the code (@PhilipVinc @gcarleo).

Implementation

There are two implementations of RBMSymm in this draft PR, which can be selected using RBMSymm(..., implementation=impl) where impl is one of the following:

  1. impl="scan" for an implementation that evaluates on-the-fly the equation of the wave function as described in the original publication (i.e., Eq. S14 in Carleo & Troyer, Science 355, 2017) by passing the loop to jax.lax.scan. This works, though it is relatively slow. https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L192-L201
  2. impl="expand" which first computes the full weights from the independent parameters and then just applies the normal RBM transformation of ∑_j ln cosh(W s + b)_j. This is faster than scan in the current form, interestingly, but still about a factor of 2 slower then the standard RBM. There should be a way to cache the result of W = expand(Wsymm), etc., so that it only needs to be recomputed on a change of parameters. I've briefly discussed this question of caching with @PhilipVinc, but I have not found a solution that works here so far. https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L272-L286 Note that this version currently only works correctly for N_symm = N_sites. This can be fixed by changing the way the full parameters are built to match the previous NetKet versions, but this is harder to implement without explicit loops, as far as I can see, so I've left open for now.

The initial jit call is also really slow for both cases - I think that compilation time scales with the number of translations and thus lattice, which is unfortunate as it may lead to a long initial delay for larger lattices with many symmetry operations.

Non-hashable fields on models

Due to the use of HashablePartial in the MCState, module fields need to be hashable as well, which is a problem when storing a list or array of permutations as in

@dataclass
class RBMSymm:
    permutations: List
    ...

This draft PR circumvents this by using this hack: https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L303-L308 and then storing

@dataclass
class RBMSymm:
    permutations: FakeHashArray

but that can't be the proper solution. Any suggestions? Also, do we need to return a different hash for different permutations? It seems unlikely for the permutations to change without changing the MCState and recompiling everything altogether.

Miscellaneous

  1. RBMSymm is showcased in the re-added Heisenberg1d/ example.
  2. Since using only the translation symmetries, instead of all graph automorphisms, is a common scenario, this PR adds a Grid.periodic_translations() method that returns only those.

To do before this PR is ready

  • Probably decide on one implementation to use. The one with jax.lax.scan works but is relatively slow; the expand one needs to work with the general (N_symm != N) case and should be optimized to re-compute the full weights only when needed.
  • Get rid of the FakeHashArray workaround.
  • Remove code duplication between RBM and RBMSymm implementations.
  • Proper docs / comments.

Merge request reports