Skip to content

Improve generated code by nkjax.tree_to_real

Vicentini Filippo requested to merge pv/impro into master

Created by: PhilipVinc

This commit changes the transpose of nkjax.tree_to_real. Before this commit, it was autogenerated using jax.linear_transpose, but jax was failing to infer that the transpose of (x.real, x.imag) is x (see discussion in google/jax#8816 ). With this PR we use explicitly jax.lax.complex and manually construct the transpose, which gives a speedup.

This gives a minor speedup to JacobianPyTree

TODO: add a test

Merge request reports