Fix Squared Operator
Created by: PhilipVinc
Cherry picked from #1065 so that i separate the two changes in two different PRs.
Computing the gradient of operators that use nkjax.expect
instead of the covariance formula (such as SquaredOperator
) also had a wrong factor of 2 for C->C models.
This is due to the weird way that Jax handles complex differentiation.
This PR now fixes it, and a test is added to check the gradient wrt finite differences.
Also, this PR moves out into a common test file the finite difference functions so that they can be used for other tests.