Skip to content
Snippets Groups Projects

Improve generated code by nkjax.tree_to_real

Open Vicentini Filippo requested to merge pv/impro into master

Created by: PhilipVinc

This commit changes the transpose of nkjax.tree_to_real. Before this commit, it was autogenerated using jax.linear_transpose, but jax was failing to infer that the transpose of (x.real, x.imag) is x (see discussion in google/jax#8816 ). With this PR we use explicitly jax.lax.complex and manually construct the transpose, which gives a speedup.

This gives a minor speedup to JacobianPyTree

TODO: add a test

Merge request reports

Loading
Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
Please register or sign in to reply
Loading