Throw informative errors if one uses operators/hilbert inside of jax
Created by: PhilipVinc
This is a common error I've seen several people make, especially when trying to use operators from netket. So let's try to throw more informative errors now that we actually do support some jax operators:
to be merged after the previous PR
In [1]: import netket as nk; g=nk.graph.Square(4); hi=nk.hilbert.Spin(0.5, g.n_nodes); import jax; import jax.numpy as jnp; import numpy as np
In [2]: ha=nk.operator.Ising(hi, graph=g, h=0.5)
In [3]: jax.jit(lambda x:ha.get_conn_padded(x))(ha.hilbert.all_states())
...
File ~/Dropbox/Ricerca/Codes/Python/netket/netket/errors.py:264, in concrete_or_error(force, value, error_class, *args, **kwargs)
257 return jax.core.concrete_or_error(
258 force,
259 value,
260 """
261 """,
262 )
263 except ConcretizationTypeError as err:
--> 264 raise error_class(*args, **kwargs) from err
NumbaOperatorGetConnDuringTracingError:
Attempted to use a Numba-based operator (<class 'netket.operator._ising.numba.Ising'>) inside a Jax function transformation (jax.jit, jax.grad & others).
Numba-based operators are not compatible with Jax functiontransformations, and can only be used outside of jax-functionboundaries.
Some operators can be converted to a Jax-compatible version bycalling `operator.to_jax_operator()`, but not all support it.
-------------------------------------------------------
For more detailed informations, visit the following link:
https://netket.readthedocs.io/en/latest/api/_generated/errors/netket.errors.NumbaOperatorGetConnDuringTracingError.html
or the list of all common errors at
https://netket.readthedocs.io/en/latest/api/errors.html
-------------------------------------------------------