Skip to content
Snippets Groups Projects
Commit 2e857ffd authored by Filippo Vicentini's avatar Filippo Vicentini
Browse files

fulstate fix

parent 6d9c59e9
Branches pv/logstate
No related tags found
No related merge requests found
......@@ -13,6 +13,8 @@
# limitations under the License.
import flax.linen as nn
import jax
import jax.numpy as jnp
from netket.hilbert import DiscreteHilbert
......@@ -52,4 +54,7 @@ class LogStateVector(nn.Module):
)
def __call__(self, x_in: Array):
return self.logstate[states_to_numbers(self.hilbert, x_in)]
indices = jax.lax.stop_gradient(states_to_numbers(self.hilbert, jax.lax.stop_gradient(x_in)))
return self.logstate[indices]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment