Skip to content

Use explicit group structure for graph symmetries

Vicentini Filippo requested to merge github/fork/femtobit/symmetry-groups into master

Created by: femtobit

In #596 we've discussed some ideas how to best implement new symmetries in an extensible fashion.

As I mentioned there, I think it would be great to not just have a set of methods that return permutation indices but something with more information on the structure.

Therefore, in this PR, I propose making the group structure more explicit by working with abstract symmetry operations (which are essentially wrapped functions applying those operations to an input array with some magic to make combining them convenient) at least until the indices are actually needed.

This has several benefits:

  • We no longer have to spend time remembering which combination of translations [ 8, 6, 7, 11, 9, 10, 2, 0, 1, 5, 3, 4] corresponds to (which is even more fun when we combine more symmetry operations like in the space group). It's T(2, 1) (in a [4, 3] grid), which is much more readable.
  • We can combine symmetries easily with some syntactic sugar, e.g., g.translations(dim=0) @ g.translations(dim=1, period=2) (see below).
  • Adding new symmetries (like @chrisrothUT is doing in #596) just requires to create a sequence of the correct symmetry operations and returning those wrapped in a SymmGroup. Combination with other symmetries is then already taken care of. (Caveat: As long as it is a direct product - more combilated structures would require more work, but that should be doable in this framework as well if really needed.)

While this may look a bit over-engineered, it is not actually too much additional code and working with symmetries (and adding new ones) becomes much more convenient IMHO.

(Note that this is orthogonal to the question of whether symmetries should be implemented in Grid or Lattice - I left it in Grid here because that code is already there. The same thing could be done in `Lattice.)

Below are some usage examples:

Using the symmetries

>>> import netket as nk
>>> g = nk.graph.Chain(4)
>>> G = g.translations()
>>> G  # a collection of symmetry operations
SymmGroup(
  Id(),
  T(1,),
  T(2,),
  T(3,)
)

We can get the explicit permutation indices if we need them as before:

>>> G.to_array()
array([[0, 1, 2, 3],
       [3, 0, 1, 2],
       [2, 3, 0, 1],
       [1, 2, 3, 0]])

and also directly apply the group action to any array (i.e., spin configurations):

>>> hi = nk.hilbert.Spin(s=1/2, N=g.n_nodes)
>>> G(hi.random_state())
array([[ 1.,  1., -1., -1.],
       [-1.,  1.,  1., -1.],
       [-1., -1.,  1.,  1.],
       [ 1., -1., -1.,  1.]])

The key feature is that we can combine the symmetries via @ (corresponding to a direct products of the groups):

>>> g = nk.graph.Grid([4, 3])
>>> g.translations(dim=0, period=2) @ g.translations(dim=1)
SymmGroup(
  Id(),
  T(0, 1),
  T(0, 2),
  T(2, 0),
  T(2, 1),
  T(2, 2)
)
>>> np.asarray((g.translations(dim=0, period=2) @ g.translations(dim=1)))
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 2,  0,  1,  5,  3,  4,  8,  6,  7, 11,  9, 10],
       [ 1,  2,  0,  4,  5,  3,  7,  8,  6, 10, 11,  9],
       [ 6,  7,  8,  9, 10, 11,  0,  1,  2,  3,  4,  5],
       [ 8,  6,  7, 11,  9, 10,  2,  0,  1,  5,  3,  4],
       [ 7,  8,  6, 10, 11,  9,  1,  2,  0,  4,  5,  3]])

It is now much easier to check, however, that the combination of symmetries one has chosen is indeed the desired one (as the group structure can be shown explicitly instead of users having to reverse engineer it from looking at indices).

For convenience, SymmGroup can be passed directly to RBMSymm etc.

ma = nk.models.RBMSymm(g.translations(), ...)  # still works

Adding new symmetries

The symmetry group elements are essentially Callables that perform the corresponding permutation on an input array. For example, if we want to define the operation of inverting the indices on a 1D chain (i.e., [1, 2, 3] -> [3, 2, 1]), we can do this as follows:

>>> g = nk.graph.Chain(3)
>>> from netket.utils.semigroup import Identity, NamedElement
>>> from netket.graph import SymmGroup
>>> def inversion(seq):
        return seq[::-1]
>>> G = SymmGroup([Identity(), NamedElement("Inv", inversion)], graph=g)
>>> G
SymmGroup(
  Id(),
  Inv()
)
>>> g.translations() @ G
SymmGroup(
  Id(),
  Inv(),
  T(1,),
  T(1,) @ Inv(),
  T(2,),
  T(2,) @ Inv()
)
>>> (g.translations() @ G).to_array()
array([[0, 1, 2],
       [2, 1, 0],
       [2, 0, 1],
       [0, 2, 1],
       [1, 2, 0],
       [1, 0, 2]])

Merge request reports