Handling negative reals correctly in `logsumexp`
Created by: attila-i-szabo
- Solves the issue with GCNNs described at #1133 (closed).
- Adds
nk.nn.logsumexp_cplx
, which wraps the JAXlogsumexp
but always returns complex results, handles the logs of negative reals correctly within complex arithmetics, without promoting all real inputs to complex (which might be a memory bottleneck for some applications).
Only thing to settle is whether the new ensure_cplx
flag should default to true or false. True is a minimal breaking change (hitherto real outputs will come with identically zero imaginary parts), but it removes a silent error we had until now, so I would go for it.