Support new PRNGKey structural interface
Created by: PhilipVinc
Jax is currently experimenting with a new interface where the PRNGKey is a pytree instead of an array, in order to make their code more readable and safe, and also in order to support different kinds of PRNGs (They recently merged a new PRNG generator called RGB). https://github.com/google/jax/pull/6899
Anyhow, those changes are hidden for the time being, unless you use export JAX_ENABLE_CUSTOM_PRNG=1
becaues I asked them not to break the API immediately.
The changes in this PR are needed to make NetKet work with Jax when they will switch to this interface.