Skip to content

TensorBoard support

Vicentini Filippo requested to merge PhilipVinc/tb into v3.0

Created by: PhilipVinc

import netket as nk
import numpy as np
import jax
from jax.experimental.optimizers import sgd as JaxSgd
import cProfile

# 1D Lattice
L = 20
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)

# Hilbert space of spins on the graph
hi = nk.hilbert.Spin(s=0.5, graph=g)

ha = nk.operator.Ising(h=1.0, hilbert=hi)

alpha = 1
ma = nk.machine.JaxRbm(hi, alpha, dtype=float)
ma.init_random_parameters(seed=1232)

# Jax Sampler
sa = nk.sampler.jax.MetropolisLocal(machine=ma, n_chains=8)

# Using a Jax Optimizer
j_op = JaxSgd(0.1)
op = nk.optimizer.Jax(ma, j_op)


# Stochastic Reconfiguration
sr = nk.optimizer.JaxSR(diag_shift=0.1)

# Create the optimization driver
gs = nk.Vmc(
    hamiltonian=ha, sampler=sa, optimizer=op, n_samples=1000, sr=sr, n_discard=0
)

log_tb = nk.logging.TBLog("prova")
log_wf = nk.logging.JsonLog("ndm", write_every=10)
logs = nk.logging.CombineLogs(log_tb, log_wf)

# The first iteration is slower because of start-up jit times
gs.run(1, out=log)
gs.run(300, out=log)

Merge request reports