Skip to content

SR onthefly: avoid converting to real

Vicentini Filippo requested to merge github/fork/inailuig/sr-otf-noconv-4 into master

Created by: inailuig

For centered=True the mean of the gradients ⟨Oₖ⟩ has to be precomputed in order to define the centered log-wavefunction (i.e. the function whose gradient is ΔOₖ). For non-homogeneous (and also non-holomorphic ℂ→ℂ) parameters this is currently done by converting all parameters to real ones.

This can be avoided by storing the gradients Oₖ ( more specifically the mean ⟨Oₖ⟩) or as two (pytrees of) complex numbers, representing the top and bottom row of the corresponding 2x2 real jacobian.

In the following I put together a short sketch for proving that the centered=False is correct for non-holomorphic ℂ→ℂ (actually it works universally), while deriving what this PR is doing (and what the next one for jacobianpytree will do).


setup

Following the Autodiff Cookbook one can consider a general complex function f: \mathbb{C} \to \mathbb{C}

f(x+i y) = u(x,y) + i v(x,y)

The corresponding (vector-valued) real function g: \mathbb{R}^2 \to \mathbb{R}^2 is given by

g(x,y) = \begin{bmatrix} u(x,y) \\ v(x,y) \end{bmatrix}

Its jacobian is a real 2x2 matrix

J_g = \begin{bmatrix} \partial_x u & \partial_y u \\ \partial_x v & \partial_y v \end{bmatrix}

One can represent this jacobian with two complex numbers by doing the transformation

\begin{bmatrix} \partial_x u & \partial_y u \\ \partial_x v & \partial_y v \end{bmatrix}
\mapsto 
\begin{bmatrix} \partial_x u + i \partial_y u \\ \partial_x v +i \partial_y v \end{bmatrix}
\eqqcolon
\begin{bmatrix} J_u \\ J_v \end{bmatrix}

which correspond to the jacobian of the real part u and immaginary part v of f.

VJP

Jax' complex vjp of f with a complex number z = c+i d is defined as

\begin{bmatrix} c & -d \end{bmatrix}
J_g
\begin{bmatrix} 1  \\  {-i}  \end{bmatrix}
= \cdots
=  c (\partial_x  u- i \partial_y u) -  d (\partial_x  v- i \partial_y v)
= c J_u^* - d J_v^*

and it is easy to see that one can calculate J_u^* and J_v^* by backpropagating z = c + i d = 1 and z = c + i d = -i respectively.

real VJP

The corresponding real vjp is given by

\begin{bmatrix} c & d \end{bmatrix} J_g = \cdots = \begin{bmatrix} c \partial_x u + d \partial_x v & c \partial_y u + d \partial_y v \end{bmatrix}

and it's corresponding complex number is  (c \partial_x u + d \partial_x v) + i  (c \partial_y u + d \partial_y v). Up to complex conjugations it is equivalent to the complex vjp above.

The correct (in terms of the real 2x2 matrix) vjp can thus be implemented with a complex jax vjp by conjugating both z and the result, i.e. when denoting the complex jax-vjp as z \star J one calculates the correct one as z J = (z^* \star J)^*. The inverse of this statement is that the complex jax-vjp can be defined as z \star J \coloneqq (z^* J)^*, which is clear from fact that the transformation (z J) \mapsto (z^* J)^* is trivially self-inverse. Effectively this transformation means that the correct vjp corresponds to a complex vector product with the hermitian tanspose of the jacobian since J^H z = (z^H J)^H
(up to transposition).

By doing some rearrangements the correct vjp in terms of J_u and J_v can be expressed as c J_u + d J_v = \text{Re}\{z\}\ J_u + \text{Im}\{z\}\ J_v

When implemented with the complex z \star J vjp's from above this becomes \text{Re}\{z\}\ J_u + \text{Im}\{z\}\ J_v = (\text{Re}\{z^*\}\star J_u + \text{Im}\{z^*\}\star J_v)^* = \text{Re}\{z\}\star J_u^* - \text{Im}\{z\}\star J_v^*

JVP

Jax' complex jvp of f with a complex number z = c+i d is defined as

\begin{bmatrix} 1 & i \end{bmatrix}
J_g
\begin{bmatrix} c  \\  d  \end{bmatrix}
= \cdots
= (c \partial_x + d \partial_y) (u+iv)
= (c \partial_x + d \partial_y) u + i (c \partial_x + d \partial_y) v

by taking the real parts and expanding the terms inside it can be shown that this is equivalent to

\text{Re} \{J_u^* \ z \} + i \text{Re}\{J_v^* \  z \}

real JVP

The equivalent real jvp is given as

\begin{bmatrix} c & d \end{bmatrix}
J_g
= \cdots
= (c \partial_x+ d \partial_y) \begin{bmatrix} u \\ v \end{bmatrix}

which is exactly the same as the complex jax jvp above. Thus the complex jax-jvp corresponds to the correct (in terms of the real 2x2 matrix) real jvp.

holomorphic

for a holomorphic wavefunction the Cauchy-Riemann equations
\begin{aligned}
\partial_x u &= \partial_y v \\
\partial_y u &= -\partial_x v
\end{aligned}
are satisfied.

Adding i times the second equation to the first one it follows that Jᵤ = -i Jᵥ. This means it is enough to pre-compute Jᵤ.

VJP

with Jᵥ* = - i Jᵤ*:

\text{Re}\{z\}\star J_u^* + \text{Im}\{z\}\star J_v^* = 
\text{Re}\{z\}\star J_u^* + i \text{Im}\{z\}\star J_u^* = z \star J_u^*

JVP

with Jᵥ* = - i Jᵤ*:

\text{Re} \{J_u^* \ z \} + i \text{Re}\{J_v^* \  z \} = 
\text{Re} \{J_u^* \ z \} + i \text{Im}\{J_u^* \  z \} = J_u^* \  z


  • For ℝ→ℂ one can do the same as for ℂ→ℂ, in this case Jᵤ and Jᵥ as well as the input vector of the vjp and output of the vjp are real, so most of the conjugations simply won't have any effect.

  • Same for the real parts of pytrees with mixed (non-homogeneous) ℝ&ℂ parameters

  • For ℂ→ℝ one can do the same as for ℂ→ℂ while dropping all terms with v/Jᵥ since v≡0.


This PR Implements the computation of the means \langle J_R^* \rangle, \langle J_I^* \rangle by doing vjp's with 1 and -i,
as well as the dot-product (similarly to a jvp)
\theta \cdot \langle J \rangle \coloneqq \text{Re}\{\theta\}\cdot \langle J_u^* \rangle - \text{Im}\{\theta\}\cdot \langle J_v^*\rangle
which is needed for correctly defining the centered log-wavefunction.


The next PR will be doing the same for qgt-jacobian-pytree, i.e. using \text{Re} \{J_u^* \ z \} + i \text{Re}\{J_v^* \  z \} and \text{Re}\{z\}\star J_u^* + \text{Im}\{z\}\star J_v^*.

I have implemented it already here, however iirc the scale-invariant regularization still needs some work.


Also we should maybe re-discuss the need/meaning of having both mode and holomorphic parameters in jacobian-pytree, and analogously add a holomorphic parameter to qgt_onthefly.py (which is only needed when centered=True).

Alternatively we should check wether centered=False really has numerical issues, which, if it hasn't would allow us to remove centered=True and make this (but not the next) PR obsolete.

Merge request reports