This was brought up tangentially in #711 (closed) but I think it is worth putting in it's own issue. Right now, as far as I can tell, Netket computes all logpsi needed to compute an expectation value in a single forward pass. This causes memory errors for moderately sized models.
I think it would be really nice if there was a "max_states_per_forward_pass" parameter that allowed you to control the memory usage.
I'm also curious if anyone has found any nice workarounds to this issue.
Designs
Child items
...
Show closed items
Linked items
0
Link issues together to show that they're related.
Learn more.
Vicentini Filippochanged title from Feature Request: Improved Control over Memory Usage for Computing Expectation Values to [Feature Request] Improved Control over Memory Usage for Computing Expectation Values
changed title from Feature Request: Improved Control over Memory Usage for Computing Expectation Values to [Feature Request] Improved Control over Memory Usage for Computing Expectation Values
Yes, it's indeed something that should be implemented.
I think I also had listed this on the list of things i'd like to see in netket in the foreseeable future in #559.
Of course a PR addressing the issue would be accepted, if someone wanted to work on this.
We should maybe agree first on an api to do this. Where this max_states_per_forward_pass should be stored (I'd say in the variational state, in order to keep the .expect and .expect_and_grad interface clean.
Also, computing the gradient (reverse pass) has a (linearly) higher memory cost so in principle one would want two different values for this, a lower one for the gradient, but we can just keep one for simplicity.
To do this, another change would be required first: we should change/add a new way to compute the statistical indicators of a chain (Mean/Error/variance/autocorrelation/r...) online rather than offline (meaning, you don't need all data at once but you can update those indicators as more data comes in).
Once this is implemented, then mnaking .expect and .expect and grad work with batches should not bee too much work.
Making SR work with this, however, is more complicated. @inailuig has ideas there but i suspect this will take a while to see the light.
To do this, another change would be required first: we should change/add a new way to compute the statistical indicators of a chain (Mean/Error/variance/autocorrelation/r...) online rather than offline (meaning, you don't need all data at once but you can update those indicators as more data comes in).
So the idea would be to break up the samples so that you never exceed max_states_per_forward_pass and then update the statistics as you feed in samples?
Another reasonable approach would be to just split the forward(sigma_p) computation into chunks.
Once this is implemented, then making .expect and .expect and grad work with batches should not bee too much work.
Making SR work with this, however, is more complicated. @inailuig has ideas there but i suspect this will take a while to see the light.
Yeah this seems slightly less pressing because QGTOnTheFly probably won't run out of memory for a while. Now if you want to use the precomputed Jacobian, this kind of thing could be useful.
I think this is indeed one option that we should have! Concerning the apis, I believe that the "good" solution is to put a field max_states_per_forward_pass (or a better name) directly into the model. Then it is the responsability of the drivers or anything that calls the model to honor that constraint
We should maybe agree first on an api to do this. Where this max_states_per_forward_pass should be stored (I'd say in the variational state, in order to keep the .expect and .expect_and_grad interface clean.
I agree, however I think this should not be directly in the variational state, rather in the model itself (this seems to me like an implementation detail of the model not a general property of variational states?)
To do this, another change would be required first: we should change/add a new way to compute the statistical indicators of a chain (Mean/Error/variance/autocorrelation/r...) online rather than offline (meaning, you don't need all data at once but you can update those indicators as more data comes in).
Is this really required though? I think it is much easier if we just split the computations into chunks of fixed max size?
I agree, however I think this should not be directly in the variational state, rather in the model itself (this seems to me like an implementation detail of the model not a general property of variational states?)
I agree that this is the best solution.
As a note, I tried this with a custom model, the call function looks something like this:
n_iters = (len(inp) - 1) // max_states + 1for i in range(n_iters): x = inp[i*max_states:(i+1)*max_states] x = some_func(x) if i == 0: out = x else: out = jnp.concatenate((out,x),0)return out
If I just jit-compile model.apply() it will work as intended but when I interface this model with NetKet I still run into memory errors.
The reason I suggested stuffing this stuff in MCState is because the implementation can be a general one and work always.
putting it in the model is also doable for forward passes, but I’m unsure of the performance for the reverse pass. It’s also hard to properly write it at that level.
Also, every model will need to implement it.
@chrisrothUT this is not doing what you think.
This is also compiling multiple times the forward pass so it’s increasing compile time unnecessarily.
You should read about how loops work in jax. You can’t use for.
@chrisrothUT this is not doing what you think.
This is also compiling multiple times the forward pass so it’s increasing compile time unnecessarily.
You should read about how loops work in jax. You can’t use for.