[RFC/Plan] Jax-Operators design
Created by: PhilipVinc
This is a sketch of the API I'd like to see for Jax-aware operators. I welcome constructive ideas and opinions on how to improve this.
Objective:
The objective is to implement some jax-aware operators that can be passed as arguments to jax.jit
functions, and are capable to fully operate and execute their logic within jax.jit
blocks (possibly without using Callbacks).
The motivations that I see for this are:
- Better composability with jax.grad & co: we could eventually compute gradients wrt parameters in the Hamiltonian and do (for optimal control applications).
- Possibly better performance on systems where generating connected elements take a long time, as we could parallelise it on a gpu, for example
- Memory reasons: we would not need to store the whole tensor of connected elements in memory, but we could store only a block at a time
Complications
No dynamic shapes. Within a Jax routine we cannot have arrays of unknown size.
So, unless the number of connected elements is known, we have to pad everything to the largest-possible size (determined by operator.max_conn * n_samples
.
- this behaviour is different from the current behaviour of Numbs-aware operators, which compute the
max_conn
on the current batch of samples, which in practice can slightly lower the padding - This behaviour is incompatible with using
op.get_conn_flattened
which would be the ideal behaviour, not (or almost not) padding.
I believe we will have dynamic shapes in 6 to 12 months. So whatever design we come up with should be ready to support that use case.
Key API points
- The fact that an operator has a jax or numba implementation should be transparent to the user, as it is only an implementation detail.
- This means that, for example, one among
HamiltonianSamplerJax
andHamiltonianSamplerNumba
should be automatically selected depending when the user callsHamiltonianSampler(operator)
depending on the implementation of operator - We should provide sane defaults
- This means that, for example, one among
- For hamiltonians where it's unclear what implementation is best, we should provide an easy mechanism for users to switch from one implementation to the other.
- No breaking the current API !!!
- The operator's methods such as
n_conns
andget_conn_padded
should work when used within a jax context.get_conn_flattened
can raise an error or call the numba version (note, however, that in order not to break the API, raising an error would maketo_dense()
not work, so a solution must be found). - support
SumOperator
of jax operators.
Implementation proposal
Unpacking operators
Making an operator Jax-compatible is a relatively easy endeavour. One just needs to teach jax how to unpack and repack them by defining the methods described in this documentation of jax.
For Ising, for example, this method could be something like
def unpack_ising(op):
dynamic_data = (op.h, op.J)
static_data = {"hilbert": op.hilbert, "graph", op.graph}
return dynamic_data, static_data
def unflatten_ising(aux_data, dyn_data):
return Ising(aux_data['hilbert'], graph=aux_data['graph'], h= dyn_data[0], J= dyn_data[1])
With that, Ising would be able to be passed as an argument to a jax.jit
function .
Be careful, that in order to avoid overheads, the dynamic_data
should be stored in a jnp.array
.
This is the opposite of what we do now (np.array
). The two choices are incompatible one with the other and favour jax/numba implementations.
Note: I don't think it's a good idea to use automatic unpack/pack methods such those generated by struct.dataclass
because they require to make all the hierarchy a data class, and for the moment I would avoid such deep changes.
This proposal Is a bit uglier but will result in fewer changers across netket for the moment. We can always switch later.
state.expect
Supporting Our current Operator API for computing expectation values is defined in our beautiful documentation. The easy way API is the one used internally by netket and we should support this one.
The API was designed for numba-operators. It has two functions:
-
nk.vqs.get_local_kernel(vstate, operator, chunk_size)
which, using dispatch, selects the right 'kernel' to be used to compute local estimators for this state and operator. It returns a function that is then passed to a jax method.- The returned function has signature
f(log_psi, pars, samples, *extra_args)
- The returned function has signature
-
nk.vqs.get_local_kernel_arguments(vstate, operator)
, which, using dispatch, computes*extra args
that will be passed to the kernel. This function essentially callsget_conn_padded
and passes it to the jax kernel later on, because the jax kernel cannot callget_conn_padded
.
Assuming Jax-operators inherit from a custom type JaxDiscreteOperator
, we could dispatch on this type and have nk.vqs.get_local_kernel_arguments
simply return the operator itself:
@nk.vqs.get_local_kernel_arguments.dispatch
def get_local_kernel_args_jax(vstate: nk.vqs.MCState, op: JaxDiscreteOperator):
return vstate.samples, op # this is jax-aware now
@nk.vqs.get_local_kernel_arguments.dispatch
def get_local_kernel(vstate: nk.vqs.MCState, op: JaxDiscreteOperator, chunk_size: None):
def _impl(logsi, pars, samples, operator):
mels, sigmap = op.get_conn_padded(samples)
....
return op_loc
return op
The kernel implementation should also work with a set chunk size.
Class
The implementation above would make it easier for 3rd parties to define custom hamiltonians that are jax.compatible, provided we give some sane defaults in this base class.
As our continuous operators are already Jax-aware, this could be a mixin-class
, namely we would inherit both from DiscreteOperator
and from JaxDiscreteOperator
. Alternatively we should decide where to keep this object in the class hierarchy...
Open Questions:
- this works with
SumOperator
. Would this work with a non existentTensorOperator
? - ...