Skip to content
Snippets Groups Projects

Proof of concept implementation of an RBM in pure Python

Merged Vicentini Filippo requested to merge github/fork/twesterhout/pytorch into master

Created by: twesterhout

First step towards using PyTorch modules with NetKet.

  • to_json and from_json methods are dropped in favour of Save and Load methods. The fact that machines save their state in JSON format is now an implementation detail. Machines implemented in Python are free to use pickle for example.
  • A "trampoline" class which enables sub-classing netket.machine.Machine from Python. Since this is still a work in progress, new features are not advertised in the docs.
  • An example of a pure Python class is given which can be used with NetKet as a drop-in replacement for RbmSpin.

Merge request reports

Merged by avatar (Apr 16, 2025 3:12am UTC)

Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
  • requested review from @filippo.vicentini

  • requested review from @filippo.vicentini

  • Created by: gcarleo

    @twesterhout a quick runtime benchmark? How slower?

  • Created by: twesterhout

    import timeit
    import netket
    import numpy as np
    
    hilbert = netket.hilbert.Spin(netket.graph.Hypercube(length=200, n_dim=1), s=1/2)
    cxx_rbm = netket.machine.RbmSpin(hilbert, alpha=5)
    py_rbm = netket.machine.PyRbm(hilbert, alpha=5)
    cxx_rbm.init_random_parameters()
    py_rbm.parameters = cxx_rbm.parameters
    
    x = np.random.choice(hilbert.local_states, size=hilbert.size)
    print("log_val:")
    print("    Py :", timeit.repeat(lambda: py_rbm.log_val(x), repeat=5, number=1000))
    print("    C++:", timeit.repeat(lambda: cxx_rbm.log_val(x), repeat=5, number=1000))
    print("der_log:")
    print("    Py :", timeit.repeat(lambda: py_rbm.der_log(x), repeat=5, number=1000))
    print("    C++:", timeit.repeat(lambda: cxx_rbm.der_log(x), repeat=5, number=1000))

    results in

    log_val:
        Py : [0.6111863109981641, 0.5606954150134698, 0.5920686059980653, 0.7154442630126141, 0.591628018009942]
        C++: [0.4250708259642124, 0.32820330100366846, 0.3333854129887186, 0.33881930098868906, 0.3311629240051843]
    der_log:
        Py : [1.8661257249768823, 1.9216104139923118, 1.8938411580165848, 1.9528111109975725, 1.9478432990144938]
        C++: [0.8959374719997868, 0.8461411170428619, 0.8448376799933612, 0.8329134889645502, 0.8522247140062973]

    on the laptop provided by Flatiron Institute. But numpy is actually using multiple cores, so that's not really fair :)

  • Created by: gcarleo

    Interesting... a factor of 2 is not too bad, especially considering that you can still JIT it, aha

24
25 namespace netket {
26
27 bool ShouldIDoIO() noexcept {
28 auto rank = 0;
29 auto const status = MPI_Comm_rank(MPI_COMM_WORLD, &rank);
30 if (status == MPI_SUCCESS) {
31 return rank == 0;
32 }
33 std::fprintf(stderr,
34 "[NetKet] MPI_Comm_rank failed: doing I/O on all processes.\n");
35 return true;
36 }
37
38 template <class Function, class... Args>
39 auto ShouldNotThrow(Function &&function, Args &&... args) noexcept
  • 125 }
    126
    127 Complex PyAbstractMachine::LogVal(VisibleConstType v) {
    128 PYBIND11_OVERLOAD_PURE_NAME(Complex, /* Return type */
    129 AbstractMachine, /* Parent class */
    130 "log_val", /* Name of the function in Python */
    131 LogVal, /* Name of function in C++ */
    132 v);
    133 }
    134
    135 Complex PyAbstractMachine::LogVal(VisibleConstType v,
    136 const LookupType & /*unused*/) {
    137 return LogVal(v);
    138 }
    139
    140 void PyAbstractMachine::InitLookup(VisibleConstType /*unused*/,
    • Created by: gcarleo

      Is this unused because we cannot derive a machine in Python using the look-up tables? Why we cannot do that?

    • Created by: twesterhout

      Like discussed, with PyTorch and deeper architectures it makes little sense to implement look-up tables.

    • Created by: gcarleo

      True, but there is other wave-functions for which it still makes sense. If it's not too much of an effort, I believe we should support also look-up tables. Is there some technical problem there?

    • Created by: twesterhout

      As for technical reasons... converting std::vector to numpy.ndarray and back does a lot of unnecessary allocations. These allocations defeat the purpose of having look-up tables as an optimisation in the first place.

      So to actually benefit from look-up tables, we need to rethink the way they're used on the C++ side. This is a pretty major refactoring, I'm afraid, and as such is a topic for a different PR/issue.

    • Created by: gcarleo

      This falls within issue #52 (closed) , I agree, but it's also important to do at some point in the future

  • Vicentini Filippo
  • Vicentini Filippo
  • Vicentini Filippo
  • Vicentini Filippo
  • Vicentini Filippo
  • Vicentini Filippo
  • Vicentini Filippo
  • Created by: femtobit

    Review: Approved

    Looks good to me now.

  • Created by: twesterhout

    @gcarleo can we merge?

  • Created by: gcarleo

    Review: Approved

    Good for me now

  • Merged by: gcarleo at 2019-06-14 17:38:17 UTC

  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Loading
  • Please register or sign in to reply
    Loading