Improve API to define autoregressive networks, maybe make them public?
Created by: PhilipVinc
One year ago @wdphy16 had to work around some bugs in Flax inheritance and put several functions (__call__
and __conditional_log_psi__
) of the models as standalone functions and all the interface was pretty messy.
Those bugs have long been fixed in Flax, so I would like to clean up our Autoregressive method inheritance so that the interface is cleaner and implementing a custom Neural Network becomes simpler.
Compared to the previous interface, this one essentially asks users to inherit either from AbstractRecursiveARNN
and only define a set of layers in the setup()
method, or to inherit from AbstractARNN
and declare a conditionals_log_psi
function that returns the set of conditional log wavefucntions.
The resulting interface is quite clean, in my opinion. I would like the opinion of @wdphy16 and Zakari. In particular, I would like to understand if this can play well with the RNN Pr of @wdphy16