Callbacks and MPI Syncronization -> deadlocks. what to do?
Created by: PhilipVinc
Right now the Callbacks are executed on all MPI ranks. That is different from loggers, which are only executed on the root rank.
# Log only non-root nodes
if self._mynode == 0:
# if out is a path, create an overwriting Json Log for output
if isinstance(out, str):
loggers = (JsonLog(out, "w", save_params_every, write_every),)
else:
loggers = _to_iterable(out)
else:
loggers = tuple()
show_progress = False
callbacks = _to_iterable(callback)
callback_stop = False
for step in itr:
for callback in callbacks:
if not callback(step, log_data, self):
callback_stop = True
for logger in loggers:
logger(self.step_count, log_data, self.state)
if callback_stop:
break
That means that if a callabkack signals the driver to stop the evolution only on one rank (for whatever reason), then that rank will stop while the overs will continue leading to a deadlock.
We have two solutions: run the callbacks only on the root node and then send the signal to stop to all other nodes, or run the callbacks on all nodes, and run MPI.Allreduce(MPI.LOR)
(logical or) on the callback_stop
among all ranks so that if at least one signals to stop, all of them stop.
I would go with the latter option MPI.Allreduce(MPI.LOR)
. @femtobit do you agree?