Skip to content

SR with precomputed gradients

Created by: attila-i-szabo

I have implemented the SR algorithm with precomputed gradients. This is less elegant than the lazy vjp/jvp-based implementation in LazySMatrix, but has practical advantages:

  • It has minimal memory overhead: precomputing the gradients requires the same amount of memory as a single pass of vjp since the different samples on which the neural network is run are independent, so we only need to compute one of them to get a row of the Jacobian. (I.e., instead of looping through vjp with vmap as jacrev does, we can loop through grad.) Storing the matrix of gradients is guaranteed to take up less memory: backpropagation has to store a lot of internal information about the forward pass, which takes up many times the memory needed for the gradients in a deep network (In my experiments, they were 40 MB vs 1.2 GB). The bottom line is that if there is enough memory for a single vjp, there is enough memory for this too.
  • It yields a massive speedup: the gradient matrix can again be calculated in the same asymptotic time as a single vjp (it will of course be a bit slower because vmap is used, but not by large factors); afterwards, we only need matrix-vector multiplications, which are way faster than a full backpropagation of a complex neural network. In my experiments (on 20 CPUs), I could reduce the time of an SR step from 8x that with Adam to about 1.2x.
  • It allows for regularising the S matrix in a scale-invariant way by factoring out the magnitude of diagonal elements, as described in Becca & Sorella, p. 143. This has minimal overhead and can be crucial for heterogeneous networks that may have very different gradients in different parts.

Unfortunately, calculating the full Jacobian cannot be done as dtype-agnostically as a VJP, so some information about the network structure needs to be passed. This is done through the jacobian parameter, which is implemented for the values "R2R" (both the network entries and the wave function are real), "R2C" (real entries, complex wave function), and "holomorphic" (holomorphic function of complex parameters), None (uses a LazySMatrix).

  • The parameter could use a better name, but jacobian also signifies it is a switch for using this code vs. the original one
  • It may be possible to automate this choice, although it's not trivial (e.g., R2R is the right choice even if the wave function has negative entries, which makes the output complex...)
  • A few more cases might be implemented, e.g. non-holomorphic C2C, although that is just sugar for R2C...

The boolean parameter rescale_shiftspecifies whether the scale-invariant regularisation should be used.

Merge request reports