diff --git a/netket/vmc_common.py b/netket/vmc_common.py index a47b20a1eb8f8d58864f2645fe41bf4afb8f1463..444cb79c83ae04ada6a05a484c1a06dbd1bad5be 100644 --- a/netket/vmc_common.py +++ b/netket/vmc_common.py @@ -26,44 +26,44 @@ if jax_available: import jax.numpy as jnp -def shape_for_sr(grads, jac): - r"""Reshapes grads and jax from tree like structures to arrays if jax_available + def shape_for_sr(grads, jac): + r"""Reshapes grads and jax from tree like structures to arrays if jax_available - Args: - grads,jac: pytrees of jax arrays or numpy array + Args: + grads,jac: pytrees of jax arrays or numpy array - Returns: - A 1D array of gradients and a 2D array of the jacobian - """ + Returns: + A 1D array of gradients and a 2D array of the jacobian + """ - grads = jnp.concatenate(tuple(fd.reshape(-1) for fd in tree_flatten(grads)[0])) - jac = jnp.concatenate( - tuple(fd.reshape(len(fd), -1) for fd in tree_flatten(jac)[0]), -1 - ) + grads = jnp.concatenate(tuple(fd.reshape(-1) for fd in tree_flatten(grads)[0])) + jac = jnp.concatenate( + tuple(fd.reshape(len(fd), -1) for fd in tree_flatten(jac)[0]), -1 + ) - return grads, jac + return grads, jac -def shape_for_update(update, shape_like): - r"""Reshapes grads from array to tree like structure if neccesary for update + def shape_for_update(update, shape_like): + r"""Reshapes grads from array to tree like structure if neccesary for update - Args: - grads: a 1d jax/numpy array - shape_like: this as in instance having the same type and shape of - the desired conversion. + Args: + grads: a 1d jax/numpy array + shape_like: this as in instance having the same type and shape of + the desired conversion. - Returns: - A possibly non-flat structure of jax arrays containing a copy of data - compatible with the given shape if jax_available and a copy of update otherwise - """ + Returns: + A possibly non-flat structure of jax arrays containing a copy of data + compatible with the given shape if jax_available and a copy of update otherwise + """ - shf, tree = tree_flatten(shape_like) + shf, tree = tree_flatten(shape_like) - updatelist = [] - k = 0 - for s in shf: - size = s.size - updatelist.append(jnp.asarray(update[k : k + size]).reshape(s.shape)) - k += size + updatelist = [] + k = 0 + for s in shf: + size = s.size + updatelist.append(jnp.asarray(update[k : k + size]).reshape(s.shape)) + k += size - return tree_unflatten(tree, updatelist) + return tree_unflatten(tree, updatelist)