Source code for gobbli.test.augment.test_marian

import pytest

from gobbli.augment.marian import MarianMT
from gobbli.test.util import model_test_dir


[docs]@pytest.mark.parametrize( "params,exception", [ # Unknown param ({"unknown": None}, ValueError), # Bad type (batch_size) ({"batch_size": 2.5}, TypeError), # Bad type (target_languages) ({"target_languages": "english"}, TypeError), # Bad value (batch_size) ({"batch_size": 0}, ValueError), # Bad value (target_languages) ({"target_languages": ["not a language"]}, ValueError), # Bad value, one OK value (target_languages) ({"target_languages": ["french", "not a language"]}, ValueError), # OK values ({"batch_size": 16, "target_languages": ["russian", "french"]}, None), ], ) def test_init(params, exception): if exception is None: MarianMT(**params) else: with pytest.raises(exception): MarianMT(**params)
[docs]def test_marianmt_augment(model_gpu_config, gobbli_dir): # Don't go overboard with the languages here, since each # one requires a separate model (few hundred MB) to be downloaded target_languages = ["russian", "french"] model = MarianMT( data_dir=model_test_dir(MarianMT), load_existing=True, target_languages=target_languages, **model_gpu_config, ) model.build() # Can't augment more times than target languages invalid_num_times = len(target_languages) + 1 with pytest.raises(ValueError): model.augment(["This is a test."], times=invalid_num_times) valid_num_times = len(target_languages) new_texts = model.augment(["This is a test."], times=valid_num_times) assert len(new_texts) == valid_num_times