NetKet Developer Meeting Notes
Created by: PhilipVinc
27 April 2022 Meeting
Partecipants: @PhilipVinc @gcarleo @femtobit @inailuig @attila-i-szabo @chrisrothUT @danielalcalde
1. Discussion on log-gradient implementation (#1170)
It was raised the issue that the log-gradient for non-holomorphic models with complex parameters is not sufficient to compute the gradient of expectation values, the QGT, or other objects. If the model has real parameters or holomorphic models with complex parameters there are no problems.
In the (complex parameters) non-holomorphic case, one must compute both the complex-valued gradient with respect to the real-part of the parameters and with respect to the imaginary part of the parameters. The main problem is that the way one handles those two objects is significantly more complex, as one often has to recast the resulting gradient to the form of the parameters.
Accepted proposals:
- Expose a function
nk.nn.split_complex_parameters_of_model
(name to be decided) which takes a jax/flax/haiku model definition and returns a new model which has only real-valued parameters. - Add a new flag to
MCState
constructor, similar to the oldholomorphic=True/False
. Normally this flag must not be set - Add a new
MCState.log_gradient(samples)
function toMCState
computing the jacobian of the model for the samples provided. If the model has real parameters this works fine, if the model has complex parameters andholomoprhic=False
, error or raise warning. - Add a utility constructor to
QGTJacobianPyTree
andQGTJacobianDense
to construct them directly from the output oflog_gradient(samples)
. - The jacobian should not be cached as it might take up a lot of memory. Adding the utility constructors at the point above should allow users to write their own logic that computers the gradient only once.
TODO:
- Decide the names of the function wrapping the module,
MCState.log_gradient
and theholomorphic
flag.
2. Discussion on utility function to provide local estimators (linked to #1154)
It was discussed that Damian's proposal from #1154 of adding a flag to compute local estimators to MCState.expect
such as:
expval = MCState.expect(operator, *, [return_estimators=False])
expval, E_loc = MCState.expect(operator, *, return_estimators=True)
is not desirable as (i) this flag is MCState-specific (not for ExactState) @gcarleo and (ii) it's hard to use it in generic drivers
@attila-i-szabo proposal was to include them in the Stats
object
expval = vstate.expect(operator, return_estimators=True/False)
type(expval) --> Stats
expval.local_values # vector ## bikeshed name
- @PhilipVinc pointed out that this approach would be nice because it allows us to do propagation of errors when multiplying/adding different expectation values.
- The memory consumption of a single instance is negligible compared to the samples that we cache anyway, but @gcarleo mentions that if you have tens or hundreds of operators and expectation values around this takes a lot of memory.
- @gcarleo says this allows ALPS style bootstrap and other statistical niceties.
- Serialisation and History objects are not a problem according to @PhilipVinc
@attila-i-szabo wonders how to pass this flag to the inside of a Driver so that it can be used to implement for example Markus Schmitt renormalisation scheme.
Accepted Proposal:
Only implement a function E_loc = MCState.local_estimators(operator)
that returns the local estimators. Users can build on it together with log_gradient
.
Further integrations can be re-evaluated in the future.
- Additional Comment: can we make this differentiable?
3. Schedules for preconditions (linked to #1142)
Right now, we have that the API for preconditions is:
new_gradient = preconditioner(vstate, gradient)
Attila has proposed to
- pass the time/iteration so that we can have a schedule
- what happens if the preconditioner is stateful? Can we pass the whole driver?
- @PhilipVinc is strongly against because he wants the driver to be something optional. Everything must work without drivers
- proposal:
driver.preconditioner(vstate, gradient, driver.step_value)
Filippo wants to push it further to also implement stateful schedules like ADAM for the diagonal_shift
. To do this, he proposes a design similar to optax:
precond_state = preconditioner.init(vstate, gradient)
new_grad, precond_state = preconditioner.apply(vstate, gradient, driver.step_value, precond_state)
Result:
No conclusion was reached. It seems the only two people interested are @PhilipVinc and @attila-i-szabo . It will be discussed again in the future
4. Overall View
An item that popped several times along the discussion and was highlighted by @chrisrothUT at the end is the need to have some mechasism to easily specify a particular gradient to be computed in an efficient way, effectively erasing the disctinction between expect_and_grad
and preconditioner
, and possibly reusign/caching quantities between the two parts.
Mostly everyone agrees that we need this
Attila's proposal is to have a GradientType
object that only takes a vstate at every iteration.
Filippo highlights that this is similar to defining a particular gradient dispatch rule, but acknowledges that nobody understands how to properly define dispatch rules besides himself.
Result:
This will be discussed in following meetings, partecipants are encouraged to bring forward some design proposals.