Skip to content

Add permutation-invariant DenseSymm layer and RBMSymm

Vicentini Filippo requested to merge github/fork/femtobit/nk3-symm into master

Created by: femtobit

Reopening of #555.

This PR now adds a DenseSymm layer that can be used as follows: https://github.com/netket/netket/blob/a37fa4a931882b03b76dbae18991bd66cf727469/netket/models/rbm.py#L330-L338

DenseSymm is implemented, as per @PhilipVinc's suggestion, by (once) computing a symmerizer tensor that is then applied as https://github.com/netket/netket/blob/a37fa4a931882b03b76dbae18991bd66cf727469/netket/nn/linear.py#L282 which should do the same thing as the old C++ code: https://github.com/netket/netket/blob/9b2a22b23dc307e396027ef99fd73b519824c102/Sources/Machine/rbm_spin_symm.cc#L224-L229

Performance looks good on my machine: Running the Heisenberg 1D example,

# With RBMSymm(..., alpha=4)
100%|█| 300/300 [00:25<00:00, 11.72it/s, Energy=-35.611 ± 0.011 [σ²=0.125, R̂=0.9990]]]
# With non-symmetrized RBM(..., alpha=4)
100%|█| 300/300 [00:17<00:00, 17.37it/s, Energy=-35.217 ± 0.055 [σ²=3.099, R̂=1.0020]]]

Still a bit slower than RBM, but not even by a factor of 2 and there is probably room for optimization. Also, the compilation time is now pretty good.

Some open questions:

  • Should DenseSymm have a user-friendly constructor (similar to RBMSymm right now)? Right now, one has to pass a lambda returning a DeviceArray of symmetries, otherwise weird errors happen.
  • We can probably remove the expand implementation of RBMSymm. Should we also remove the lax.scan one? It is slower than the symmetrizer one but still a somewhat educational example, IMHO, given that it pretty directly implements the formula from the original Science paper. (And maybe it will be beneficial for very large systems?)

Open TODOs, which I'll get to soon:

  • Remove code duplication for the RBM, now that we have DenseSymm.
  • Write proper docs. Not necessarily in this PR:
  • Add proper tests that run VMC on our models (some jax-related errors do not occur when initializing the machine but during VMC, since more code is jit-ed then).

Merge request reports