Skip to content

[RFC] Automatic tuning of chunk size in VMC driver

Created by: wdphy16

This trick already helped me a lot when training large models. We start by setting the chunk size to a large number, and if it causes OOM, we reduce it by half each time. A good initial value can be n_samples_per_rank * hilbert.size * a multiplier, because for each sample, the number of connected configurations in the local energy is in the order of hilbert.size. Following Clemens' advice, we keep the chunk size to be a power of 2 for performance reasons.

The implementation can be

class VMCAutoChunk(VMC):
    def __init__(self, *args, **kwargs):
        init_chunk_size_multiplier = kwargs.pop("init_chunk_size_multiplier", 16)
        min_chunk_size = kwargs.pop("min_chunk_size", None)
    
        super().__init__(*args, **kwargs)

        chunk_size = self.state.chunk_size
        # If `state.chunk_size` is already set, we use that as the initial value
        if chunk_size is None:
            chunk_size = self.state.n_samples_per_rank * self._ham.hilbert.size * init_chunk_size_multiplier
        # Round up to a power of 2
        chunk_size = 2 ** int(ceil(log2(chunk_size)))
        self.state.chunk_size = chunk_size
        
        if min_chunk_size is None:
            min_chunk_size = self.state.n_samples_per_rank

    def _forward_and_backward(self):
        while True:
            try:
                return super()._forward_and_backward()
            except RuntimeError as e:
                chunk_size = self.state.chunk_size // 2
                if chunk_size < min_chunk_size:
                    warnings.warn(f"Minimum chunk size {min_chunk_size} reached")
                    raise e
                warnings.warn(f"Reducing chunk size to {chunk_size}")
                # This driver modifies `state.chunk_size` in place
                self.state.chunk_size = chunk_size

If we want to integrate this into NetKet, we can discuss more about the API. Similar tricks are already implemented in some high-level ML frameworks like toma and PyTorch Lightning.