Skip to content

Only require models to work with batches (not vector) inputs

Vicentini Filippo requested to merge github/fork/PhilipVinc/pv/batch into master

Created by: PhilipVinc

Right now models provided by users must satisfy two conditions:

  • if input is a batch of bitstrings return a vector (M,N) -> (M,)
  • If input is a single bitstring (vector) return a scalar (N,) -> ()

However this is confusing and annoying. After a discussion with @femtobit and @attila-i-szabo we agreed that the best choice is to require models to work with batches because (a) most models in ML are written with explicit batches and (b) it makes sure we can control batching behaviour and were it ends up in intermediate calculations, something we can't do under vmap.

There are 2 ways to ensure that models must only support batches and not scalars:

  1. Go through netket and make sure we never call apply_fun with a vector (there are about 4-5 places where this happens). We must also go through netket and make sure to reshape to a scalar before calling jax.grad (happens in 2 places: lindblad and continuous kinetic operator).
  2. Just wrap all apply_fun provided to MCState and ExactState to always reshape to a batch if the input is a vector.

I am unconvinced by (1) because it makes code more verbose, and, if you are writing a custom expect kernel you have to be careful to never call the scalar version of the code. I don't like enforcing more constraints on users so... This PR opts for (2).

I think @femtobit won't fully agree with me?

Closes #629

Merge request reports