SR with precomputed gradient pytrees
Created by: inailuig
This so far only adds the necessary logic, adding the interfaces around it should be straightforward though.
- The gradients (oks) are stored as a pytree where the first dim of each leaf is k (of size n_samples)
- Reduces to the usual two matrix-vector multiplications if a flattened model and parameters (and a flattened gradient matrix) are passed, all the tree maps should be optimized away by the jit complier
Works with inhomogeneous parameters (e.g. a mix of complex and real) just like the onthefly code does- works with R->R, holomprphic C->C and R->C
- non-homogenous parameters and non-holomorphic C->C requires converting to homogeneous real parameters
- is MPI enabled