Skip to content

Improve API to define autoregressive networks, maybe make them public?

Vicentini Filippo requested to merge pv/ancestral-improvements into master

Created by: PhilipVinc

One year ago @wdphy16 had to work around some bugs in Flax inheritance and put several functions (__call__ and __conditional_log_psi__) of the models as standalone functions and all the interface was pretty messy.

Those bugs have long been fixed in Flax, so I would like to clean up our Autoregressive method inheritance so that the interface is cleaner and implementing a custom Neural Network becomes simpler.

Compared to the previous interface, this one essentially asks users to inherit either from AbstractRecursiveARNN and only define a set of layers in the setup() method, or to inherit from AbstractARNN and declare a conditionals_log_psi function that returns the set of conditional log wavefucntions.

The resulting interface is quite clean, in my opinion. I would like the opinion of @wdphy16 and Zakari. In particular, I would like to understand if this can play well with the RNN Pr of @wdphy16

Merge request reports