[bug] nk.jax.expect is off by a factor of mpi.n_rank in the backward pass
Created by: PhilipVinc
it's just a multiplicative factor, but still...
discovered by Alessandro Sinibaldi
Created by: PhilipVinc
it's just a multiplicative factor, but still...
discovered by Alessandro Sinibaldi