Skip to content

(Excessively) out of memory in `expect_and_grad`

Created by: attila-i-szabo

I have tried to run a moderately sized model on a single GPU (no MPI). Vmc::expect_and_grad crashes when it attempts to allocate about 1TB of memory:

2021-05-15 17:06:18.153586: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Resource exhausted: Allocating 174487240704 bytes exceeds the memory limit of 4294967296 bytes.

Convolution performance may be suboptimal.
2021-05-15 17:06:18.159994: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Resource exhausted: Allocating 465299308544 bytes exceeds the memory limit of 4294967296 bytes.

Convolution performance may be suboptimal.
2021-05-15 17:06:18.161074: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Resource exhausted: Allocating 465299308544 bytes exceeds the memory limit of 4294967296 bytes.

Convolution performance may be suboptimal.
2021-05-15 17:06:33.088411: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:457] Allocator (GPU_0_bfc) ran out of memory trying to allocate 977.79GiB (rounded to 1049890968832)requested by op 
2021-05-15 17:06:33.088942: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:468] *****_______________________________________________________________________________________________
2021-05-15 17:06:33.089058: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1981] Execution of replica 0 failed: Resource exhausted: Out of memory while trying to allocate 1049890968768 bytes.
Traceback (most recent call last):
  File "/home/vol06/scarf1036/NQS/python/expt1.py", line 88, in <module>
    vmc.run(n_iter=1000, out=f"{output_dir}/expt1")
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
    for step in itr:
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/netket/driver/abstract_variational_driver.py", line 167, in iter
    dp = self._forward_and_backward()
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/netket/driver/vmc.py", line 115, in _forward_and_backward
    self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/netket/variational/mc_state.py", line 469, in expect_and_grad
    Ō, Ō_grad, new_model_state = grad_expect_hermitian(
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 143, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/_src/api.py", line 426, in cache_miss
    out_flat = xla.xla_call(
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/core.py", line 1565, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/core.py", line 1556, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/core.py", line 1568, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/core.py", line 609, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
    return compiled_fun(*args)
  File "/home/vol06/scarf1036/.venv/netket/lib/python3.8/site-packages/jax/interpreters/xla.py", line 874, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Out of memory while trying to allocate 1049890968768 bytes.

It seems that the demand for all this memory comes from mc_state::grad_expect_hermitian, which only contains a single VJP. I checked that with the settings I was using (1024 samples on a 10x10 lattice), a VJP only requires 1.5 GB of memory, rather than the ~1TB that is attempted to be allocated. I cannot imagine what would require this much memory... If that's helpful for debugging, the code uses LAX-based convolutional layers and it's written in STAX. I have run the same code on CPUs also, and never had such a problem.