`DenseSymm` supports input features
Created by: attila-i-szabo
This PR adds the optional argument in_features
to all implementations of DenseSymm
. This is useful if the input has several features per lattice site, e.g., it is given in a one-hot representation. By default, in_features=1
, in which case no extra feature dimension is needed in the input (this is the old behaviour of DenseSymm
). Tests are added for both usages.
Implementation details were remodelled based on DenseEquivariant
, e.g. the symmetrizer_col
function was removed in favour of the cleaner indexing solution used there.
A quasi-bug in parity-symmetric GCNN is also fixed. With use_bias=True
, it used to add a bias to both the equivariant
and equivariant_flip
layers. However, these are all added to the same terms, so one of them is redundant. The network output is still valid, but having two variables that do the exact same thing is wasteful and introduces singularities in the QGT. I've removed the bias from equivariant_flip
, which solves this problem.
Before merging, I'd also like to clean up the initialisation of these layers/networks, which is a bit of a mess now. At this point, all kernels are of the shape [out_features, in_features, n_symm or n_sites]
, so partial(lecun_normal, in_axis=1, out_axis=0)
seems like the sensible initialiser for every layer type that we offer. However, we define some complicated logic for a default_densesymm_initializer
, which is basically the same and it's not clear when it would be used; also, we randomly mix in lecun_normal
without adjusting for the axes layout and the rough approximation unit_normal_scaling
... Also, now that jax.nn.initializers
(at least since v0.2.21) normalises complex numbers properly, do we still need to offer our version?
Or should I leave this to a next PR?