Jax improvements + remove outdtype
Created by: PhilipVinc
This is extracted from #490, as it is more general.
This PR does mainly two things:
-
Remove all the logic to select the correct R->R, R->C, C->C vjp/grad kernel from JaxMachine and puts it in the jvp kernel itself.
- Before, when declaring a jax machine you had to declare the (homogeneous) type of all parameters as well as the expected output type. With this PR, both are inferred automatically.
- An added benefit is that our custom
vjp
now supports any arbitrary mix of real and complex parameters! - All this inference logic is used inside
@jit
blocks, so it has no runtime cost.
-
Remove all machine member
outdtype
, as it is no longer needed (as described above, it's inferred when needed).
An extra change is in the way netket.vmc_common.tree_map
is defined: if there is no jax, then we still use our definition, but if jax is present we use jax's version, which is slightly more general and will allow us to work with Flax and other frameworks.
While this PR does not yet allow us to use arbitrary weights, it brings us almost at the finish line.