Skip to content

Annotate `mutable` to be `CollectionFilter` in vstates

Vicentini Filippo requested to merge github/fork/wdphy16/annot_mutable into master

Created by: wdphy16

The previous annotations always made me puzzled when reading the source code for vstates. Actually its type should be CollectionFilter = Union[bool, str, Collection[str]], as defined in flax.core.scope and used in flax.linen.Module.apply. Its default value should be False when initializing the vstate.

Methods like vstate.expect_and_grad can have a parameter mutable: Optional[CollectionFilter] = None to override the initial vstate.mutable, just like how chain_length in vstate.sample overrides the initial vstate.chain_length. Here None means not to override, and False means to disable mutation even if the vstate is mutable initially. Outside those methods, mutable should never be None.

Merge request reports