Improve generated code by nkjax.tree_to_real
Created by: PhilipVinc
This commit changes the transpose of nkjax.tree_to_real.
Before this commit, it was autogenerated using jax.linear_transpose, but jax was failing to infer that the transpose of (x.real, x.imag)
is x
(see discussion in google/jax#8816 ). With this PR we use explicitly jax.lax.complex and manually construct the transpose, which gives a speedup.
This gives a minor speedup to JacobianPyTree
TODO: add a test