[RFC] Automatic tuning of chunk size in VMC driver
Created by: wdphy16
This trick already helped me a lot when training large models. We start by setting the chunk size to a large number, and if it causes OOM, we reduce it by half each time. A good initial value can be n_samples_per_rank * hilbert.size * a multiplier
, because for each sample, the number of connected configurations in the local energy is in the order of hilbert.size
. Following Clemens' advice, we keep the chunk size to be a power of 2 for performance reasons.
The implementation can be
class VMCAutoChunk(VMC):
def __init__(self, *args, **kwargs):
init_chunk_size_multiplier = kwargs.pop("init_chunk_size_multiplier", 16)
min_chunk_size = kwargs.pop("min_chunk_size", None)
super().__init__(*args, **kwargs)
chunk_size = self.state.chunk_size
# If `state.chunk_size` is already set, we use that as the initial value
if chunk_size is None:
chunk_size = self.state.n_samples_per_rank * self._ham.hilbert.size * init_chunk_size_multiplier
# Round up to a power of 2
chunk_size = 2 ** int(ceil(log2(chunk_size)))
self.state.chunk_size = chunk_size
if min_chunk_size is None:
min_chunk_size = self.state.n_samples_per_rank
def _forward_and_backward(self):
while True:
try:
return super()._forward_and_backward()
except RuntimeError as e:
chunk_size = self.state.chunk_size // 2
if chunk_size < min_chunk_size:
warnings.warn(f"Minimum chunk size {min_chunk_size} reached")
raise e
warnings.warn(f"Reducing chunk size to {chunk_size}")
# This driver modifies `state.chunk_size` in place
self.state.chunk_size = chunk_size
If we want to integrate this into NetKet, we can discuss more about the API. Similar tricks are already implemented in some high-level ML frameworks like toma and PyTorch Lightning.