Skip to content

Fix GCNN initializers

Vicentini Filippo requested to merge github/fork/attila-i-szabo/initializer into master

Created by: attila-i-szabo

Currently, the initializers used for GCNNs are a bit of a mess. We define several different initializers (some of which are manifestly incorrect, e.g. a jax.nn.lecun_normal that refers to the wrong axes), all of which are trumped in the GCNN classes by a custom variance-scaled normal distribution, which is however a factor of sqrt(2) too large for complex numbers. this only works because jax.random.normal is now defined to halve the variance of real/imaginary parts for complex dtypes.

This PR clears this up.

  • The one remaining initialiser in nk.nn is dropped for it is internal and now available in all versions of JAX that we support. This leaves only deprecations in the initializers.py file.
  • A single default_gcnn_initializer is defined in the same file as DenseSymm and DenseEquivariant, which is also exported into GCNN: this is the plain JAX lecun_normal instantiated with the correct axes, the one we ought to use with SELU.

Variance scaling for masks is also fixed by defining a scaled_mask whose norm is always the same as a mask full of 1s.

Possible TODOs:

  • Check the initialiser of RBMSymm. I have no intuition as to whether lecun_normal is appropriate for that too, so please weigh in; we can probably do better than the current normal(0.1)
  • More generally, define similar initialisers for non-GCNN layers like Dense

Merge request reports