Skip to content

WIP: state_dict

Vicentini Filippo requested to merge github/fork/twesterhout/state_dict into master

Created by: twesterhout

We talked about using Python dictionaries for updating and saving parameters. This PR actually implements this functionality for RbmSpin.

Now we have

>>> import netket
>>> import numpy
>>> h = netket.hilbert.Spin(netket.graph.Hypercube(5, n_dim=1, pbc=True), s=0.5)
>>> m = netket.machine.RbmSpin(h, n_hidden=4)
>>> m.init_random_parameters()
>>> m.state_dict()
{'a': array([ 0.06266416-0.08706595j,  0.05022808-0.00971853j,
        0.12287212+0.13265499j, -0.36866098-0.00968607j,
       -0.12293974-0.00809343j]), 'b': array([ 0.03828042-0.16867513j, -0.0536683 +0.00832681j,
        0.11986847-0.03270153j,  0.21675369-0.05061179j]), 'w': array([[ 0.10887414+9.25775000e-02j,  0.04988153+7.22126278e-03j,
        -0.00953845-3.61061884e-05j, -0.02722431-9.56652539e-02j],
       [ 0.17739797+2.67596957e-01j, -0.08609805+3.47115713e-02j,
         0.05762644-1.57047880e-02j,  0.06564118+5.19112256e-03j],
       [-0.01790749-1.57556098e-01j, -0.02993753+1.34490347e-01j,
         0.18646482-1.45914767e-01j, -0.04700775+6.12943165e-02j],
       [ 0.10217094+8.98425795e-02j,  0.04849756+1.80935579e-01j,
        -0.06182897-1.25392615e-02j,  0.04124381+1.00977662e-01j],
       [ 0.17383309+7.56729259e-02j, -0.01546016+1.61383038e-02j,
         0.13857972-3.74710975e-02j, -0.02467454+1.47849472e-01j]])}

There are also a few changes to the implementation of the RBM, namely using optionals instead of separate bool flags. Also, passing invalid arguments to RbmSpin constructor actually throws instead of crashing:

>>> # Before this PR (i.e. current master)
>>> m = netket.machine.RbmSpin(h, alpha=-2, n_hidden=-4)
python3: External/Eigen3/Eigen/src/Core/PlainObjectBase.h:285: void Eigen::PlainObjectBase<Eigen::Matrix<std::complex<double>, -1, -1, 0, -1, -1> >::resize(Eigen::Index, Eigen::Index) [Derived = Eigen::Matrix<std::complex<double>, -1, -1, 0, -1, -1>]: Assertion `(!(RowsAtCompileTime!=Dynamic) || (rows==RowsAtCompileTime)) && (!(ColsAtCompileTime!=Dynamic) || (cols==ColsAtCompileTime)) && (!(RowsAtCompileTime==Dynamic && MaxRowsAtCompileTime!=Dynamic) || (rows<=MaxRowsAtCompileTime)) && (!(ColsAtCompileTime==Dynamic && MaxColsAtCompileTime!=Dynamic) || (cols<=MaxColsAtCompileTime)) && rows>=0 && cols>=0 && "Invalid sizes when resizing a matrix or array."' failed.
[sccloan005:04474] *** Process received signal ***
[sccloan005:04474] Signal: Aborted (6)
[sccloan005:04474] Signal code:  (-6)
[sccloan005:04474] [ 0] /lib/x86_64-linux-gnu/libc.so.6(+0x43f60)[0x7f0ea7daff60]
[sccloan005:04474] [ 1] /lib/x86_64-linux-gnu/libc.so.6(gsignal+0xc7)[0x7f0ea7dafed7]
[sccloan005:04474] [ 2] /lib/x86_64-linux-gnu/libc.so.6(abort+0x121)[0x7f0ea7d91535]
[sccloan005:04474] [ 3] /lib/x86_64-linux-gnu/libc.so.6(+0x2540f)[0x7f0ea7d9140f]
[sccloan005:04474] [ 4] /lib/x86_64-linux-gnu/libc.so.6(+0x35012)[0x7f0ea7da1012]
[sccloan005:04474] [ 5] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x11151e)[0x7f0ea662851e]
[sccloan005:04474] [ 6] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x341991)[0x7f0ea6858991]
[sccloan005:04474] [ 7] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x33e957)[0x7f0ea6855957]
[sccloan005:04474] [ 8] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x3756d4)[0x7f0ea688c6d4]
[sccloan005:04474] [ 9] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x3755ef)[0x7f0ea688c5ef]
[sccloan005:04474] [10] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x375522)[0x7f0ea688c522]
[sccloan005:04474] [11] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x374a36)[0x7f0ea688ba36]
[sccloan005:04474] [12] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x374838)[0x7f0ea688b838]
[sccloan005:04474] [13] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0x374735)[0x7f0ea688b735]
[sccloan005:04474] [14] /home/twesterhout/src/netket/netket/_C_netket.cpython-37m-x86_64-linux-gnu.so(+0xde027)[0x7f0ea65f5027]
[sccloan005:04474] [15] python3(_PyMethodDef_RawFastCallDict+0x12b)[0x5d82db]
[sccloan005:04474] [16] python3[0x4d9dc7]
[sccloan005:04474] [17] python3(PyObject_Call+0x56)[0x5dbc76]
[sccloan005:04474] [18] python3[0x591298]
[sccloan005:04474] [19] python3(_PyObject_FastCallKeywords+0x129)[0x5d93c9]
[sccloan005:04474] [20] python3[0x54b0f1]
[sccloan005:04474] [21] python3(_PyEval_EvalFrameDefault+0x13ae)[0x54f0ee]
[sccloan005:04474] [22] python3(_PyEval_EvalCodeWithName+0x252)[0x54b9f2]
[sccloan005:04474] [23] python3(PyEval_EvalCode+0x23)[0x54dd33]
[sccloan005:04474] [24] python3[0x630f22]
[sccloan005:04474] [25] python3[0x480d87]
[sccloan005:04474] [26] python3(PyRun_InteractiveLoopFlags+0xd4)[0x480f09]
[sccloan005:04474] [27] python3(PyRun_AnyFileExFlags+0x53)[0x631e33]
[sccloan005:04474] [28] python3[0x65414e]
[sccloan005:04474] [29] python3(_Py_UnixMain+0x2e)[0x6544ae]
[sccloan005:04474] *** End of error message ***
Aborted (core dumped)

and

>>> # With this PR
>>> m = netket.machine.RbmSpin(h, n_hidden=-4)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: invalid number of hidden units: -4; expected a non-negative number

I'm looking for some early feedback. Can I go on and implement state_dict interface for other machines? This will greatly benefit interoperability with PyTorch later on.

Merge request reports