Remove everything about 'chain' in exact samplers
Created by: wdphy16
Implements #997 (closed). This is definitely not 'slightly' or 'straightforward', and I hurried this in the weekend so that others' work will not introduce more conflicts.
All the changes to MCState
are in the commit 'Update MCState'. Before that, all the tests for samplers should pass.
API changes to samplers:
Specifyingn_chains
andn_chains_per_rank
when constructingSampler
and exact samplers (ExactSampler
andARDirectSampler
) is deprecated. Please do not use them, as previously they had no effect.MetropolisSampler
still implements them.The methodsample_next
inSampler
and exact samplers is removed. Please use the methodsample
instead.MetropolisSampler
still implements it.nk.sampler.sample_next
andnk.sampler.samples
now only work withMetropolisSampler
. For exact samplers, please usenk.sampler.sample
instead.MetropolisSampler.sample_chain
andnk.sampler.sample_chain
become public.Sampler.sample
andnk.sampler.sample
have a new argumentn_samples
.Specifyingchain_length
inSampler.sample
andnk.sampler.sample
is deprecated. For exact samplers, please specifyn_samples
instead. For Metropolis samplers, please usesample_chain
instead.The shape of the returned samples ofSampler.sample
andnk.sampler.sample
now depends on the type of the sampler. For exact samplers it is a 2D array(n_samples, hilbert.size)
, and for Metropolis samplers it is a 3D array(n_chains, chain_length, hilbert.size)
.nk.sampler.sampler_state
now accepts the argumentseed
, as inSampler.init_state
. (Why not?)
API changes to variational states:
MCState.chain_length
now has the typeOptional[int]
, and will beNone
if the sampler is exact.Specifyingchain_length
inMCState.sample
when the sampler is exact is deprecated. Please specifyn_samples
instead.The shape of the returned samples ofMCState.sample
now depends on the type of the sampler, as inSampler.sample
.
Implementation notes on Sampler
:
The module functionssample_next
,sample_chain
, andsamples
are moved fromsampler/base.py
tosampler/metropolis.py
.To help write deprecation messages, there are helper functionscompute_n_samples
andcompute_n_samples_per_rank
insampler/base.py
, andcompute_n_chains_per_rank
andcompute_chain_length
(moved fromvqs/mc/mc_state/state.py
) insampler/metropolis.py
.In exact samplers,_sample_next
and_sample_chain
are removed, and they only need to override_sample
.Sampler.sample
takesn_samples
as an argument, while_sample
takesn_samples_per_rank
. This is becausesample
handles the logic of dividingn_samples
into ranks, which may raise warnings and should not be jitted.I don't think we should allow to specifyn_samples_per_rank
as an argument inSampler.sample
, otherwiseMCState.sample
needs to be consistent with it and it will become more complicated.-
There is a propertyAfter finishing this PR, we can continue discussing about chunked sampling for exact samplers,Sampler.n_batches
. Currently chunked sampling is not implemented, and exact samplers only usen_batches
to storen_chains_per_rank
and make it possible to specifychain_length
insample
before it's removed.and even Metropolis samplers (in principle it's possible when the number of independent chains is too large to fit into the memory).
Implementation notes on MetropolisSampler
:
InMetropolisSampler
,sample
calls_sample_chain
rather than_sample
.The public methodsample_chain
has default values forstate
andchain_length
, and useswrap_afun
to handlemachine
, while the private method_sample_chain
does not. The public module functionsample_chain
just callsMetropolisSampler.sample_chain
. The private module function_sample_chain
is jitted, and is called byMetropolisSampler._sample_chain
to avoid repeated jitting from different class instances. Similar fuctionsinit
,reset
,sample
, andsample_next
also follow this principle.MetropolisPtSampler.__repr__
and__str__
are removed. Previously we did not renamereset_chain
toreset_chains
so they would cause errors. I checked that they actually did not override anything, so I removed them to reduce the future maintenance cost.I reviewed all docstrings and type annotations insampler/base.py
andsampler/metropolis.py
, which made me confused when I read them for the first time. But I didn't touch numpy, PT, and pmap samplers.
Implementation notes on MCState
:
Attributes_n_samples_per_rank
and_chain_length
are now declared for the dataclass. There should be no missing attribute now.MCState.n_samples.setter
is modified to support exact samplers.When settingMCState.sampler
, it will recomputen_samples
,n_samples_per_rank
, andchain_length
according to the type andn_chains
of the new sampler.MCState.sample
callssampler.sample
withn_samples
if the sampler is exact, and callssampler.sample_chain
withchain_length
otherwise.
Implementation notes on tests:
In all tests,n_chains
is removed when constructing exact samplers.The tests for PT samplers can pass now (on my machine), so I enabled them. But I don't know why they failed in the past.
Questions:
Should we deprecateSampler.is_exact
in favor ofnot isinstance(sampler, MetropolisSampler)
? Currently all the usages ofSampler.is_exact
are actually 'the sampler does not have chains'.
After we agree on all this, I'll add the changelog.
Update: Now I'll break down this to several smaller PRs, and I'll add strikethroughs to the finished parts above. Now it's finally implemented!