Proof of concept implementation of an RBM in pure Python
Created by: twesterhout
First step towards using PyTorch modules with NetKet.
-
to_json
andfrom_json
methods are dropped in favour ofSave
andLoad
methods. The fact that machines save their state in JSON format is now an implementation detail. Machines implemented in Python are free to usepickle
for example. - A "trampoline" class which enables sub-classing
netket.machine.Machine
from Python. Since this is still a work in progress, new features are not advertised in the docs. - An example of a pure Python class is given which can be used with NetKet as a drop-in replacement for
RbmSpin
.
Merge request reports
Activity
requested review from @filippo.vicentini
requested review from @filippo.vicentini
Created by: twesterhout
import timeit import netket import numpy as np hilbert = netket.hilbert.Spin(netket.graph.Hypercube(length=200, n_dim=1), s=1/2) cxx_rbm = netket.machine.RbmSpin(hilbert, alpha=5) py_rbm = netket.machine.PyRbm(hilbert, alpha=5) cxx_rbm.init_random_parameters() py_rbm.parameters = cxx_rbm.parameters x = np.random.choice(hilbert.local_states, size=hilbert.size) print("log_val:") print(" Py :", timeit.repeat(lambda: py_rbm.log_val(x), repeat=5, number=1000)) print(" C++:", timeit.repeat(lambda: cxx_rbm.log_val(x), repeat=5, number=1000)) print("der_log:") print(" Py :", timeit.repeat(lambda: py_rbm.der_log(x), repeat=5, number=1000)) print(" C++:", timeit.repeat(lambda: cxx_rbm.der_log(x), repeat=5, number=1000))
results in
log_val: Py : [0.6111863109981641, 0.5606954150134698, 0.5920686059980653, 0.7154442630126141, 0.591628018009942] C++: [0.4250708259642124, 0.32820330100366846, 0.3333854129887186, 0.33881930098868906, 0.3311629240051843] der_log: Py : [1.8661257249768823, 1.9216104139923118, 1.8938411580165848, 1.9528111109975725, 1.9478432990144938] C++: [0.8959374719997868, 0.8461411170428619, 0.8448376799933612, 0.8329134889645502, 0.8522247140062973]
on the laptop provided by Flatiron Institute. But numpy is actually using multiple cores, so that's not really fair :)
24 25 namespace netket { 26 27 bool ShouldIDoIO() noexcept { 28 auto rank = 0; 29 auto const status = MPI_Comm_rank(MPI_COMM_WORLD, &rank); 30 if (status == MPI_SUCCESS) { 31 return rank == 0; 32 } 33 std::fprintf(stderr, 34 "[NetKet] MPI_Comm_rank failed: doing I/O on all processes.\n"); 35 return true; 36 } 37 38 template <class Function, class... Args> 39 auto ShouldNotThrow(Function &&function, Args &&... args) noexcept 125 } 126 127 Complex PyAbstractMachine::LogVal(VisibleConstType v) { 128 PYBIND11_OVERLOAD_PURE_NAME(Complex, /* Return type */ 129 AbstractMachine, /* Parent class */ 130 "log_val", /* Name of the function in Python */ 131 LogVal, /* Name of function in C++ */ 132 v); 133 } 134 135 Complex PyAbstractMachine::LogVal(VisibleConstType v, 136 const LookupType & /*unused*/) { 137 return LogVal(v); 138 } 139 140 void PyAbstractMachine::InitLookup(VisibleConstType /*unused*/, Created by: twesterhout
As for technical reasons... converting
std::vector
tonumpy.ndarray
and back does a lot of unnecessary allocations. These allocations defeat the purpose of having look-up tables as an optimisation in the first place.So to actually benefit from look-up tables, we need to rethink the way they're used on the C++ side. This is a pretty major refactoring, I'm afraid, and as such is a topic for a different PR/issue.
Created by: gcarleo
This falls within issue #52 (closed) , I agree, but it's also important to do at some point in the future
Created by: gcarleo
Wait
On Fri, Jun 14, 2019, 10:39 Tom Westerhout notifications@github.com wrote:
@gcarleo https://github.com/gcarleo can we merge?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://gitlab.labos.polytechnique.fr/filippo.vicentini/netket/-/merge_requests/219?email_source=notifications&email_token=AGWYRBEW24A6CQYB74EJFGLP2OUSBA5CNFSM4HX52VE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODXW7Q2I#issuecomment-502134889, or mute the thread https://github.com/notifications/unsubscribe-auth/AGWYRBBNYL7PT3UQN4DKHZLP2OUSBANCNFSM4HX52VEQ .