[RFC][WIP] Overhaul of jax samplers and hilbert
Created by: PhilipVinc
This is very WIP, and half done (only LocalKernel works), but I'd like comments.
I want to solve the following issues:
- (Main) I would like to arbitrary, non-homogeneous hilbert spaces.
- I would like to have a better way to define Metropolis kernels, and make them work better in general.
- Jax samplers on GPU are slow. Let's speed them up
- Jax samplerse recompile a lot of code for nothing every time you re-create a sampler. Can we avoid this?
Note: I tried to address the problem by ignoring non-jax stuff.
At the moment the PR addresses points 1,2,3. It would probably be possible to solve 4, but need to think about it. Incidentally, this PR speeds samplers up by 80% (at least in the spin case with 20 spins) on my laptop. Will benchmark better in a bit.
EDIT:The first 2 commits are unrelated to the samplers overhaul and are simply a refactor of jax machines... ignore it.
--
Idea: right now Metropolis Kernels are very stiff: a kernel has a rule to randomly generate a new state, and a rule to update them. However, a kernel completely ignores any structure in the Hilbert space itself. so how to we address 1?
Solution: better distinguish what the sampling kernel does from what must be done at the level of the Hilbert space.
Proposal: add two (jax-specific) methods to all hilbert spaces: one to generate a random state (or batch of states) and one to propose a new local_state in a site.
jax_random_state
can then dispatch on the Hilbert space for a specific implementation, allowing us to implement non.homogeneous hilbert spaces (a tensor space would concatenate the states generated by it's sub-spaces), but also to speed-up considerably this generation. For example, generating spin states does not need to index into the list of local_states
, but can be done with some smart combination of floating-point arithmetic and floor
, which is well supported on GPUs.
Also, while one can define only the scalar version of those generators, the best performance is obtained by defining the function to generate them in batches.
The kernel transition
rule can also be defined to work in batches, otherwise it falls back to vmapping on the scalar function.
There I'd like to think a bit more about possible use-cases and how to be general... so it's not definitive. But I'd like some comments.