[MPI] add a token-aware tree map
Created by: inailuig
Recently, while running VMC, I have been seeing MPI errors (on cpu) which look like
0%| | 0/1000 [00:00<?, ?it/s]r0 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 0 in communicator MPI COMMUNICATOR 4 CREATE FROM 0
with errorcode 15.
NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes.
You may or may not see output from other processes, depending on
exactly when Open MPI kills them.
I believe they could be caused by reordered mpi allreduce commands over the different leaves in https://github.com/netket/netket/blob/fae7c32914ddc041b7fe4c5fb74febe1ac8264b1/netket/optimizer/qgt/qgt_onthefly_logic.py#L55
This is an attempt to fix this by enforcing consistent ordering using tokens.
Tests will fail because of the solver trying to transpose the tokens with AD. @PhilipVinc any ideas?