RFC: removing numpy and pytorch and going towards a jax-only world
Created by: PhilipVinc
This discussion refers to v3.0. I'd appreciate comments from everyone (seriously, everyone).
Currently the un-shipped v3.0 has 3 backends, that is, 3 frameworks you can use to define your neural networks:
- numpy, there mainly for historical reasons, where you have to hand write everything in numpy (and numba to get decent performance). Of course you also have to hand-code the derivative of your network.
- Jax, which has seen a lot of recent developments. It's usually at least if not more efficient than numpy/numba. Defining networks is easy, and you get the derivative for free. With my recent PR we now also support MPI operations for SR, but we don't have as many SR solvers as for numpy (diagonalization for example). Very soon this will also allow to support multiGPU execution without changing any code.
- PyTorch, which is similar to jax in that networks are easy to define and automatically computes the gradient. However SR is very slow due to limitations in PyTorch which won't be addressed soon by developers. You can't use MPI with PyTorch unless someone writes the equivalent of mpi4jax for PyTorch, something that I'm not gonna do and requires a lot of effort.
Me and @gcarleo have recently been discussing about simplifying the internals of netket notably to fully exploit Automatic Differentiation in order to write more compact drivers for unsupervised/supervised learning and other future applications.
Simplyifying the internals mean, mainly, removing back-ends. Ideally all but one. Our current idea is that numpy can be dropped as it's quite unflexible, provided we port to Jax/PyTorch the missing features. However, we have also been toying with the idea of removing PyTorch support, in order to have a much more compact, functional and streamlined codebase.
This would also simplify the API, because we wouldn't need to discriminate between the backends anymore.
What do people think? Can anyone defend strongly PyTorch and give a valid reason for keeping it around? How many people are intensely using it?
cc @femtobit @kchoo1118 @twesterhout @fabienalet and those who contributed the PyTorch backend @ChenAo-Phys @nikita-astronaut