`QGTJacobianDense` does not work with models using single precision(float32/complex64) parameters
Created by: PhilipVinc
This was brought to my attention right now. I think it happens because JacobianDense concatenates the parameters and the way we do this plays badly when the parameters are single precision but the output of the model is double precision...
Should be relatively easy to fix, if someone wanted to pick it up (@inailuig ?)
Note: this came up by someone using a GCNN with single precision parameters. The default QGT implementation in that case picks jacobiandense which then hits this bug.
Reproducer
import netket as nk
import jax.numpy as jnp
# 1D Lattice
L = 20
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
# Hilbert space of spins on the graph
hi = nk.hilbert.Spin(s=1 / 2, N=g.n_nodes)
# Ising spin hamiltonian
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
# RBM Spin Machine
ma = nk.models.RBM(alpha=1, dtype=jnp.float32)
# Metropolis Local Sampling
sa = nk.sampler.MetropolisLocal(hi, n_chains=16)
# Optimizer
op = nk.optimizer.Sgd(learning_rate=0.1)
# SR
sr = nk.optimizer.SR(qgt=nk.optimizer.qgt.QGTJacobianDense(diag_shift=0.01))
# Variational state
vs = nk.vqs.MCState(sa, ma, n_samples=1000, n_discard_per_chain=100)
# Variational monte carlo driver with a variational state
gs = nk.VMC(ha, op, variational_state=vs, preconditioner=sr)
# Run the optimization for 300 iterations
gs.run(n_iter=300, out=None)