Skip to content

Netket v3

Vicentini Filippo requested to merge nk3 into master

Created by: PhilipVinc

TLDR

I'd like some feedback on this proposal. It's not yet done, though a big part is complete (I need to put a few @jit calls in the right place). Download it with pip install git+https://github.com/netket/netket@nk3 and play around a bit.

If you want to review the PR, since it's huge, I'd suggest to look at it commit by commit. Every chunk of changes is in a separate commit. Skip the first commit which is not really important anymore.

see how VMC becomes easy and more generic, or this gists for an example of the api

Netket v3.0 master plan

For the so-called, unreleased version of netket v3.0, until now, we have transitioned all of our infrastructure to python and removed all C++ code. This was done by temporarily using numpy and later converting most code to jax.

Still, the API has largely remained unchanged (except few changes to the construction of hilbert spaces and graphs, mostly aestetical).

v3.0 is still unreleased because me and @gcarleo wanted to get the API right before we commit to it, however, months have passed (almost an year) and this is still unreleased.

The big issues that I would like to address are the following:

  1. #437 Add an extensible VariationalState and use it in the drivers, making it easier to develop new applications.
  2. #525 Make it easy and intuitive to write custom Metropolis-Hastings transition rules and use them.
  3. Support non homogeneous hilbert spaces.
  4. Make it easier to define new arbitrary networks, with real or complex weights.
  5. Actually support arbitrary networks that mix real and complex layers (right now they aren't)
  6. Support models with a state that can change (but should not trigger recompilation)
  7. Remove legacy code.

In particular, while giving some lectures on netket I recently noticed Points 2 and 4, that is, it's quite messy to define those things.

Points 3/5 can be resolved by moving to flax, where (as I will argue below) defining models is much more compact and intuitive.

The following is my proposal for netket v3.0 API

A somwhat central point of this API stems from the discussion in #480 (closed), namely my proposal, which seemed somewhat accepted, to remove numpy and torch backend and convert netket to a pure jax package. This was already quite well received, and it seems to me that a) jax has now complete support from Google and it's a stable project and b) it allows us to writ more compact, simpler code.

This is my proposal for netket v3.0:

Roughly:

  • add a new requirement: jax and flax. This will also bump the minimum required version to python 3.7. This should not be an issue as also jax is discussing dropping python 3.6. EDIT: it's possible to still support 3.6.

Add to netket the following sub-packages:

  • netket.nn : re-exporting flax.nn but wrapping some functions in order to make them work better with complex numbers.

    • Jax/Flax already supports complex numbers, but some activation functions do not for different reasons and the kwargs necessary to use complex weights in a layer are a bit complicated to use. The main reason to re-export and wrap is to make our exported api of flax work out of the box with complex. Slowly, I hope to get some PRs merged in flax itself so we can drop our own code.

    • Example:

      >> import netket as nk
      >> import flax
      >> from jax import numpy as jnp
      
      >> x = jax.random.normal(jax.random.PRNGKey(0), (4,4), dtype=jax.numpy.complex64)
      # flax does not work
      >> flax.nn.activation.softplus(x)
      TypeError: add requires arguments to have the same dtypes, got complex64, float32.
      # netket.nn works
      >> nk.nn.activation.softplus(x)
      DeviceArray([[1.9650044+0.14860448j, 0.362084 -0.14605658j],
                   [0.6844594+0.72896177j, 0.4604799-0.02623011j]],            dtype=complex64)
      
      # To use flax, we need to define our own complex init function
      def complex_kernel_init(rng, shape):
        fan_in = np.prod(shape) // shape[-1]
        x = random.normal(random.PRNGKey(0), shape) + 1j * random.normal(random.PRNGKey(0), shape)
        return x * (2 * fan_in) ** -0.5
      
      complex_bias_init = lambda _, shape: jnp.zeros(shape, jnp.complex64)
      
      # complex-valued dense flax version
      m = flax.nn.Dense(features=3, dtype=jax.numpy.complex64, kernel_init=complex_kernel_init, bias_init=complex_bias_init)
      
      # nk-version
      m = netket.nn.Dense(features=3, dtype=jax.numpy.complex64)

      I hope the above convinces you that having flax working out of the box with complex values is handy.

    • Since people might still have jax machines around, we provide a very simple function wrapping a jax module into a flax one so that everything works out of the box (nk.nn.wrap_jax)

  • netket.jax : wrapping some functions in order to support complex numbers and functions that may one of R->R, R->C, C->C with the same syntax. Jax does not and will not support this out of the box. Notably, this will have our own version of jax.vjp and jax.grad based on the code we have already in jax_utils, plus a few over utilities.

  • netket.optim: (What was netket.optimizer).

    • I propose to change name (with a slow deprecation so not to break code) because optim is the default in jax and pyro world. tensorflfow uses optimisers, and we use optimiser. Let's pick one and be consistent.
    • the optimisers are simply re-exported from flax. No code.
    • We also export SR
      • Sr is rewritten with a new interface. see below.

Remove AbstractMachine and its implementations

  • Functionally replaced by pure flax modules, which will be objects that contain no state (parameters) but only two fucntions: the init_params, returning the pytree of params and apply to compute the forward pass.
    • Also support pure jax modules and make it easy to support other jax frameworks (optax, for example).

    • While those modules do not contain the hilbert space, provide an easy to use constructor that accepts an hilbert space and extract the size.

    • By default use np.float32 and not np.float64, but depending on what the user uses, anything is supported

    • An advantage of this is that we can now copy-paste any machine written for jax/flax and they will work out of the box, regardless of what they do! (modulo replacing flax.nn -> netket.nn for complex-number compatibility until things are fixed upstream).

    • See for example how to define a Convnet or a RBM with spin and phase: it's very easy. compare it with our old jax code for a RBMModPhase which

    import netket as nk
    from netket import nn
    
    class RBMModPhase(nn.Module):
        dtype : Any = np.float32
        activation : Any = nknn.logcosh
        alpha : Union[float, int] = 1
        use_bias : bool = True
    
        @nn.compact
        def __call__(self, x):
            re = nknn.Dense(features=self.alpha*x.shape[-1], dtype=self.dtype, use_bias=self.use_bias)(x)
            re = self.activation(re)
            re = jnp.sum(re, axis=-1)
    
            im = nknn.Dense(features=self.alpha*x.shape[-1], dtype=self.dtype, use_bias=self.use_bias)(x)
            im = self.activation(im)
            im = jnp.sum(im, axis=-1)
    
            return mod + 1j * im
    
    
    class CNN(nn.Module):
      @nn.compact
      def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.log_softmax(x)
        return x
    
    machine = CNN(hilbert)
    W = machine.init_params(rng_key) # rng_key either an int/uint or jax.random.PRNGkey
    • We wish to support non-differentiable variables in machines (for RNNs, batchnorm or other things.). To do so, we adopt the flax standard: the parameters pytree has the shape {'params':params_pytree, **other_state}.

Rewrite the samplers to be functional. Mainly:

  • Samplers are now datastructures only containing the parameters of the sampler, and no state, nor machine.

    • Using datastructures allow us to pass those objects straight to jitted functions without issues.
    • The state of a sampler is stored in a separate struct.
      • The state mainly carries the current rng state, plus additional stuff if needed.
    • The api will be comprised of the following functions:
      • netket.sampler.init_state(sampler, machine, params) or sampler.init_state(), creating the state
      • netket.sampler.sample(sampler, machine, params, chain_length=XXX, state=None) or sampler.sample(...)`, sampling n_samples
        • If state is None, then init_state is used to create a new_state
        • Returns the modified state and sampled values.
        • I chose chain_length instead of the old n_samples because that's technically what it is.
  • Example:

    >>>import netket as nk
    >>>from netket import nn
    >>>sampler = nk.sampler.ExactSampler(hilbert, seed=0)
    ExactSampler(
      hilbert = Spin(s=1/2, N=4),
      seed = [         0 2385058908],
      n_batches = 8,
      machine_power = 2)
    # notice the seed
    
    # create the state for the sampler:
    state = nk.sampler.init_state(sampler, machine, params)
    # reset the chain (could also have passed state = None, and it would create it)
    state = nk.sampler.reset(sampler, machine, params, state)
    
    samples, state = nk.sampler.sample(sampler, machine, params, chain_length = 1000, state=state)
    >>> samples
    DeviceArray([[[-1., -1., -1., -1.],
                  [-1.,  1.,  1.,  1.],
                  [-1., -1., -1., -1.],
                  ...,
    >>> state
    ExactSamplerState(pdf=DeviceArray([3.8351530e-01, 2.0624077e-02, 7.7783165e-04, 7.2943498e-03,
                 4.4462629e-02, 1.9844480e-04, 9.5614446e-03, 3.3565965e-02,
                 3.3565965e-02, 9.5614446e-03, 1.9844480e-04, 4.4462629e-02,
                 7.2943498e-03, 7.7783165e-04, 2.0624077e-02, 3.8351530e-01], dtype=float32), 
                rng=DeviceArray([1255698341, 3859703708], dtype=uint32))
                # notice the rng
    • We store in the sampler the seed, so that if you reuse the same samplere, without carrying with you the state, you will get the same samples again. I had a brief discussion about this with @gcarleo and it seems the most sensible thing to do.
    # If state is not passed in, the state is automatically creaeted...
    samples, state = nk.sampler.sample(sampler, machine, params, chain_length = 1000)
    >>> samples
    DeviceArray([[[-1., -1., -1., -1.],
                  [-1.,  1.,  1.,  1.],
                  [-1., -1., -1., -1.],
                  ...,
    >>> state.rng
    DeviceArray([1255698341, 3859703708], dtype=uint32)                
    # notice the rng: it's the same as before
    • While the interface is fully functional, you can also call those methods as my_sampler.sample(...) and so on.
    • This is all pure jax, even the sampler themselvees, so everything can be jitted through.
      • If you create a new sampler, the function is not recompiled. The only thing triggering recompilation is changing the n_batches, the hilbert or other things declared static.

### Implement a variationalstate

  • A variationalState has the following interface:
    • vs.parameters : (or params?) returns the PyTree of the variational parameters one may want to optimise.

    • vs.expect(operator) : computes expectation value of operator

    • vs.expect_and_grad(operator, is_hermitian=auto/True/False) computes the expectation value of operator, and the gradient of it.

      • is_hermitian can be used to decide whever to use the simpler form we use for the gradient energy now, or the more standard formula.
    • vs.QGT() -> Callable[Grad, Grad] : returns the quantum geometric tensor/ S matrix.

      • The returned object should be a lazy object (a la scipy.functor) that takes as input a gradient and returns another gradient.
    • vs.reset() : resets the internal state among iterations

    • save/load

    • For a machine/sampler, the ClassicalVariationalState will be thee first implementation of this interfacee.

    • It will provide some functionality similar to a sampler/machine from before, with a bunch of extra tricks:

      • A ClassicalVariational State is constructed by taking a
        • hilbert,
        • Machine/Module
        • Sampler
        • optional SR object? (maybe probably)
        • some configuration data
          • chain_length
        • the api is the one deescribed above plus:
          • vs.sample(chain_length) : sample and store the samples internally until a reset() is callede
          • vs.samples : access the samplse. if it was resetted, resample.
          • vs.model_state : returns the PyTree of any other parameters that might change but we don't want to differentiate against (think batchnorm, rnn...). Might also put it in the common api. I'm not sure.

Minor changes:

  • the file netket/utils.py has been moved into a subfolder and utils is now a full fledges submodule. I moved here our logic for detecting if MPI is installed, deprecation warnings, and so on.
  • I more carefully check that the discoverable names in every module are things that are relevant. For example, before we had that in every module like netket.hilbert there was both Spin (the class) and spin (the module generated by file). There is a small utility in netket.utils to hide all file-generated modules that we don't want and I use it extensively.
  • netket.hilbert.Boson has been renamed to netket.hilbert.Fock, as I find it more accurate. However, there is a deprecated constructor called Boson that forwards to Fock.
  • Maybe we should rename all Samplers from MetropolisSampler, ExactSampler to Metropolis, Exact ? As we usually use them from their module (netket.sampler) it might make sense, and makes the API lighter.
  • The old API to create samplers is deprecated (I'd like to remove it) but it is still there. If you try to construct a MetropolisSampler with a machine, it will give you the same objects as before.
  • I'd like to remove netket.random. There is nothing of interest there anymore.

People

@inailuig If you have some time I'd like to know if this can solve your problem of recompiling. I think it should.

Merge request reports