Source code for gobbli.test.interactive.test_util

import inspect
from pathlib import Path

import pytest

from gobbli.interactive.util import read_data_file


[docs]@pytest.mark.parametrize( "filename,contents,multilabel,limit,expected", [ # CSV, no text column ("test.csv", "a,b", False, None, ValueError), # CSV, no label column ( "test.csv", "text,b\ntest1,1\ntest2,2", False, None, (["test1", "test2"], None), ), # CSV ( "test.csv", "text,label\ntest1,1\ntest2,2", False, None, (["test1", "test2"], ["1", "2"]), ), # CSV, limit ("test.csv", "text,label\ntest1,1\ntest2,2", False, 1, (["test1"], ["1"])), # CSV, multilabel ( "test.csv", 'text,label\ntest1,"1,2"\ntest2,', True, None, (["test1", "test2"], [["1", "2"], []]), ), # TSV, no text column ("test.tsv", "a\tb", False, None, ValueError), # TSV, no label column ( "test.tsv", "text\tb\ntest1\t1\ntest2\t2", False, None, (["test1", "test2"], None), ), # TSV ( "test.tsv", "text\tlabel\ntest1\t1\ntest2\t2", False, None, (["test1", "test2"], ["1", "2"]), ), # TSV, limit ("test.tsv", "text\tlabel\ntest1\t1\ntest2\t2", False, 1, (["test1"], ["1"])), # TSV, multilabel ( "test.tsv", "text\tlabel\ntest1\t1,2\ntest2\t", True, None, (["test1", "test2"], [["1", "2"], []]), ), # TXT ("test.txt", "test1\ntest2", False, None, (["test1", "test2"], None)), # TXT, limit ("test.txt", "test1\ntest2", False, 1, (["test1"], None)), ], ) def test_read_data_file(tmpdir, filename, contents, multilabel, limit, expected): filepath = Path(tmpdir) / filename filepath.write_text(contents) if inspect.isclass(expected) and issubclass(expected, Exception): with pytest.raises(expected): read_data_file(filepath, multilabel, n_rows=limit) else: texts, labels = read_data_file(filepath, multilabel, n_rows=limit) expected_texts, expected_labels = expected assert texts == expected_texts assert labels == expected_labels