Source code for gobbli.test.dataset.test_base_dataset
import pytest
from gobbli.test.util import MockDataset
[docs]def test_base_dataset_load():
ds = MockDataset()
# Dataset should be unbuilt after default initialization
assert ds._build_count == 0
ds = MockDataset.load()
# Dataset should now be built
assert ds._build_count == 1
ds.load()
# Dataset shouldn't have been built again
assert ds._build_count == 1
[docs]def test_base_dataset_train_input():
# Need to build first
with pytest.raises(ValueError):
MockDataset().train_input()
ds = MockDataset.load()
# No limit
train_input = ds.train_input(valid_proportion=0.5)
X_len = len(MockDataset.X_TRAIN_VALID)
assert len(train_input.X_train) == X_len / 2
assert len(train_input.y_train) == X_len / 2
assert len(train_input.X_valid) == X_len / 2
assert len(train_input.y_valid) == X_len / 2
# Limit
train_input = ds.train_input(valid_proportion=0.5, limit=2)
assert len(train_input.X_train) == 1
assert len(train_input.y_train) == 1
assert len(train_input.X_valid) == 1
assert len(train_input.y_valid) == 1
[docs]def test_base_dataset_predict_input():
# Need to build first
with pytest.raises(ValueError):
MockDataset().train_input()
ds = MockDataset.load()
# No limit
predict_input = ds.predict_input()
X_len = len(MockDataset.X_TEST)
assert len(predict_input.X) == X_len
assert set(predict_input.labels) == set(MockDataset.Y_TEST)
# Limit applied
predict_input = ds.predict_input(limit=1)
assert len(predict_input.X) == 1
# Make sure we only have the labels from the limited subset
assert set(predict_input.labels) < set(MockDataset.Y_TEST)