Skip to content

Change `tree_multimap` to `tree_map` for jax 0.3.5

Vicentini Filippo requested to merge github/fork/wdphy16/tree_map into master

Created by: wdphy16

As tree_multimap is deprecated in jax 0.3.5 . There are still some deprecation warnings caused by optax, and they already did a PR to fix them (deepmind/optax#330).

Merge request reports