Add permutation-invariant DenseSymm layer and RBMSymm
Created by: femtobit
Reopening of #555.
This PR now adds a DenseSymm
layer that can be used as follows:
https://github.com/netket/netket/blob/a37fa4a931882b03b76dbae18991bd66cf727469/netket/models/rbm.py#L330-L338
DenseSymm
is implemented, as per @PhilipVinc's suggestion, by (once) computing a symmerizer
tensor that is then applied as
https://github.com/netket/netket/blob/a37fa4a931882b03b76dbae18991bd66cf727469/netket/nn/linear.py#L282
which should do the same thing as the old C++ code:
https://github.com/netket/netket/blob/9b2a22b23dc307e396027ef99fd73b519824c102/Sources/Machine/rbm_spin_symm.cc#L224-L229
Performance looks good on my machine: Running the Heisenberg 1D example,
# With RBMSymm(..., alpha=4)
100%|█| 300/300 [00:25<00:00, 11.72it/s, Energy=-35.611 ± 0.011 [σ²=0.125, R̂=0.9990]]]
# With non-symmetrized RBM(..., alpha=4)
100%|█| 300/300 [00:17<00:00, 17.37it/s, Energy=-35.217 ± 0.055 [σ²=3.099, R̂=1.0020]]]
Still a bit slower than RBM
, but not even by a factor of 2 and there is probably room for optimization. Also, the compilation time is now pretty good.
Some open questions:
- Should
DenseSymm
have a user-friendly constructor (similar toRBMSymm
right now)? Right now, one has to pass a lambda returning aDeviceArray
of symmetries, otherwise weird errors happen. - We can probably remove the
expand
implementation ofRBMSymm
. Should we also remove thelax.scan
one? It is slower than the symmetrizer one but still a somewhat educational example, IMHO, given that it pretty directly implements the formula from the original Science paper. (And maybe it will be beneficial for very large systems?)
Open TODOs, which I'll get to soon:
-
Remove code duplication for the RBM
, now that we haveDenseSymm
. -
Write proper docs. Not necessarily in this PR: -
Add proper tests that run VMC
on our models (some jax-related errors do not occur when initializing the machine but during VMC, since more code is jit-ed then).