Numba4jax: Numba in HamiltonianSampler
Created by: PhilipVinc
This uses a trick to create on the fly a numba function inside of jax jitted functions and call it with almost-0 overhead.
This makes HamiltonianSampler be 5 times faster then the previous implementation that had to exit the jitted function and call back into jax. Now HamiltonianSampler is only 50% slower than MetropolisLocal. Also, we can use the same trick in over places.
The limitation is that it only works for cpu-arrays, not gpu. But we can put a conversion before calling those functions. Besides, indexing in localoperators makes no sense/would be unbearably slow on gpus so that makes sense.
Also, to make this work i need to add a very ugly yet necessary function to operators returning a closure computing the get_conn_flattened with their internal fields inside the closure. That is because numba does not know how to deal with our operators.