Source code for gobbli.test.util
from pathlib import Path
from typing import Any
import pytest
from gobbli.dataset.base import BaseDataset
from gobbli.experiment.base import BaseExperiment
from gobbli.model.base import BaseModel
from gobbli.model.bert import BERT
from gobbli.model.fasttext import FastText, FastTextCheckpoint
from gobbli.model.mtdnn import MTDNN
from gobbli.model.transformer import Transformer
from gobbli.util import gobbli_dir
[docs]def skip_if_no_gpu(config):
if not config.option.use_gpu:
pytest.skip("needs --use-gpu option to run")
[docs]def skip_if_low_resource(config):
"""
Used when a test involves a large amount of CPU, memory, etc,
and the user has indicated we're running in a resource-limited
environment.
"""
if config.option.low_resource:
pytest.skip("skipping large test due to --low-resource option")
# TODO can we write a type declaration to indicate that args
# should be classes derived from BaseModel?
[docs]def model_test_dir(model_cls: Any) -> Path:
"""
Return a directory to be used for models of the passed type.
Helpful when the user wants data to be persisted so weights don't
have to be reloaded for each test run.
"""
return gobbli_dir() / "model_test" / model_cls.__name__
[docs]def validate_checkpoint(model_cls: Any, checkpoint: Path):
"""
Use assertions to validate a given checkpoint depending on which kind of
model created it.
"""
if model_cls == BERT:
# The checkpoint isn't an actual file, but it should have an associated metadata file
assert Path(f"{str(checkpoint)}.meta").is_file()
elif model_cls == MTDNN:
# The checkpoint is a single file
assert checkpoint.is_file()
elif model_cls == FastText:
# The checkpoint has a couple components
fasttext_checkpoint = FastTextCheckpoint(checkpoint)
assert fasttext_checkpoint.model.exists()
assert fasttext_checkpoint.vectors.exists()
elif model_cls == Transformer:
assert checkpoint.is_dir()
[docs]class MockDataset(BaseDataset):
"""
A minimal dataset derived from BaseDataset for testing the
ABC's logic.
"""
X_TRAIN_VALID = ["train1", "train2", "train3", "train4"]
Y_TRAIN_VALID = ["0", "1", "0", "1"]
X_TEST = ["test1", "test2"]
Y_TEST = ["1", "0"]
def __init__(self, *args, **kwargs):
self._built = False
self._build_count = 0
def _is_built(self) -> bool:
return self._built
def _build(self):
self._build_count += 1
self._built = True
[docs] def X_train(self):
return MockDataset.X_TRAIN_VALID
[docs] def y_train(self):
return MockDataset.Y_TRAIN_VALID
[docs] def X_test(self):
return MockDataset.X_TEST
[docs] def y_test(self):
return MockDataset.Y_TEST
[docs]class MockModel(BaseModel):
"""
A minimal model derived from BaseModel for testing the
ABC's logic.
"""
[docs] def init(self, params):
self.params = params
def _build(self):
pass
[docs]class MockExperiment(BaseExperiment):
"""
A minimal experiment derived from BaseExperiment for testing the
ABC's logic.
"""