SR onthefly: avoid converting to real
Created by: inailuig
For centered=True
the mean of the gradients ⟨Oₖ⟩ has to be precomputed in order to define the centered log-wavefunction (i.e. the function whose gradient is ΔOₖ). For non-homogeneous (and also non-holomorphic ℂ→ℂ) parameters this is currently done by converting all parameters to real ones.
This can be avoided by storing the gradients Oₖ ( more specifically the mean ⟨Oₖ⟩) or as two (pytrees of) complex numbers, representing the top and bottom row of the corresponding 2x2 real jacobian.
In the following I put together a short sketch for proving that the centered=False
is correct for non-holomorphic ℂ→ℂ (actually it works universally), while deriving what this PR is doing (and what the next one for jacobianpytree will do).
setup
Following the Autodiff Cookbook one can consider a general complex function
The corresponding (vector-valued) real function is given by
Its jacobian is a real 2x2 matrix
One can represent this jacobian with two complex numbers by doing the transformation
which correspond to the jacobian of the real part u and immaginary part v of f.
VJP
Jax' complex vjp of with a complex number is defined as
and it is easy to see that one can calculate and by backpropagating and respectively.
real VJP
The corresponding real vjp is given by
and it's corresponding complex number is . Up to complex conjugations it is equivalent to the complex vjp above.
The correct (in terms of the real 2x2 matrix) vjp can thus be implemented with a complex jax vjp by conjugating both z and the result, i.e. when denoting the complex jax-vjp as one calculates the correct one as . The inverse of this statement is that the complex jax-vjp can be defined as , which is clear from fact that the transformation is trivially self-inverse. Effectively this transformation means that the correct vjp corresponds to a complex vector product with the hermitian tanspose of the jacobian since (up to transposition).
By doing some rearrangements the correct vjp in terms of and can be expressed as
When implemented with the complex vjp's from above this becomes
JVP
Jax' complex jvp of with a complex number is defined as
by taking the real parts and expanding the terms inside it can be shown that this is equivalent to
real JVP
The equivalent real jvp is given as
which is exactly the same as the complex jax jvp above. Thus the complex jax-jvp corresponds to the correct (in terms of the real 2x2 matrix) real jvp.
holomorphic
for a holomorphic wavefunction the Cauchy-Riemann equations
are satisfied.
Adding i times the second equation to the first one it follows that Jᵤ = -i Jᵥ. This means it is enough to pre-compute Jᵤ.
VJP
with Jᵥ* = - i Jᵤ*:
JVP
with Jᵥ* = - i Jᵤ*:
-
For ℝ→ℂ one can do the same as for ℂ→ℂ, in this case Jᵤ and Jᵥ as well as the input vector of the vjp and output of the vjp are real, so most of the conjugations simply won't have any effect.
-
Same for the real parts of pytrees with mixed (non-homogeneous) ℝ&ℂ parameters
-
For ℂ→ℝ one can do the same as for ℂ→ℂ while dropping all terms with v/Jᵥ since v≡0.
This PR Implements the computation of the means by doing vjp's with 1 and -i,
as well as the dot-product (similarly to a jvp)
which is needed for correctly defining the centered log-wavefunction.
The next PR will be doing the same for qgt-jacobian-pytree, i.e. using and .
I have implemented it already here, however iirc the scale-invariant regularization still needs some work.
Also we should maybe re-discuss the need/meaning of having both mode and holomorphic parameters in jacobian-pytree, and analogously add a holomorphic parameter to qgt_onthefly.py (which is only needed when centered=True).
Alternatively we should check wether centered=False really has numerical issues, which, if it hasn't would allow us to remove centered=True and make this (but not the next) PR obsolete.