Speed up DenseSymm by removing dense matvec product
Created by: femtobit
The symmetrizer
in DenseSymm
is a very sparse matrix. While jax
does not yet support sparse operations, we can just do this manually (and more efficiently) in this case.
In fact, symmetrizer
is a COO matrix with
rows = arange(sites * n_hidden)
cols = ...
data = [1., ..., 1.]
so only cols
need to be stored and the "matrix-vector product" to map reduced to full weights reduces to
full_kernel = (kernel.reshape(-1)[cols]).reshape(self.n_sites, -1)
This saves memory for storing symmetrizer
, which becomes important in larger systems, and is also faster:
At least on the CPU, the comparison looks very much in favor of the indexing version (n_sites=100
, features=8
):
In[1]: f1 = jax.jit(lambda w: (w.reshape(-1)[symm_cols]).reshape(n_sites, -1))
In[2]: f2 = jax.jit(lambda w: jnp.matmul(symmetrizer, w.reshape(-1)).reshape(n_sites, -1))
In[3]: f1(kernel); f2(kernel); # pre-jit
In[4]: %timeit f1(kernel).block_until_ready()
101 µs ± 1.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In[5]: %timeit f2(kernel).block_until_ready()
15.7 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Still todo:
-
The code constructs the full symmetrizer
and then converts it to a COO matrix. This is of course redundant - that code can be rewritten to just createcols
directly. I'll do that soon; it just requires some more fiddling with indices.