Skip to content

Jax improvements + remove outdtype

Vicentini Filippo requested to merge PhilipVinc/jaximpro into master

Created by: PhilipVinc

This is extracted from #490, as it is more general.

This PR does mainly two things:

  • Remove all the logic to select the correct R->R, R->C, C->C vjp/grad kernel from JaxMachine and puts it in the jvp kernel itself.

    • Before, when declaring a jax machine you had to declare the (homogeneous) type of all parameters as well as the expected output type. With this PR, both are inferred automatically.
    • An added benefit is that our custom vjp now supports any arbitrary mix of real and complex parameters!
    • All this inference logic is used inside @jit blocks, so it has no runtime cost.
  • Remove all machine member outdtype, as it is no longer needed (as described above, it's inferred when needed).

An extra change is in the way netket.vmc_common.tree_map is defined: if there is no jax, then we still use our definition, but if jax is present we use jax's version, which is slightly more general and will allow us to work with Flax and other frameworks.

While this PR does not yet allow us to use arbitrary weights, it brings us almost at the finish line.

Merge request reports