Only require models to work with batches (not vector) inputs
Created by: PhilipVinc
Right now models provided by users must satisfy two conditions:
- if input is a batch of bitstrings return a vector
(M,N) -> (M,)
- If input is a single bitstring (vector) return a scalar
(N,) -> ()
However this is confusing and annoying. After a discussion with @femtobit and @attila-i-szabo we agreed that the best choice is to require models to work with batches because (a) most models in ML are written with explicit batches and (b) it makes sure we can control batching behaviour and were it ends up in intermediate calculations, something we can't do under vmap.
There are 2 ways to ensure that models must only support batches and not scalars:
- Go through netket and make sure we never call
apply_fun
with a vector (there are about 4-5 places where this happens). We must also go through netket and make sure to reshape to a scalar before callingjax.grad
(happens in 2 places: lindblad and continuous kinetic operator). - Just wrap all
apply_fun
provided toMCState
andExactState
to always reshape to a batch if the input is a vector.
I am unconvinced by (1) because it makes code more verbose, and, if you are writing a custom expect
kernel you have to be careful to never call the scalar version of the code. I don't like enforcing more constraints on users so...
This PR opts for (2).
I think @femtobit won't fully agree with me?
Closes #629