Add basic permutation-invariant RBMSymm implementation
Created by: femtobit
Recently, I have been interested in trying out the permutation-invariant RBM (RbmSpinSymm
in Netket 2.x) within the new nk3
architecture. Therefore, I have started writing a flax-based implementation and since this is still missing from nk3
I am opening this PR as a basis for discussion.
This is still a work in progress and I have some open jax/flax related questions, so I'd appreciate any feedback on how to improve the code (@PhilipVinc @gcarleo).
Implementation
There are two implementations of RBMSymm
in this draft PR, which can be selected using RBMSymm(..., implementation=impl)
where impl
is one of the following:
-
impl="scan"
for an implementation that evaluates on-the-fly the equation of the wave function as described in the original publication (i.e., Eq. S14 in Carleo & Troyer, Science 355, 2017) by passing the loop tojax.lax.scan
. This works, though it is relatively slow. https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L192-L201 -
impl="expand"
which first computes the full weights from the independent parameters and then just applies the normal RBM transformation of∑_j ln cosh(W s + b)_j
. This is faster thanscan
in the current form, interestingly, but still about a factor of 2 slower then the standardRBM
. There should be a way to cache the result ofW = expand(Wsymm)
, etc., so that it only needs to be recomputed on a change of parameters. I've briefly discussed this question of caching with @PhilipVinc, but I have not found a solution that works here so far. https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L272-L286 Note that this version currently only works correctly forN_symm = N_sites
. This can be fixed by changing the way the full parameters are built to match the previous NetKet versions, but this is harder to implement without explicit loops, as far as I can see, so I've left open for now.
The initial jit
call is also really slow for both cases - I think that compilation time scales with the number of translations and thus lattice, which is unfortunate as it may lead to a long initial delay for larger lattices with many symmetry operations.
Non-hashable fields on models
Due to the use of HashablePartial
in the MCState
, module fields need to be hashable as well, which is a problem when storing a list or array of permutations as in
@dataclass
class RBMSymm:
permutations: List
...
This draft PR circumvents this by using this hack: https://github.com/netket/netket/blob/16186e885f3798d5f4fa6f1ef9cf499e6698ea9a/netket/models/rbm.py#L303-L308 and then storing
@dataclass
class RBMSymm:
permutations: FakeHashArray
but that can't be the proper solution. Any suggestions? Also, do we need to return a different hash for different permutations? It seems unlikely for the permutations to change without changing the MCState
and recompiling everything altogether.
Miscellaneous
-
RBMSymm
is showcased in the re-addedHeisenberg1d/
example. - Since using only the translation symmetries, instead of all graph automorphisms, is a common scenario, this PR adds a
Grid.periodic_translations()
method that returns only those.
To do before this PR is ready
-
Probably decide on one implementation to use. The one with jax.lax.scan
works but is relatively slow; theexpand
one needs to work with the general (N_symm != N
) case and should be optimized to re-compute the full weights only when needed. -
Get rid of the FakeHashArray
workaround. -
Remove code duplication between RBM
andRBMSymm
implementations. -
Proper docs / comments.