Fix shape check in jax construction
Created by: PhilipVinc
Output shapes (-1, 1)
and (-1)
are equivalent, and some jax functions can return one or the other, and our code works in both cases, so both are valid.
Created by: PhilipVinc
Output shapes (-1, 1)
and (-1)
are equivalent, and some jax functions can return one or the other, and our code works in both cases, so both are valid.