Skip to content

Support new PRNGKey structural interface

Vicentini Filippo requested to merge github/fork/PhilipVinc/pv/prng into master

Created by: PhilipVinc

Jax is currently experimenting with a new interface where the PRNGKey is a pytree instead of an array, in order to make their code more readable and safe, and also in order to support different kinds of PRNGs (They recently merged a new PRNG generator called RGB). https://github.com/google/jax/pull/6899

Anyhow, those changes are hidden for the time being, unless you use export JAX_ENABLE_CUSTOM_PRNG=1 becaues I asked them not to break the API immediately.

The changes in this PR are needed to make NetKet work with Jax when they will switch to this interface.

Merge request reports