Skip to content

Add RNN

Vicentini Filippo requested to merge github/fork/wdphy16/rnn into master

Created by: wdphy16

RNNs are a subclass of ARNNs, and they reuse ARNN's sampler. They support complex parameters using the same normalization as ARNNs, and I'll leave the mod-phase version to a more general implementation of ModPhase.

When using an RNN on a 2D (or higher-D) lattice, we usually need a non-trivial autoregressive order (like the snake order) to utilize the locality of the lattice. For 1D RNN it's straightforward to implement because the RNN cell only needs to access the previous site in the autoregressive order, without knowing the graph. Currently 1D RNN already supports arbitrary ordering.

However, for 2D RNN it's complicated to implement the ordering, because the RNN cell needs to access previous spatial neighbors, which are defined by both the ordering ('previous') and the graph. Currently 2D RNN only supports snake ordering for square lattice (and there is actually no check for square lattice), which is hard-coded in nk.nn.rnn_2d._get_h_xy and nk.models.rnn.LSTMNet2D.setup. If we really want to support arbitrary ordering, we need a fast and easy-to-use way to access previous spatial neighbors given the ordering and the graph, and make sure that the number of neighbors is the same for every step (if not... that's the job of graph RNN).

The methods reorder and inverse_reorder have been added to AbstractARNN because ARDirectSampler needs them, but currently there is no way to specify ordering for existing ARNNs.

There is also a proposal google/flax#2396 to add RNN in Flax. We've diverged a lot from them because we need to put RNN into the existing framework of ARNN and sampler, and we need to support any geometry. After they finalize their implementation, I'll be happy to refactor this and reuse what they have.

Merge request reports