Fix GCNN initializers
Created by: attila-i-szabo
Currently, the initializers used for GCNNs are a bit of a mess. We define several different initializers (some of which are manifestly incorrect, e.g. a jax.nn.lecun_normal
that refers to the wrong axes), all of which are trumped in the GCNN
classes by a custom variance-scaled normal distribution, which is however a factor of sqrt(2) too large for complex numbers. this only works because jax.random.normal
is now defined to halve the variance of real/imaginary parts for complex dtypes.
This PR clears this up.
- The one remaining initialiser in
nk.nn
is dropped for it is internal and now available in all versions of JAX that we support. This leaves only deprecations in theinitializers.py
file. - A single
default_gcnn_initializer
is defined in the same file asDenseSymm
andDenseEquivariant
, which is also exported intoGCNN
: this is the plain JAXlecun_normal
instantiated with the correct axes, the one we ought to use with SELU.
Variance scaling for masks is also fixed by defining a scaled_mask
whose norm is always the same as a mask full of 1s.
Possible TODOs:
- Check the initialiser of
RBMSymm
. I have no intuition as to whetherlecun_normal
is appropriate for that too, so please weigh in; we can probably do better than the currentnormal(0.1)
- More generally, define similar initialisers for non-GCNN layers like
Dense