Strange errors with `nk.optimizer.SRLazyGMRES()`
Created by: SemyonSinchenko
With nk.optimizer.SRLazyCG()
my code works well but replacing it with nk.optimizer.SRLazyGMRES()
i get an long error:
Traceback (most recent call last):
File "main.py", line 84, in <module>
model.train()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 71, in train
callback=SigmaCallback(),
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
for step in itr:
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/tqdm/std.py", line 1178, in __iter__
for obj in iterable:
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 167, in iter
dp = self._forward_and_backward()
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/vmc.py", line 109, in _forward_and_backward
self._dp, self._sr_info = self._S.solve(self._loss_grad, x0=x0)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py", line 187, in solve
x0,
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py", line 230, in apply_onthefly
out, info = solve_fun(_mat_vec, grad, x0=x0)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/scipy/sparse/linalg.py", line 661, in gmres
restart = min(restart, size)
jax._src.traceback_util.FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function apply_onthefly at /home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py:205, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to apply_onthefly at /home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py:205, transformed by jit. at flattened positions [8], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
(https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "main.py", line 84, in <module>
model.train()
File "/home/sem/GitHub/XXModelNQS/project/src/model/nqs.py", line 71, in train
callback=SigmaCallback(),
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 252, in run
for step in itr:
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/tqdm/std.py", line 1178, in __iter__
for obj in iterable:
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/abstract_variational_driver.py", line 167, in iter
dp = self._forward_and_backward()
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/driver/vmc.py", line 109, in _forward_and_backward
self._dp, self._sr_info = self._S.solve(self._loss_grad, x0=x0)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py", line 187, in solve
x0,
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/api.py", line 338, in cache_miss
donated_invars=donated_invars)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 1402, in bind
return call_bind(self, fun, *args, **params)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 1393, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 1405, in process
return trace.process_call(self, fun, tracers, params)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 600, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/interpreters/xla.py", line 577, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py", line 230, in apply_onthefly
out, info = solve_fun(_mat_vec, grad, x0=x0)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/_src/scipy/sparse/linalg.py", line 661, in gmres
restart = min(restart, size)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 529, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/jax/core.py", line 957, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function apply_onthefly at /home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py:205, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to apply_onthefly at /home/sem/.cache/pypoetry/virtualenvs/xxmodelnqs--vHqlMQ2-py3.7/lib/python3.7/site-packages/netket/optimizer/sr/sr_onthefly.py:205, transformed by jit. at flattened positions [8], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
(https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)
Reading the https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError didn't help me. How should I use the nk.optimizer.SRLazyGMRES
in the right way?
Thank you!
NetKet version:
NetKet version: 3.0b1.post6