[NFC] Clean-up some old workarounds in sampler, simplify some code, and add error checking
Created by: PhilipVinc
This PR does the following minor changes:
- removes the jit function trampolines used by samplers to jit the code, as it's no longer needed thanks to improvements in jax caching mechanism, making the code easier to read
- remove some custom
__repr__
definitions for samplers, in favour of usingshow_repr=False
onn_chains
of the samplers. This cuts some more LOCs. - Makes
MetropolisRule
an abstract base class (it was raising errors anyway, and this makes it nicer in generated docs).
The following major changes:
- move
nk.sampler.MetropolisRule
abstract base class tonk.sampler.rules.MetropolisRule
. In a later PR I'd like to deprecate the original binding. part of the reason is to make themetropolis.py
file a bit smaller. - Add some error checking code to
MetropolisSampler
, so that if one defines a custom rule or uses a model that returns the wrong shapes, it gives a more comprehensible error instead of imperscrutabili jax errors.