[NFC] refactor and Move around jax vjp and grad code
Created by: PhilipVinc
This code moves out the grad
(perex) and vjp
code that handle R->R, R->C and C->C code from the jaxMachine
constructor to another file.
This simply makes easier to work more on the JaxMachine
code in the future.
It should also reduce the amount of re-compilation, as now the jitted functions are global and not specific to every jax macchine.
Otherwise, this changes nothing