Skip to content

Numba4jax: Numba in HamiltonianSampler

Vicentini Filippo requested to merge numba4jax into master

Created by: PhilipVinc

This uses a trick to create on the fly a numba function inside of jax jitted functions and call it with almost-0 overhead.

This makes HamiltonianSampler be 5 times faster then the previous implementation that had to exit the jitted function and call back into jax. Now HamiltonianSampler is only 50% slower than MetropolisLocal. Also, we can use the same trick in over places.

The limitation is that it only works for cpu-arrays, not gpu. But we can put a conversion before calling those functions. Besides, indexing in localoperators makes no sense/would be unbearably slow on gpus so that makes sense.

Also, to make this work i need to add a very ugly yet necessary function to operators returning a closure computing the get_conn_flattened with their internal fields inside the closure. That is because numba does not know how to deal with our operators.

Merge request reports