Skip to content

Remove vendored flax layers

Vicentini Filippo requested to merge pv/flax-future into master

Created by: PhilipVinc

https://github.com/google/flax/pull/2058 is going to land soon.

Once they tag a new version, I will bump NetKet's minimum flax version to match and would like to merge this PR. This comes with the deletion of 400 lines of code (and we add 150 that aree mainly deprecation warnings...)

This PR removes our own implementation of Dense and Conv layers, deprecating nk.nn.Dense/Conv and instructing users to use flax.linen.Dense and flax.linen.Conv.

Post-flax/#2058 their layers will have two attributes: dtype and param_dtype.

  • param_dtype specifies the dtype of the parameters in the layer (and is equivalent to netket's current dtype attribute`)
  • dtype is an optional type, None by default, and specifies the precision used to compute the layer. This will only work on accelerators (GPU/TPU) .

Most likely, netket's users won't need dtype and only param_dtype.

This will only affect those that define their own architectures, and the error message is quite clear. The old syntax will still be supported for the time being with a (very loud) deprecation warning.

Merge request reports