Skip to content

[WIP] Add `run` function to VMC driver

Vicentini Filippo requested to merge github/fork/femtobit/pr-vmc-run into v2.1

Created by: femtobit

This PR contains a first work-in-progress version of the Vmc.run for the new VMC class.

There are a couple of differences to the old implementation, which we should discuss:

  1. vmc.add_observable is deprecated. I do not see a need to store the observables as an attribute of the driver. It is just as easy-to-use to have the user create a dict of observables in their own code and pass it to Vmc.run:
vmc = nk.Vmc(...)
obs = {"X": sigmax, "Y": sigmay, ...}
vmc.run(n_iter=1000, output_prefix="...", obs=obs)
  1. Instead of get_observable_stats with a somewhat complicated interface (include_energy, the option of either using the stored obs or passing a dict of other ones) I've added the method vmc.estimate(obs) which takes just a dict of observables and uses the pre-compiled samples of the current step and vmc.energy (following the original suggrstion of @orialb in #344 ) returning the already computed energy. get_observable_stats is thus also deprecated (but trivial to implement in terms of vmc.estimate).

  2. When I was saying vmc.estimate accepts a dict, that was not the full story: I have experimented here with using a JAX pytree, which is the JAX term for any type of list [obs1, obs2], dict {"Obs1": ..., "Obs2": ...}, tuples, or nested structure thereof, i.e., you could pass a nested dict like {0: {1: sigmax01, 2: sigmax02, ...}, 1: {0: sigmax10, 2: sigmax12, ...}, ...}). The estimate method will then return a pytree of the same structure with the observables replaced by their MC stats via the JAX tree_map function. It seems a bit silly to depend on JAX only for that (maybe we could just extract the function we need, it's also Apache licensed), but I like the resulting interface. Users can just store the observables in any structure they like (be it a list, dict, tuple, or a nested structure of these) and NetKet will do the right thing with it.

  3. The JSON output logic is now contained in an internal class _JsonLog. We should make sure to check how to generalize it so it can be used by all the drivers.

For now, let me know what you think of the general ideas here.

Merge request reports