Skip to content

Fix Squared Operator

Vicentini Filippo requested to merge pv/sq2 into master

Created by: PhilipVinc

Cherry picked from #1065 so that i separate the two changes in two different PRs.

Computing the gradient of operators that use nkjax.expect instead of the covariance formula (such as SquaredOperator) also had a wrong factor of 2 for C->C models. This is due to the weird way that Jax handles complex differentiation.

This PR now fixes it, and a test is added to check the gradient wrt finite differences.

Also, this PR moves out into a common test file the finite difference functions so that they can be used for other tests.

Merge request reports