Skip to content

Speed up DenseSymm by removing dense matvec product

Vicentini Filippo requested to merge github/fork/femtobit/densesymm into master

Created by: femtobit

The symmetrizer in DenseSymm is a very sparse matrix. While jax does not yet support sparse operations, we can just do this manually (and more efficiently) in this case.

In fact, symmetrizer is a COO matrix with

rows = arange(sites * n_hidden)
cols = ...
data = [1., ..., 1.]

so only cols need to be stored and the "matrix-vector product" to map reduced to full weights reduces to

full_kernel = (kernel.reshape(-1)[cols]).reshape(self.n_sites, -1)

This saves memory for storing symmetrizer, which becomes important in larger systems, and is also faster: At least on the CPU, the comparison looks very much in favor of the indexing version (n_sites=100, features=8):

In[1]: f1 = jax.jit(lambda w: (w.reshape(-1)[symm_cols]).reshape(n_sites, -1))
In[2]: f2 = jax.jit(lambda w: jnp.matmul(symmetrizer, w.reshape(-1)).reshape(n_sites, -1))
In[3]: f1(kernel); f2(kernel); # pre-jit
In[4]: %timeit f1(kernel).block_until_ready()
101 µs ± 1.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In[5]: %timeit f2(kernel).block_until_ready()
15.7 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Still todo:

  • The code constructs the full symmetrizer and then converts it to a COO matrix. This is of course redundant - that code can be rewritten to just create cols directly. I'll do that soon; it just requires some more fiddling with indices.

Merge request reports