Skip to content

Fix JaxSR

Vicentini Filippo requested to merge PhilipVinc/fix_jaxsr into v3.0

Created by: PhilipVinc

Apparently the SR implementation for jax by @chrisrothUT did not work with R->C machines, because of jitting through non-pure functions.

This PR rewrites a big chunk of JaxSR to make it only use pure functions, and incidentally now works with MPI.

It also changes the way we initialise SR classes: instead of setting has_complex_parameters like we were doing before, I created a function sr.setup(machine) that sets this up (and methods to flatten/unflatten).

I'd like to add a few more test to check this case, but otherwise this can already be reviewed.

There is also a tiny bug fix in the last commit: now if stats.error_of_mean is a nan we don't crash upon display.

Merge request reports