diff --git a/netket/vmc_common.py b/netket/vmc_common.py
index a47b20a1eb8f8d58864f2645fe41bf4afb8f1463..444cb79c83ae04ada6a05a484c1a06dbd1bad5be 100644
--- a/netket/vmc_common.py
+++ b/netket/vmc_common.py
@@ -26,44 +26,44 @@ if jax_available:
     import jax.numpy as jnp
 
 
-def shape_for_sr(grads, jac):
-    r"""Reshapes grads and jax from tree like structures to arrays if jax_available 
+    def shape_for_sr(grads, jac):
+        r"""Reshapes grads and jax from tree like structures to arrays if jax_available 
 
-    Args:
-        grads,jac: pytrees of jax arrays or numpy array
+        Args:
+            grads,jac: pytrees of jax arrays or numpy array
 
-    Returns:
-        A 1D array of gradients and a 2D array of the jacobian
-    """
+        Returns:
+            A 1D array of gradients and a 2D array of the jacobian
+        """
 
-    grads = jnp.concatenate(tuple(fd.reshape(-1) for fd in tree_flatten(grads)[0]))
-    jac = jnp.concatenate(
-        tuple(fd.reshape(len(fd), -1) for fd in tree_flatten(jac)[0]), -1
-    )
+        grads = jnp.concatenate(tuple(fd.reshape(-1) for fd in tree_flatten(grads)[0]))
+        jac = jnp.concatenate(
+            tuple(fd.reshape(len(fd), -1) for fd in tree_flatten(jac)[0]), -1
+        )
 
-    return grads, jac
+        return grads, jac
 
 
-def shape_for_update(update, shape_like):
-    r"""Reshapes grads from array to tree like structure if neccesary for update 
+    def shape_for_update(update, shape_like):
+        r"""Reshapes grads from array to tree like structure if neccesary for update 
 
-    Args:
-        grads: a 1d jax/numpy array
-        shape_like: this as in instance having the same type and shape of
-                    the desired conversion.
+        Args:
+            grads: a 1d jax/numpy array
+            shape_like: this as in instance having the same type and shape of
+                        the desired conversion.
 
-    Returns:
-        A possibly non-flat structure of jax arrays containing a copy of data
-        compatible with the given shape if jax_available and a copy of update otherwise
-    """
+        Returns:
+            A possibly non-flat structure of jax arrays containing a copy of data
+            compatible with the given shape if jax_available and a copy of update otherwise
+        """
 
-    shf, tree = tree_flatten(shape_like)
+        shf, tree = tree_flatten(shape_like)
 
-    updatelist = []
-    k = 0
-    for s in shf:
-        size = s.size
-        updatelist.append(jnp.asarray(update[k : k + size]).reshape(s.shape))
-        k += size
+        updatelist = []
+        k = 0
+        for s in shf:
+            size = s.size
+            updatelist.append(jnp.asarray(update[k : k + size]).reshape(s.shape))
+            k += size
 
-    return tree_unflatten(tree, updatelist)
+        return tree_unflatten(tree, updatelist)