Symmetries in RBMSymm can (must) be HashableArrays
Created by: VolodyaCO
When passing symmetries to RBMSymm
as a custom array containing index permutations, the type of symmetries
is np.ndarray
, which is non-hashable. Thus, nk.jax.HashablePartial
raises a hashing error when trying to hash the arguments of the RBMSymm
model. However, when passing symmetries to RBMSymm
as a nk.utils.HashableArray
, the logic in the nk.nn.DenseSymm
function does not recognise the symmetries. This PR fixes this. However, it is an ugly fix.
What I would like to do is for RBMSymm
to initialise automatically the RBMSymm.symmetries
attribute to be a HashableArray
, even if the user passes a np.ndarray
. We can't do this conversion in the setup
method of RBMSymm
because setup
is like a lazy init, thus not fixing the problem. If you come up with a way to do this, I think it's a much nicer solution than what I'm offering in this PR.