Skip to content

[NFC] refactor and Move around jax vjp and grad code

Vicentini Filippo requested to merge jax_stuff into master

Created by: PhilipVinc

This code moves out the grad (perex) and vjp code that handle R->R, R->C and C->C code from the jaxMachine constructor to another file.

This simply makes easier to work more on the JaxMachine code in the future. It should also reduce the amount of re-compilation, as now the jitted functions are global and not specific to every jax macchine.

Otherwise, this changes nothing

Merge request reports