Float0 error with chunked vjp for complex operators and int sampler dtype.
Created by: MandMarc
When training a VQS to optimize for the ground state of an operator with
- complex matrix elements (e.g. sigmay)
- chunking enabled and
- the dtype of the sampler set to jnp.int8, I get the following float0 error printed down below.
I'm using netket version 3.6, jax version 0.3.25. I suppose there is some problem with typecasting in the chunked version of expect and grad but it's possible I'm just using things wrong.
The minimal example to reproduce this error for me:
import netket as nk
import jax.numpy as jnp
hilbert = nk.hilbert.Spin(1/2, 2)
ha = nk.operator.LocalOperator(hilbert, dtype=complex)
ha += nk.operator.spin.sigmaz(hilbert, 0)
ha += nk.operator.spin.sigmay(hilbert, 1)
model = nk.models.RBM(alpha=2, param_dtype=float)
sampler = nk.sampler.MetropolisLocal(hilbert, n_chains=512, dtype=jnp.int8)
vstate = nk.vqs.MCState(sampler, model, n_samples=1024, chunk_size=256)
optimizer = nk.optimizer.Sgd(learning_rate=0.05)
precon = nk.optimizer.SR(nk.optimizer.qgt.QGTOnTheFly, diag_shift=0.01)
gs = nk.driver.VMC(ha, optimizer, variational_state=vstate, preconditioner=precon)
gs.run(100)
The full error message:
Traceback (most recent call last):
File "<input>", line 19, in <module>
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
for step in self.iter(n_iter, step_size):
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 168, in iter
dp = self._forward_and_backward()
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/vmc.py", line 169, in _forward_and_backward
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 634, in expect_and_grad
return expect_and_grad(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 584, in __call__
return method(*args, **kw_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/base.py", line 395, in expect_and_grad
return expect_and_grad(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_grad_chunked.py", line 75, in expect_and_grad_covariance_chunked
Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, mutable=mutable)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 81, in expect_and_forces_impl
Ō, Ō_grad, new_model_state = forces_expect_hermitian_chunked(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 145, in forces_expect_hermitian_chunked
Ō_grad = vjp_fun_chunked(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
return self.fun(*args, **kw)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 58, in __vjp_fun_chunked
res = scanmap(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in f_
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 73, in scan_append_reduce
carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0)))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 3201, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in <lambda>
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/utils.py", line 294, in <lambda>
return lambda *args, **kwargs: f(g(*args, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 29, in _vjp
res = vjp_fun(cotangents)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in vjp_fun
out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 207, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in <lambda>
out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/core.py", line 606, in __rmul__
def __rmul__(self, other): return self.aval._rmul(self, other)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4936, in deferring_binary_op
return binary_op(*args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 626, in cache_miss
top_trace.process_call(primitive, fun_, tracers, params))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1739, in process_call
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 97, in fn
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 364, in _promote_args
_check_no_float0s(fun_name, *args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 351, in _check_no_float0s
raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/code.py", line 90, in runcode
exec(code, self.locals)
File "<input>", line 19, in <module>
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
for step in self.iter(n_iter, step_size):
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 168, in iter
dp = self._forward_and_backward()
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/vmc.py", line 169, in _forward_and_backward
self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 634, in expect_and_grad
return expect_and_grad(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 584, in __call__
return method(*args, **kw_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/base.py", line 395, in expect_and_grad
return expect_and_grad(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_grad_chunked.py", line 75, in expect_and_grad_covariance_chunked
Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, mutable=mutable)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 81, in expect_and_forces_impl
Ō, Ō_grad, new_model_state = forces_expect_hermitian_chunked(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 145, in forces_expect_hermitian_chunked
Ō_grad = vjp_fun_chunked(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 58, in __vjp_fun_chunked
res = scanmap(
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in f_
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 73, in scan_append_reduce
carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0)))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in <lambda>
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/utils.py", line 294, in <lambda>
return lambda *args, **kwargs: f(g(*args, **kwargs))
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 29, in _vjp
res = vjp_fun(cotangents)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in vjp_fun
out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in <lambda>
out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4936, in deferring_binary_op
return binary_op(*args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 97, in fn
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 364, in _promote_args
_check_no_float0s(fun_name, *args)
File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 351, in _check_no_float0s
raise TypeError(
TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array.
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.