Change `tree_multimap` to `tree_map` for jax 0.3.5
Created by: wdphy16
As tree_multimap
is deprecated in jax 0.3.5 . There are still some deprecation warnings caused by optax, and they already did a PR to fix them (deepmind/optax#330).
Created by: wdphy16
As tree_multimap
is deprecated in jax 0.3.5 . There are still some deprecation warnings caused by optax, and they already did a PR to fix them (deepmind/optax#330).