import logging
from abc import ABC, abstractmethod
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, List, Optional, Tuple
from sklearn.model_selection import train_test_split
import gobbli.io
from gobbli.util import collect_labels, format_duration, gobbli_dir, shuffle_together
LOGGER = logging.getLogger(__name__)
[docs]def dataset_dir() -> Path:
return gobbli_dir() / "dataset"
[docs]class BaseDataset(ABC):
"""
Abstract base class for datasets used for benchmarking and testing.
Derived classes should account for the following:
- Dataset order should be consistent so limiting can work correctly
"""
def __init__(self, *args, **kwargs):
"""
Blank constructor needed to satisfy mypy
"""
[docs] @classmethod
def data_dir(cls) -> Path:
return dataset_dir() / cls.__name__
[docs] @classmethod
def load(cls, *args, **kwargs) -> "BaseDataset":
ds = cls(*args, **kwargs)
if not ds._is_built():
LOGGER.info("Dataset %s hasn't been built; building.", cls.__name__)
start = timer()
ds._build()
end = timer()
LOGGER.info(f"Dataset building finished in {format_duration(end - start)}.")
return ds
@abstractmethod
def _is_built(self) -> bool:
raise NotImplementedError
@abstractmethod
def _build(self):
raise NotImplementedError
[docs] @abstractmethod
def X_train(self):
raise NotImplementedError
[docs] @abstractmethod
def y_train(self):
raise NotImplementedError
[docs] @abstractmethod
def X_test(self):
raise NotImplementedError
[docs] @abstractmethod
def y_test(self):
raise NotImplementedError
def _get_train_valid(
self, limit: Optional[int] = None, shuffle_seed: int = 1234
) -> Tuple[List[str], List[Any]]:
"""
Return the X and y used for training and validation with the
appropriate limit applied. Shuffle first to minimize the possibility of
getting only a single label in a small/limited dataset if it happens to be ordered
by label.
"""
X_train_valid = self.X_train()
y_train_valid = self.y_train()
# Shuffle the two simultaneously so text and label stay together
shuffle_together(X_train_valid, y_train_valid, shuffle_seed)
if limit is not None:
X_train_valid = X_train_valid[:limit]
y_train_valid = y_train_valid[:limit]
return X_train_valid, y_train_valid