Netket v3
Created by: PhilipVinc
TLDR
I'd like some feedback on this proposal. It's not yet done, though a big part is complete (I need to put a few @jit calls in the right place). Download it with pip install git+https://github.com/netket/netket@nk3
and play around a bit.
If you want to review the PR, since it's huge, I'd suggest to look at it commit by commit. Every chunk of changes is in a separate commit. Skip the first commit which is not really important anymore.
see how VMC becomes easy and more generic, or this gists for an example of the api
Netket v3.0 master plan
For the so-called, unreleased version of netket v3.0, until now, we have transitioned all of our infrastructure to python and removed all C++ code. This was done by temporarily using numpy and later converting most code to jax.
Still, the API has largely remained unchanged (except few changes to the construction of hilbert spaces and graphs, mostly aestetical).
v3.0 is still unreleased because me and @gcarleo wanted to get the API right before we commit to it, however, months have passed (almost an year) and this is still unreleased.
The big issues that I would like to address are the following:
- #437 Add an extensible VariationalState and use it in the drivers, making it easier to develop new applications.
- #525 Make it easy and intuitive to write custom Metropolis-Hastings transition rules and use them.
- Support non homogeneous hilbert spaces.
- Make it easier to define new arbitrary networks, with real or complex weights.
- Actually support arbitrary networks that mix real and complex layers (right now they aren't)
- Support models with a state that can change (but should not trigger recompilation)
- Remove legacy code.
In particular, while giving some lectures on netket I recently noticed Points 2 and 4, that is, it's quite messy to define those things.
Points 3/5 can be resolved by moving to flax, where (as I will argue below) defining models is much more compact and intuitive.
The following is my proposal for netket v3.0 API
A somwhat central point of this API stems from the discussion in #480 (closed), namely my proposal, which seemed somewhat accepted, to remove numpy and torch backend and convert netket to a pure jax package. This was already quite well received, and it seems to me that a) jax has now complete support from Google and it's a stable project and b) it allows us to writ more compact, simpler code.
This is my proposal for netket v3.0:
Roughly:
- add a new requirement:
jax
andflax
.This will also bump the minimum required version to python 3.7. This should not be an issue as also jax is discussing dropping python 3.6.EDIT: it's possible to still support 3.6.
Add to netket the following sub-packages:
-
netket.nn
: re-exportingflax.nn
but wrapping some functions in order to make them work better with complex numbers.-
Jax/Flax already supports complex numbers, but some activation functions do not for different reasons and the
kwargs
necessary to use complex weights in a layer are a bit complicated to use. The main reason to re-export and wrap is to make our exported api of flax work out of the box with complex. Slowly, I hope to get some PRs merged in flax itself so we can drop our own code. -
Example:
>> import netket as nk >> import flax >> from jax import numpy as jnp >> x = jax.random.normal(jax.random.PRNGKey(0), (4,4), dtype=jax.numpy.complex64) # flax does not work >> flax.nn.activation.softplus(x) TypeError: add requires arguments to have the same dtypes, got complex64, float32. # netket.nn works >> nk.nn.activation.softplus(x) DeviceArray([[1.9650044+0.14860448j, 0.362084 -0.14605658j], [0.6844594+0.72896177j, 0.4604799-0.02623011j]], dtype=complex64) # To use flax, we need to define our own complex init function def complex_kernel_init(rng, shape): fan_in = np.prod(shape) // shape[-1] x = random.normal(random.PRNGKey(0), shape) + 1j * random.normal(random.PRNGKey(0), shape) return x * (2 * fan_in) ** -0.5 complex_bias_init = lambda _, shape: jnp.zeros(shape, jnp.complex64) # complex-valued dense flax version m = flax.nn.Dense(features=3, dtype=jax.numpy.complex64, kernel_init=complex_kernel_init, bias_init=complex_bias_init) # nk-version m = netket.nn.Dense(features=3, dtype=jax.numpy.complex64)
I hope the above convinces you that having flax working out of the box with complex values is handy.
-
Since people might still have jax machines around, we provide a very simple function wrapping a jax module into a flax one so that everything works out of the box (
nk.nn.wrap_jax
)
-
-
netket.jax
: wrapping some functions in order to support complex numbers and functions that may one of R->R, R->C, C->C with the same syntax. Jax does not and will not support this out of the box. Notably, this will have our own version ofjax.vjp
andjax.grad
based on the code we have already injax_utils
, plus a few over utilities. -
netket.optim
: (What wasnetket.optimizer
).- I propose to change name (with a slow deprecation so not to break code) because
optim
is the default in jax and pyro world. tensorflfow usesoptimisers
, and we useoptimiser
. Let's pick one and be consistent. - the optimisers are simply re-exported from flax. No code.
- We also export SR
- Sr is rewritten with a new interface. see below.
- I propose to change name (with a slow deprecation so not to break code) because
AbstractMachine
and its implementations
Remove - Functionally replaced by pure
flax
modules, which will be objects that contain no state (parameters) but only two fucntions: theinit_params
, returning the pytree of params andapply
to compute the forward pass.-
Also support pure jax modules and make it easy to support other jax frameworks (optax, for example).
-
While those modules do not contain the hilbert space, provide an easy to use constructor that accepts an hilbert space and extract the size.
-
By default use
np.float32
and notnp.float64
, but depending on what the user uses, anything is supported -
An advantage of this is that we can now copy-paste any machine written for jax/flax and they will work out of the box, regardless of what they do! (modulo replacing flax.nn -> netket.nn for complex-number compatibility until things are fixed upstream).
-
See for example how to define a Convnet or a RBM with spin and phase: it's very easy. compare it with our old jax code for a RBMModPhase which
import netket as nk from netket import nn class RBMModPhase(nn.Module): dtype : Any = np.float32 activation : Any = nknn.logcosh alpha : Union[float, int] = 1 use_bias : bool = True @nn.compact def __call__(self, x): re = nknn.Dense(features=self.alpha*x.shape[-1], dtype=self.dtype, use_bias=self.use_bias)(x) re = self.activation(re) re = jnp.sum(re, axis=-1) im = nknn.Dense(features=self.alpha*x.shape[-1], dtype=self.dtype, use_bias=self.use_bias)(x) im = self.activation(im) im = jnp.sum(im, axis=-1) return mod + 1j * im class CNN(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(features=32, kernel_size=(3, 3))(x) x = nn.relu(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=256)(x) x = nn.relu(x) x = nn.log_softmax(x) return x machine = CNN(hilbert) W = machine.init_params(rng_key) # rng_key either an int/uint or jax.random.PRNGkey
- We wish to support non-differentiable variables in machines (for RNNs, batchnorm or other things.). To do so,
we adopt the flax standard: the parameters pytree has the shape
{'params':params_pytree, **other_state}
.
-
Rewrite the samplers to be functional. Mainly:
-
Samplers are now
datastructures
only containing the parameters of the sampler, and no state, nor machine.- Using
datastructures
allow us to pass those objects straight to jitted functions without issues. - The state of a sampler is stored in a separate struct.
- The state mainly carries the current rng state, plus additional stuff if needed.
- The api will be comprised of the following functions:
-
netket.sampler.init_state(sampler, machine, params)
orsampler.init_state()
, creating the state -
netket.sampler.sample(sampler, machine, params, chain_length=XXX, state=None)
or sampler.sample(...)`, sampling n_samples- If state is None, then
init_state
is used to create a new_state - Returns the modified state and sampled values.
- I chose
chain_length
instead of the oldn_samples
because that's technically what it is.
- If state is None, then
-
- Using
-
Example:
>>>import netket as nk >>>from netket import nn >>>sampler = nk.sampler.ExactSampler(hilbert, seed=0) ExactSampler( hilbert = Spin(s=1/2, N=4), seed = [ 0 2385058908], n_batches = 8, machine_power = 2) # notice the seed # create the state for the sampler: state = nk.sampler.init_state(sampler, machine, params) # reset the chain (could also have passed state = None, and it would create it) state = nk.sampler.reset(sampler, machine, params, state) samples, state = nk.sampler.sample(sampler, machine, params, chain_length = 1000, state=state) >>> samples DeviceArray([[[-1., -1., -1., -1.], [-1., 1., 1., 1.], [-1., -1., -1., -1.], ..., >>> state ExactSamplerState(pdf=DeviceArray([3.8351530e-01, 2.0624077e-02, 7.7783165e-04, 7.2943498e-03, 4.4462629e-02, 1.9844480e-04, 9.5614446e-03, 3.3565965e-02, 3.3565965e-02, 9.5614446e-03, 1.9844480e-04, 4.4462629e-02, 7.2943498e-03, 7.7783165e-04, 2.0624077e-02, 3.8351530e-01], dtype=float32), rng=DeviceArray([1255698341, 3859703708], dtype=uint32)) # notice the rng
- We store in the sampler the seed, so that if you reuse the same samplere, without carrying with you the state, you will get the same samples again. I had a brief discussion about this with @gcarleo and it seems the most sensible thing to do.
# If state is not passed in, the state is automatically creaeted... samples, state = nk.sampler.sample(sampler, machine, params, chain_length = 1000) >>> samples DeviceArray([[[-1., -1., -1., -1.], [-1., 1., 1., 1.], [-1., -1., -1., -1.], ..., >>> state.rng DeviceArray([1255698341, 3859703708], dtype=uint32) # notice the rng: it's the same as before
- While the interface is fully functional, you can also call those methods as
my_sampler.sample(...)
and so on. - This is all pure jax, even the sampler themselvees, so everything can be jitted through.
- If you create a new sampler, the function is not recompiled. The only thing triggering recompilation is
changing the
n_batches
, thehilbert
or other things declared static.
- If you create a new sampler, the function is not recompiled. The only thing triggering recompilation is
changing the
### Implement a variationalstate
- A variationalState has the following interface:
-
vs.parameters
: (or params?) returns the PyTree of the variational parameters one may want to optimise. -
vs.expect(operator)
: computes expectation value of operator -
vs.expect_and_grad(operator, is_hermitian=auto/True/False)
computes the expectation value of operator, and the gradient of it.- is_hermitian can be used to decide whever to use the simpler form we use for the gradient energy now, or the more standard formula.
-
vs.QGT() -> Callable[Grad, Grad]
: returns the quantum geometric tensor/ S matrix.- The returned object should be a lazy object (a la scipy.functor) that takes as input a gradient and returns another gradient.
-
vs.reset()
: resets the internal state among iterations -
save/load
-
For a machine/sampler, the ClassicalVariationalState will be thee first implementation of this interfacee.
-
It will provide some functionality similar to a
sampler/machine
from before, with a bunch of extra tricks:- A ClassicalVariational State is constructed by taking a
- hilbert,
- Machine/Module
- Sampler
- optional SR object? (
maybeprobably) - some configuration data
- chain_length
- the api is the one deescribed above plus:
-
vs.sample(chain_length)
: sample and store the samples internally until a reset() is callede -
vs.samples
: access the samplse. if it was resetted, resample. -
vs.model_state
: returns the PyTree of any other parameters that might change but we don't want to differentiate against (think batchnorm, rnn...). Might also put it in the common api. I'm not sure.
-
- A ClassicalVariational State is constructed by taking a
-
Minor changes:
- the file netket/utils.py has been moved into a subfolder and utils is now a full fledges submodule. I moved here our logic for detecting if MPI is installed, deprecation warnings, and so on.
- I more carefully check that the discoverable names in every module are things that are relevant. For example, before we had that in every module like netket.hilbert there was both Spin (the class) and spin (the module generated by file). There is a small utility in netket.utils to hide all file-generated modules that we don't want and I use it extensively.
- netket.hilbert.Boson has been renamed to netket.hilbert.Fock, as I find it more accurate. However, there is a deprecated constructor called Boson that forwards to Fock.
- Maybe we should rename all Samplers from MetropolisSampler, ExactSampler to Metropolis, Exact ? As we usually use them from their module (netket.sampler) it might make sense, and makes the API lighter.
- The old API to create samplers is deprecated (I'd like to remove it) but it is still there. If you try to construct a MetropolisSampler with a machine, it will give you the same objects as before.
- I'd like to remove netket.random. There is nothing of interest there anymore.
People
@inailuig If you have some time I'd like to know if this can solve your problem of recompiling. I think it should.