Remove `is_holomorphic` and replace with `has_complex_weights` (+ R->C bug fix)
Created by: PhilipVinc
After having harassed @gcarleo several times about our usage of is_holomorphic
, I finally did a PR to fix (in my opinion) it.
In essence:
- In netket our usage of
is_holomorphic
was to check if the function had real parameters, to dispatch to a more efficient SR version. - While R->R functions are not holomorphic, R->C functions are holomorphic, even if the Cauchy conditions have to be reformulated (my favourite read is this).
- If we have a C->C function that is not holomorphic netket breaks down anyways so...
I propose to change is_holomorphic
to has_complex_weights
. Now all machines specify their dtype
(either float or complex) at construction.
Machines also specify their output type, outdtype
. this defaults to the same as the dtype
, so if you don't specify it everything will stil work.
The main advantage in knowing the input and output type is that now we can correctly treat R->C wavefunctions during optimisation, and restrict the update to only the real part of the gradient. Previously this would fail (and was untested).
This is also used to speed up some jax gradients and vjp (from a previous PR).
This pr also:
- Introduces a new Jax Pure machine that is R->C, and tests that der_log and vjp works
- I also add a bunch of density matrix machines
- I fixed a bug with der log of jax R->C wavefunction that I missed in my last pr.