QGTJacobianDense/PyTree of R->C model broken for complex gradient
Created by: PhilipVinc
cc @attila-i-szabo , @inailuig (could you cook up a fix ?)
As title.
This does not surface in tests because R->C models have a real energy gradient (because energy is real) but it surfaces with dissipative systems (see #850 (closed)).
MWE:
import netket as nk
import jax
import jax.numpy as jnp
g = nk.graph.Hypercube(length=10, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
ma = nk.models.RBMModPhase(alpha=1, dtype=float)
sa = nk.sampler.MetropolisLocal(hi, n_chains=16)
vs = nk.vqs.MCState(sa, ma, n_samples=1000, n_discard_per_chain=100)
So = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTOnTheFly)
Sjd = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTJacobianDense)
Sjo = vs.quantum_geometric_tensor(nk.optimizer.qgt.QGTJacobianPyTree)
_, F = vs.expect_and_grad(ha)
F = jax.tree_map(lambda x: jnp.asarray(x, dtype=complex), F)
So@F # works
Sjd@F # not work
Sjo@F # not work
stack traces are:
>>> Sjo@F
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 137, in __matmul__
return _matmul(self, vec)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/api.py", line 399, in cache_miss
out_flat = xla.xla_call(
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1561, in bind
return call_bind(self, fun, *args, **params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1564, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 682, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1285, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1263, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 185, in _matmul
result = mat_vec(vec, self.O, self.diag_shift)
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 285, in mat_vec
return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 173, in _mat_vec
res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 158, in _jvp
return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 168, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 157, in <lambda>
td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'tuple' object has no attribute 'ndim'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 137, in __matmul__
return _matmul(self, vec)
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree.py", line 185, in _matmul
result = mat_vec(vec, self.O, self.diag_shift)
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 285, in mat_vec
return tree_axpy(diag_shift, v, _mat_vec(v, centered_oks))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 173, in _mat_vec
res = tree_conj(_vjp(oks, _jvp(oks, v).conjugate()))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 158, in _jvp
return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_pytree_logic.py", line 157, in <lambda>
td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim)
AttributeError: 'tuple' object has no attribute 'ndim'
and
>>> Sjd@F
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 133, in __matmul__
return _matmul(self, vec)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/api.py", line 399, in cache_miss
out_flat = xla.xla_call(
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1561, in bind
return call_bind(self, fun, *args, **params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1552, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 1564, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 607, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 607, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/xla.py", line 682, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1285, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1263, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 171, in _matmul
mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0]
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 521, in __matmul__
def __matmul__(self, other): return self.aval._matmul(self, other)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5871, in deferring_binary_op
return binary_op(self, other)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4220, in matmul
out = lax.dot_general(
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 699, in dot_general
return dot_general_p.bind(lhs, rhs,
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1059, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 2136, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3405, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: dot_general requires contracting dimensions to have the same shape, got [220] and [440].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 133, in __matmul__
return _matmul(self, vec)
File "/Users/filippovicentini/Dropbox/Ricerca/Codes/Python/netket/netket/optimizer/qgt/qgt_jacobian_dense.py", line 171, in _matmul
mpi.mpi_sum_jax(((self.O @ vec).T.conj() @ self.O).T.conj())[0]
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5871, in deferring_binary_op
return binary_op(self, other)
File "/Users/filippovicentini/Documents/pythonenvs/netket_env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4220, in matmul
out = lax.dot_general(
TypeError: dot_general requires contracting dimensions to have the same shape, got [220] and [440].