Source code for gobbli.test.model.test_mtdnn

import pytest

from gobbli.model.mtdnn import MTDNN


[docs]@pytest.mark.parametrize( "params,exception", [ # Unknown param ({"unknown": None}, ValueError), # Bad type (max_seq_length) ({"max_seq_length": "100"}, TypeError), # Bad value (mtdnn_model) ({"mtdnn_model": "bert"}, ValueError), # OK type (max_seq_length) ({"max_seq_length": 100}, None), # OK value (mtdnn_model) ({"mtdnn_model": "mt-dnn-base"}, None), # OK values (both params) ({"max_seq_length": 100, "mtdnn_model": "mt-dnn-base"}, None), ], ) def test_init(params, exception): if exception is None: MTDNN(**params) else: with pytest.raises(exception): MTDNN(**params)