Skip to content

Remove everything about 'chain' in exact samplers

Vicentini Filippo requested to merge github/fork/wdphy16/sampler_api into master

Created by: wdphy16

Implements #997 (closed). This is definitely not 'slightly' or 'straightforward', and I hurried this in the weekend so that others' work will not introduce more conflicts.

All the changes to MCState are in the commit 'Update MCState'. Before that, all the tests for samplers should pass.


API changes to samplers:

  • Specifying n_chains and n_chains_per_rank when constructing Sampler and exact samplers (ExactSampler and ARDirectSampler) is deprecated. Please do not use them, as previously they had no effect. MetropolisSampler still implements them.
  • The method sample_next in Sampler and exact samplers is removed. Please use the method sample instead. MetropolisSampler still implements it.
  • nk.sampler.sample_next and nk.sampler.samples now only work with MetropolisSampler. For exact samplers, please use nk.sampler.sample instead.
  • MetropolisSampler.sample_chain and nk.sampler.sample_chain become public.
  • Sampler.sample and nk.sampler.sample have a new argument n_samples.
  • Specifying chain_length in Sampler.sample and nk.sampler.sample is deprecated. For exact samplers, please specify n_samples instead. For Metropolis samplers, please use sample_chain instead.
  • The shape of the returned samples of Sampler.sample and nk.sampler.sample now depends on the type of the sampler. For exact samplers it is a 2D array (n_samples, hilbert.size), and for Metropolis samplers it is a 3D array (n_chains, chain_length, hilbert.size).
  • nk.sampler.sampler_state now accepts the argument seed, as in Sampler.init_state. (Why not?)

API changes to variational states:

  • MCState.chain_length now has the type Optional[int], and will be None if the sampler is exact.
  • Specifying chain_length in MCState.sample when the sampler is exact is deprecated. Please specify n_samples instead.
  • The shape of the returned samples of MCState.sample now depends on the type of the sampler, as in Sampler.sample.

Implementation notes on Sampler:

  • The module functions sample_next, sample_chain, and samples are moved from sampler/base.py to sampler/metropolis.py.
  • To help write deprecation messages, there are helper functions compute_n_samples and compute_n_samples_per_rank in sampler/base.py, and compute_n_chains_per_rank and compute_chain_length (moved from vqs/mc/mc_state/state.py) in sampler/metropolis.py.
  • In exact samplers, _sample_next and _sample_chain are removed, and they only need to override _sample.
  • Sampler.sample takes n_samples as an argument, while _sample takes n_samples_per_rank. This is because sample handles the logic of dividing n_samples into ranks, which may raise warnings and should not be jitted.
  • I don't think we should allow to specify n_samples_per_rank as an argument in Sampler.sample, otherwise MCState.sample needs to be consistent with it and it will become more complicated.
  • There is a property Sampler.n_batches. Currently chunked sampling is not implemented, and exact samplers only use n_batches to store n_chains_per_rank and make it possible to specify chain_length in sample before it's removed. After finishing this PR, we can continue discussing about chunked sampling for exact samplers, and even Metropolis samplers (in principle it's possible when the number of independent chains is too large to fit into the memory).

Implementation notes on MetropolisSampler:

  • In MetropolisSampler, sample calls _sample_chain rather than _sample.
  • The public method sample_chain has default values for state and chain_length, and uses wrap_afun to handle machine, while the private method _sample_chain does not. The public module function sample_chain just calls MetropolisSampler.sample_chain. The private module function _sample_chain is jitted, and is called by MetropolisSampler._sample_chain to avoid repeated jitting from different class instances. Similar fuctions init, reset, sample, and sample_next also follow this principle.
  • MetropolisPtSampler.__repr__ and __str__ are removed. Previously we did not rename reset_chain to reset_chains so they would cause errors. I checked that they actually did not override anything, so I removed them to reduce the future maintenance cost.
  • I reviewed all docstrings and type annotations in sampler/base.py and sampler/metropolis.py, which made me confused when I read them for the first time. But I didn't touch numpy, PT, and pmap samplers.

Implementation notes on MCState:

  • Attributes _n_samples_per_rank and _chain_length are now declared for the dataclass. There should be no missing attribute now.
  • MCState.n_samples.setter is modified to support exact samplers.
  • When setting MCState.sampler, it will recompute n_samples, n_samples_per_rank, and chain_length according to the type and n_chains of the new sampler.
  • MCState.sample calls sampler.sample with n_samples if the sampler is exact, and calls sampler.sample_chain with chain_length otherwise.

Implementation notes on tests:

  • In all tests, n_chains is removed when constructing exact samplers.
  • The tests for PT samplers can pass now (on my machine), so I enabled them. But I don't know why they failed in the past.

Questions:

  • Should we deprecate Sampler.is_exact in favor of not isinstance(sampler, MetropolisSampler)? Currently all the usages of Sampler.is_exact are actually 'the sampler does not have chains'.

After we agree on all this, I'll add the changelog.


Update: Now I'll break down this to several smaller PRs, and I'll add strikethroughs to the finished parts above. Now it's finally implemented!

Merge request reports