Issue with `take` call in `DenseSymmIrrep`
Created by: attila-i-szabo
When using any model containing DenseSymmIrrep
with complex kernels with newer JAX versions, I get a variation of the following error:
Traceback (most recent call last):
File "/home/vol06/scarf1036/NQS/new.py", line 164, in <module>
gs.run(out=ma_name, n_iter=niter)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
for step in self.iter(n_iter, step_size):
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/driver/abstract_variational_driver.py", line 168, in iter
dp = self._forward_and_backward()
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/driver/vmc.py", line 132, in _forward_and_backward
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/state.py", line 595, in expect_and_grad
return expect_and_grad(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/base.py", line 381, in expect_and_grad
return expect_and_grad(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/expect_grad_chunked.py", line 49, in expect_and_grad_nochunking
return expect_and_grad(vstate, operator, use_covariance, *args, **kwargs)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/expect_grad.py", line 54, in expect_and_grad
Ō, Ō_grad, new_model_state = grad_expect_hermitian(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/expect_grad.py", line 149, in grad_expect_hermitian
_, vjp_fun, *new_model_state = nkjax.vjp(
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/jax/_vjp.py", line 156, in vjp
return vjp_cc(fun, *primals, has_aux=has_aux, conjugate=conjugate)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/jax/_vjp.py", line 41, in vjp_cc
out, _vjp_fun = jax.vjp(fun, *primals, has_aux=False)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/expect_grad.py", line 150, in <lambda>
lambda w: model_apply_fun({"params": w, **model_state}, σ, mutable=mutable),
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/utils/jax.py", line 76, in maybe_scalar_fun
res = apply_fun(pars, xb, *args, **kwargs)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/vqs/mc/mc_state/state.py", line 191, in <lambda>
lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs),
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/models/equivariant.py", line 535, in __call__
x_flip = self.dense_symm(-1 * x)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/netket/nn/symmetric_linear.py", line 120, in __call__
kernel = jnp.take(kernel, jnp.asarray(self.symmetries), 2)
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3339, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/home/vol06/scarf1036/.venv/netket/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 3397, in _take
return lax.gather(a, indices[..., None], dimension_numbers=dnums,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: 'NoneType' object is not iterable
The error message flags up this take
function call:
https://github.com/netket/netket/blob/1ba051e332614b23cf4d876becdad92904d2d017/netket/nn/symmetric_linear.py#L117-L120
and indeed, replacing line 120 with
kernel = kernel[:,:,jnp.asarray(self.symmetries)]
fixes the issue. I can submit this fix as a PR, but it would be good to figure out what's going on—the two bits of code are equivalent, so it's worrying that one fails. Here's what we know so far:
- The issue is there for JAX 0.3.4 but not for JAX 0.2.25
- It only affects
DenseSymmIrrep
because only that usestake
(and probablyDenseSymmMatrix
but we don't use that for anything serious) - It only affects complex kernels, not real ones
- It only affects the backward pass (the exception is raised from
vjp
after local energies are computed) - It only affects the GPU backend, but not the CPU one
The last one in particular suggests this is a JAX issue. I tried to produce a MNWE that doesn't use NetKet functionality but didn't succeed so far.