Source code for gobbli.test.model.test_bert
import pytest
from gobbli.model.bert import BERT
[docs]@pytest.mark.parametrize(
"params,exception",
[
# Unknown param
({"unknown": None}, ValueError),
# Bad type (max_seq_length)
({"max_seq_length": "100"}, TypeError),
# Bad value (bert_model)
({"bert_model": "ernie"}, ValueError),
# OK type (max_seq_length)
({"max_seq_length": 100}, None),
# OK value (bert_model)
({"bert_model": "bert-base-uncased"}, None),
# OK values (both params)
({"max_seq_length": 100, "bert_model": "bert-base-uncased"}, None),
],
)
def test_init(params, exception):
if exception is None:
BERT(**params)
else:
with pytest.raises(exception):
BERT(**params)