Skip to content

Throw informative errors if one uses operators/hilbert inside of jax

Vicentini Filippo requested to merge github/fork/PhilipVinc/pv/error into master

Created by: PhilipVinc

This is a common error I've seen several people make, especially when trying to use operators from netket. So let's try to throw more informative errors now that we actually do support some jax operators:

to be merged after the previous PR

In [1]: import netket as nk; g=nk.graph.Square(4); hi=nk.hilbert.Spin(0.5, g.n_nodes); import jax; import jax.numpy as jnp; import numpy as np

In [2]: ha=nk.operator.Ising(hi, graph=g, h=0.5)

In [3]: jax.jit(lambda x:ha.get_conn_padded(x))(ha.hilbert.all_states())
...
File ~/Dropbox/Ricerca/Codes/Python/netket/netket/errors.py:264, in concrete_or_error(force, value, error_class, *args, **kwargs)
    257     return jax.core.concrete_or_error(
    258         force,
    259         value,
    260         """
    261       """,
    262     )
    263 except ConcretizationTypeError as err:
--> 264     raise error_class(*args, **kwargs) from err

NumbaOperatorGetConnDuringTracingError:
Attempted to use a Numba-based operator (<class 'netket.operator._ising.numba.Ising'>) inside a Jax function transformation (jax.jit, jax.grad & others).

Numba-based operators are not compatible with Jax functiontransformations, and can only be used outside of jax-functionboundaries.

Some operators can be converted to a Jax-compatible version bycalling `operator.to_jax_operator()`, but not all support it.

-------------------------------------------------------
For more detailed informations, visit the following link:
	 https://netket.readthedocs.io/en/latest/api/_generated/errors/netket.errors.NumbaOperatorGetConnDuringTracingError.html
or the list of all common errors at
	 https://netket.readthedocs.io/en/latest/api/errors.html
-------------------------------------------------------

Merge request reports