Skip to content

QGTJacobianDense/PyTree of R->C model broken for complex gradient

Created by: PhilipVinc

cc @attila-i-szabo , @inailuig (could you cook up a fix ?)

As title.

This does not surface in tests because R->C models have a real energy gradient (because energy is real) but it surfaces with dissipative systems (see #850 (closed)).

MWE:

import netket as nk
import jax
import jax.numpy as jnp

g = nk.graph.Hypercube(length=10, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
ma = nk.models.RBMModPhase(alpha=1, dtype=float)
sa = nk.sampler.MetropolisLocal(hi, n_chains=16)

vs = nk.vqs.MCState(sa, ma, n_samples=1000, n_discard_per_chain=100)

So = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTOnTheFly)
Sjd = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTJacobianDense)
Sjo = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTJacobianPyTree)

_, F = vs.expect_and_grad(ha)
F = jax.tree_map(lambda x: jnp.asarray(x, dtype=complex), F)

So@F # works
Sjd@F # not work
Sjo@F # not work

stack traces are:

>>> Sjo@F
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 137, in __matmul__
    return _matmul(self, vec)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/api.py", line 399, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1561, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1564, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 682, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1285, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1263, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 185, in _matmul
    result = mat_vec(vec, self.O, self.diag_shift)
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 285, in mat_vec
    return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 173, in _mat_vec
    res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 158, in _jvp
    return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 157, in <lambda>
    td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'tuple' object has no attribute 'ndim'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 137, in __matmul__
    return _matmul(self, vec)
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 185, in _matmul
    result = mat_vec(vec, self.O, self.diag_shift)
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 285, in mat_vec
    return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 173, in _mat_vec
    res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 158, in _jvp
    return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 157, in <lambda>
    td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim)
AttributeError: 'tuple' object has no attribute 'ndim'

and

>>> Sjd@F
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 133, in __matmul__
    return _matmul(self, vec)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/api.py", line 399, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1561, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1564, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 682, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1285, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1263, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 171, in _matmul
    mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0]
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 521, in __matmul__
    def __matmul__(self, other): return self.aval._matmul(self, other)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5871, in deferring_binary_op
    return binary_op(self, other)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4220, in matmul
    out = lax.dot_general(
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 699, in dot_general
    return dot_general_p.bind(lhs, rhs,
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1059, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 2136, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3405, in _dot_general_shape_rule
    raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: dot_general requires contracting dimensions to have the same shape, got [220] and [440].

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 133, in __matmul__
    return _matmul(self, vec)
  File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 171, in _matmul
    mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0]
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5871, in deferring_binary_op
    return binary_op(self, other)
  File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4220, in matmul
    out = lax.dot_general(
TypeError: dot_general requires contracting dimensions to have the same shape, got [220] and [440].