implemented periodic mps in jax
Created by: tvieijra
So I implemented the periodic MPS in Jax. All the functionality of the c++ PeriodicMPS code is also in the Jax code.
The main limitation right now is that it is limited to local Hilbert spaces with local states which are evenly spaced. This is due to the fact that the input needs to be transformed to a form where the local state can be used as an index for the MPS tensor. Using a dictionary from localstate to index does not work with jitted functions as the input cannot be hashed. So the solution right now is to define the transformation using a bias and a scale transformation on the input. I think this limitation is not much of an issue as most Hilbert spaces are evenly spaced.
The initialization of the tensors is fixed to random normally distributed values around zero with standard deviation of 1e-2. I noticed that the initialization of Jax machines does not allow to set the sigma right now, so 1e-2 seems a good default value.
Let me know if you have questions, remarks or suggestions!