Source code for gobbli.test.augment.test_bertmaskedlm
import pytest
from gobbli.augment.bert import BERTMaskedLM
from gobbli.test.util import model_test_dir
[docs]@pytest.mark.parametrize(
"params,exception",
[
# Unknown param
({"unknown": None}, ValueError),
# Bad type (diversity)
({"diversity": 2}, TypeError),
# Bad type (batch size)
({"batch_size": 2.5}, TypeError),
# Bad type (n_probable)
({"n_probable": 2.5}, TypeError),
# Bad value (diversity)
({"diversity": 0.0}, ValueError),
# Bad value (batch_size)
({"batch_size": 0}, ValueError),
# Bad value (n_probable)
({"n_probable": 0}, ValueError),
# OK values
({"diversity": 0.5, "n_probable": 3, "batch_size": 16}, None),
],
)
def test_init(params, exception):
if exception is None:
BERTMaskedLM(**params)
else:
with pytest.raises(exception):
BERTMaskedLM(**params)
[docs]def test_bertmaskedlm_augment(model_gpu_config, gobbli_dir):
model = BERTMaskedLM(
data_dir=model_test_dir(BERTMaskedLM), load_existing=True, **model_gpu_config
)
model.build()
times = 5
new_texts = model.augment(["This is a test."], times=times)
assert len(new_texts) == times