Skip to content

Float0 error with chunked vjp for complex operators and int sampler dtype.

Created by: MandMarc

When training a VQS to optimize for the ground state of an operator with

  1. complex matrix elements (e.g. sigmay)
  2. chunking enabled and
  3. the dtype of the sampler set to jnp.int8, I get the following float0 error printed down below.

I'm using netket version 3.6, jax version 0.3.25. I suppose there is some problem with typecasting in the chunked version of expect and grad but it's possible I'm just using things wrong.

The minimal example to reproduce this error for me:

import netket as nk
import jax.numpy as jnp

hilbert = nk.hilbert.Spin(1/2, 2)

ha = nk.operator.LocalOperator(hilbert, dtype=complex)
ha += nk.operator.spin.sigmaz(hilbert, 0)
ha += nk.operator.spin.sigmay(hilbert, 1)

model = nk.models.RBM(alpha=2, param_dtype=float)
sampler = nk.sampler.MetropolisLocal(hilbert, n_chains=512, dtype=jnp.int8)

vstate = nk.vqs.MCState(sampler, model, n_samples=1024, chunk_size=256)

optimizer = nk.optimizer.Sgd(learning_rate=0.05)
precon = nk.optimizer.SR(nk.optimizer.qgt.QGTOnTheFly, diag_shift=0.01)
gs = nk.driver.VMC(ha, optimizer, variational_state=vstate, preconditioner=precon)

gs.run(100)

The full error message:

Traceback (most recent call last):
  File "<input>", line 19, in <module>
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
    for step in self.iter(n_iter, step_size):
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 168, in iter
    dp = self._forward_and_backward()
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/vmc.py", line 169, in _forward_and_backward
    self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 634, in expect_and_grad
    return expect_and_grad(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 584, in __call__
    return method(*args, **kw_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/base.py", line 395, in expect_and_grad
    return expect_and_grad(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_grad_chunked.py", line 75, in expect_and_grad_covariance_chunked
    Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, mutable=mutable)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 81, in expect_and_forces_impl
    Ō, Ō_grad, new_model_state = forces_expect_hermitian_chunked(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 145, in forces_expect_hermitian_chunked
    Ō_grad = vjp_fun_chunked(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 301, in __call__
    return self.fun(*args, **kw)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 58, in __vjp_fun_chunked
    res = scanmap(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in f_
    return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 73, in scan_append_reduce
    carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0)))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 3201, in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in <lambda>
    return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/utils.py", line 294, in <lambda>
    return lambda *args, **kwargs: f(g(*args, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 29, in _vjp
    res = vjp_fun(cotangents)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in vjp_fun
    out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 207, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in <lambda>
    out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/core.py", line 606, in __rmul__
    def __rmul__(self, other): return self.aval._rmul(self, other)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4936, in deferring_binary_op
    return binary_op(*args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1739, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 97, in fn
    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 364, in _promote_args
    _check_no_float0s(fun_name, *args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 351, in _check_no_float0s
    raise TypeError(
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/usr/lib/python3.10/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 19, in <module>
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
    for step in self.iter(n_iter, step_size):
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/abstract_variational_driver.py", line 168, in iter
    dp = self._forward_and_backward()
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/driver/vmc.py", line 169, in _forward_and_backward
    self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 634, in expect_and_grad
    return expect_and_grad(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 584, in __call__
    return method(*args, **kw_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/base.py", line 395, in expect_and_grad
    return expect_and_grad(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_grad_chunked.py", line 75, in expect_and_grad_covariance_chunked
    Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, mutable=mutable)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/plum/function.py", line 586, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 81, in expect_and_forces_impl
    Ō, Ō_grad, new_model_state = forces_expect_hermitian_chunked(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/expect_forces_chunked.py", line 145, in forces_expect_hermitian_chunked
    Ō_grad = vjp_fun_chunked(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 58, in __vjp_fun_chunked
    res = scanmap(
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in f_
    return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 73, in scan_append_reduce
    carry_init = True, _get_op_part(_tree_zeros_like(jax.eval_shape(f, x0)))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_scanmap.py", line 128, in <lambda>
    return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/utils.py", line 294, in <lambda>
    return lambda *args, **kwargs: f(g(*args, **kwargs))
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp_chunked.py", line 29, in _vjp
    res = vjp_fun(cotangents)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in vjp_fun
    out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/netket/jax/_vjp.py", line 74, in <lambda>
    out = tree_map(lambda re, im: re - 1j * im, out_r, out_i)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4936, in deferring_binary_op
    return binary_op(*args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 97, in fn
    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 364, in _promote_args
    _check_no_float0s(fun_name, *args)
  File "/home/marc/programming/projects/geneqs/venv/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 351, in _check_no_float0s
    raise TypeError(
TypeError: Called multiply with a float0 array. float0s do not support any operations by design because they are not compatible with non-trivial vector spaces. No implicit dtype conversion is done. You can use np.zeros_like(arr, dtype=np.float) to cast a float0 array to a regular zeros array. 
If you didn't expect to get a float0 you might have accidentally taken a gradient with respect to an integer argument.