`get_conn` should throw an error on wrong input size
Created by: PhilipVinc
Currently get_conn does no input shape validation and then reshapes the input carelessly. So if you give it a batch of states, which is not supported, he gives back garbage becase
def get_conn
return self.get_conn_flattened(
x.reshape((1, -1)),
np.ones(1),
)