Skip to content

Add PRNGSeq type and some tests

Vicentini Filippo requested to merge github/fork/femtobit/nk3-prngseq into nk3

Created by: femtobit

In order to make it a bit more convenient to derive PRNGKeys from one initial key, this PR adds a PRNGSeq type (inspired by Haiku's PRNGSequence, though it does not share the implementation -- I think this simpler implementation should be sufficient for our purposes).

>>> hi = netket.hilbert.Spin(s=1/2, N=10)
>>> rseq = netket.jax.PRNGSeq(12)
>>> hi.random_state(next(rseq))
DeviceArray([ 1., -1., -1., -1., -1., -1.,  1., -1., -1., -1.], dtype=float32)
>>> nk.variational.MCState(..., seed=next(rseq))

(This is a PR on branch nk3, if it is merged into master soon it can of course be retargeted.)

This also bumps the required version of mpi4jax>=0.2.10, since the test of mpi_split added here was failing for me without it.

Merge request reports