Minor optimizations to autoregressive layers
Created by: wdphy16
These changes were proposed in a discussion with @PhilipVinc many days ago:
- Change
lax.cond
tojnp.where
, so we don't need to wait for CPU to evaluate the control flow. This results in more than 30% speedup for the scriptBenchmarks/fast_autoreg.py
on my workstation. By the way, I didn't find any significant difference betweenjnp.where
and some manual arithmetic likex * a + (1 - x) * b
. - Evaluate some constants in numpy rather than jnp, as JAX does not do well in constant folding yet.
- Copy-paste the private API
flax.linen.linear._conv_dimension_numbers
into our code.