Source code for gobbli.test.dataset.test_newsgroups
import pandas as pd
from gobbli.dataset.newsgroups import NewsgroupsDataset
[docs]def test_load_newsgroups(tmp_gobbli_dir):
ds = NewsgroupsDataset.load()
X_train = ds.X_train()
X_test = ds.X_test()
y_train = ds.y_train()
y_test = ds.y_test()
assert len(X_train) == 11314
assert len(y_train) == 11314
assert len(X_test) == 7532
assert len(y_test) == 7532
assert len(pd.unique(y_train)) == 20
assert len(pd.unique(y_test)) == 20
# Ensure these objects pass validation
ds.train_input()
ds.predict_input()