Source code for gobbli.test.augment.test_word2vec

from pathlib import Path

import gensim
import pytest

from gobbli.augment.word2vec import _WORD2VEC_MODEL_ARCHIVES, WORD2VEC_MODELS, Word2Vec
from gobbli.util import TokenizeMethod


[docs]def test_word2vec_models(): for model_name, (archive_name, filename) in WORD2VEC_MODELS.items(): assert archive_name in _WORD2VEC_MODEL_ARCHIVES
_TEST_WORD2VEC_VECTORS = """2 3 test 0 1 1 test2 1 0 1 """
[docs]def make_weights_file(tmpdir): weights_file = Path(tmpdir) / "word2vec.txt" weights_file.write_text(_TEST_WORD2VEC_VECTORS) return weights_file
[docs]def test_word2vec_init(tmpdir): # Path corresponding to existing weights file should be loaded weights_file = make_weights_file(tmpdir) Word2Vec(weights_file) # Existing model should be loaded m = gensim.models.KeyedVectors.load_word2vec_format(str(weights_file)) Word2Vec(m) # String corresponding to one of the pretrained models should download and # use the weights Word2Vec("fasttext-simple") # Anything else should error with pytest.raises(TypeError): Word2Vec(None)
[docs]@pytest.mark.parametrize( "n_similar,diversity,tokenizer,exception_cls", [ # wrong type n_similar (5.5, 1.0, TokenizeMethod.SPLIT, TypeError), # wrong type diversity (5, 1, TokenizeMethod.SPLIT, TypeError), # wrong type tokenizer (5, 0.8, 1, TypeError), # bad value n_similar (0, 1.0, TokenizeMethod.SPLIT, ValueError), # bad value diversity (<= 0) (5, 0.0, TokenizeMethod.SPLIT, ValueError), # bad value diversity (> 1) (5, 1.1, TokenizeMethod.SPLIT, ValueError), # bad value str tokenizer (no match) (5, 0.8, "no such tokenizer", KeyError), # ok (enum tokenizer) (5, 0.8, TokenizeMethod.SPLIT, None), # ok (callable tokenizer) (5, 0.8, lambda s: s.split(), None), # ok (str tokenizer) (5, 0.8, "SPLIT", None), ], ) def test_word2vec_kwargs(tmpdir, n_similar, diversity, tokenizer, exception_cls): weights_file = make_weights_file(tmpdir) kwargs = {"n_similar": n_similar, "diversity": diversity, "tokenizer": tokenizer} if exception_cls is None: Word2Vec(weights_file, **kwargs) else: with pytest.raises(exception_cls): Word2Vec(weights_file, **kwargs)
[docs]def test_word2vec_replace(tmpdir): weights_file = make_weights_file(tmpdir) m = Word2Vec(weights_file) # we should find a replacement for a word in the vocabulary assert m._maybe_replace_token("test") == "test2" # out of vocabulary words shouldn't be replaced assert m._maybe_replace_token("bad_token") == "bad_token"