Skip to content

`get_conn` should throw an error on wrong input size

Created by: PhilipVinc

Currently get_conn does no input shape validation and then reshapes the input carelessly. So if you give it a batch of states, which is not supported, he gives back garbage becase

def get_conn
        return self.get_conn_flattened(
            x.reshape((1, -1)),
            np.ones(1),
        )