Skip to content

Equivariant Convolutions

Vicentini Filippo requested to merge github/fork/chrisrothUT/equivariant into master

Created by: chrisrothUT

Hi everybody,

I now have a working implementation of G-CNNs that works with a generic permutation array. This PR seems like a lot of changes because it also includes #600 and #611, so I'll keep it a draft until those are resolved. Since there's a lot of changes the main things I've added are graph.group_algebra() and graph.inverse() for computing the kernel mapping, nn.DenseEquivariant(), and models.GCNN(). I also added test_GCNN which is similar to test_rbmsymm.

Like RBMSymm, it can take an AbstractGraph as an argument:

graph = nk.graph.Lattice(basis_vectors=[[1.,0.],[-1./2.,np.sqrt(3)/2]],extent=[2,2],atoms_coord=[[0.,0.],[-1./4.,np.sqrt(3)/4.],[1./2.,0.]])

ma = nk.models.GCNN( permutations = graph, layers = 4, features = 4, )

And we're off! This model would be a four layer GCNN on the 12-site Kagome lattice. The first layer is DenseSymm, and the other three are DenseEquivariant which has kernel=[n_symm=48,features,features] and convolves over poses based on their relative symmetry orientations. Like RBMSymm, this will produce identical outputs for each automorphism of the lattice, i.e. this test will pass:

hi = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes) pars = ma.init(nk.jax.PRNGKey(), hi.random_state(1))

v = hi.random_state(3) vals = [ma.apply(pars, v[..., p]) for p in graph.automorphisms()]

for val in vals: assert jnp.allclose(val, vals[0])

A few questions/comments:

  • What would be a good test for the layer DenseEquivariant?
  • I call reshape twice in DenseEquivariant, perhaps there's a way to only call it once?
  • Right now this only implements a dense convolution, but it would be good to implement sparser convolutions. This requires information about the symmetry group, not just a list of automorphisms.

Merge request reports