Annotate `mutable` to be `CollectionFilter` in vstates
Created by: wdphy16
The previous annotations always made me puzzled when reading the source code for vstates. Actually its type should be CollectionFilter = Union[bool, str, Collection[str]]
, as defined in flax.core.scope
and used in flax.linen.Module.apply
. Its default value should be False
when initializing the vstate.
Methods like vstate.expect_and_grad
can have a parameter mutable: Optional[CollectionFilter] = None
to override the initial vstate.mutable
, just like how chain_length
in vstate.sample
overrides the initial vstate.chain_length
. Here None
means not to override, and False
means to disable mutation even if the vstate is mutable initially. Outside those methods, mutable
should never be None
.