Skip to content

Symmetries in RBMSymm can (must) be HashableArrays

Created by: VolodyaCO

When passing symmetries to RBMSymm as a custom array containing index permutations, the type of symmetries is np.ndarray, which is non-hashable. Thus, nk.jax.HashablePartial raises a hashing error when trying to hash the arguments of the RBMSymm model. However, when passing symmetries to RBMSymm as a nk.utils.HashableArray, the logic in the nk.nn.DenseSymm function does not recognise the symmetries. This PR fixes this. However, it is an ugly fix.

What I would like to do is for RBMSymm to initialise automatically the RBMSymm.symmetries attribute to be a HashableArray, even if the user passes a np.ndarray. We can't do this conversion in the setup method of RBMSymm because setup is like a lazy init, thus not fixing the problem. If you come up with a way to do this, I think it's a much nicer solution than what I'm offering in this PR.

Merge request reports