Fix JaxSR
Created by: PhilipVinc
Apparently the SR implementation for jax by @chrisrothUT did not work with R->C machines, because of jitting through non-pure functions.
This PR rewrites a big chunk of JaxSR to make it only use pure functions, and incidentally now works with MPI.
It also changes the way we initialise SR classes: instead of setting has_complex_parameters
like we were doing before, I created a function sr.setup(machine)
that sets this up (and methods to flatten/unflatten).
I'd like to add a few more test to check this case, but otherwise this can already be reviewed.
There is also a tiny bug fix in the last commit: now if stats.error_of_mean
is a nan we don't crash upon display.