Skip to content

Minor optimizations to autoregressive layers

Vicentini Filippo requested to merge github/fork/wdphy16/autoreg into master

Created by: wdphy16

These changes were proposed in a discussion with @PhilipVinc many days ago:

  1. Change lax.cond to jnp.where, so we don't need to wait for CPU to evaluate the control flow. This results in more than 30% speedup for the script Benchmarks/fast_autoreg.py on my workstation. By the way, I didn't find any significant difference between jnp.where and some manual arithmetic like x * a + (1 - x) * b.
  2. Evaluate some constants in numpy rather than jnp, as JAX does not do well in constant folding yet.
  3. Copy-paste the private API flax.linen.linear._conv_dimension_numbers into our code.

Merge request reports