dtype errors with adaptive integrator
Created by: jwnys
When using an adaptive integrator and a dtype=jnp.float32 for parameters, samples, etc, the adaptive integrator raises errors. The errors arise due to the following:
- the two dt's in the jax.lax.cond in the accepted case can be different (next_dt vs rk_state.dt)
- the error norm inherits the dtype from the variational state, and so the replaced last_norm and last_scaled_error can be different in the accepted case
- the last happens because we initialize the last_norm e.g. with 0. in the adaptive case, which can be float32 after.
I'm not sure what the best solution is (I wasn't even expecting errors of this kind tbh), but some possibilities are:
- make sure all replaced dt's have the same dtype as rk_state.dt
- initialize last_norm and last_scaled_error fields with a jnp.array with a predefined fixed dtype (not clear how to determine this in general, but float64 would make sense)
- initialize and convert everything to float64 in the RKState.