Add PRNGSeq type and some tests
Created by: femtobit
In order to make it a bit more convenient to derive PRNGKey
s from one initial key, this PR adds a PRNGSeq
type (inspired by Haiku's PRNGSequence
, though it does not share the implementation -- I think this simpler implementation should be sufficient for our purposes).
>>> hi = netket.hilbert.Spin(s=1/2, N=10)
>>> rseq = netket.jax.PRNGSeq(12)
>>> hi.random_state(next(rseq))
DeviceArray([ 1., -1., -1., -1., -1., -1., 1., -1., -1., -1.], dtype=float32)
>>> nk.variational.MCState(..., seed=next(rseq))
(This is a PR on branch nk3
, if it is merged into master
soon it can of course be retargeted.)
This also bumps the required version of mpi4jax>=0.2.10
, since the test of mpi_split
added here was failing for me without it.