Equivariant Convolutions
Created by: chrisrothUT
Hi everybody,
I now have a working implementation of G-CNNs that works with a generic permutation array. This PR seems like a lot of changes because it also includes #600 and #611, so I'll keep it a draft until those are resolved. Since there's a lot of changes the main things I've added are graph.group_algebra() and graph.inverse() for computing the kernel mapping, nn.DenseEquivariant(), and models.GCNN(). I also added test_GCNN which is similar to test_rbmsymm.
Like RBMSymm, it can take an AbstractGraph as an argument:
graph = nk.graph.Lattice(basis_vectors=[[1.,0.],[-1./2.,np.sqrt(3)/2]],extent=[2,2],atoms_coord=[[0.,0.],[-1./4.,np.sqrt(3)/4.],[1./2.,0.]])
ma = nk.models.GCNN( permutations = graph, layers = 4, features = 4, )
And we're off! This model would be a four layer GCNN on the 12-site Kagome lattice. The first layer is DenseSymm, and the other three are DenseEquivariant which has kernel=[n_symm=48,features,features] and convolves over poses based on their relative symmetry orientations. Like RBMSymm, this will produce identical outputs for each automorphism of the lattice, i.e. this test will pass:
hi = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes) pars = ma.init(nk.jax.PRNGKey(), hi.random_state(1))
v = hi.random_state(3) vals = [ma.apply(pars, v[..., p]) for p in graph.automorphisms()]
for val in vals: assert jnp.allclose(val, vals[0])
A few questions/comments:
- What would be a good test for the layer DenseEquivariant?
- I call reshape twice in DenseEquivariant, perhaps there's a way to only call it once?
- Right now this only implements a dense convolution, but it would be good to implement sparser convolutions. This requires information about the symmetry group, not just a list of automorphisms.