remove support for jax 0.3 / Python 3.7
Created by: PhilipVinc
Jax dropped support for Python 3.7 (which is officially unsupported by the numerical ecosystem since December). PR #1461 for a fix requires recent versions of jax >=0.4, and I don't have the time to invest in finding workarounds to fix the bug with older jax versions.
As we have almost no users who install netket on python 3.7 (on average we have 1 download per day, vs 35 on python 3.9 and 40, and jax already dropped support for python 3.7, I would like to go down this way.
In any case, more modern tricks with jax.array and auto partitioning will require it so...