Source code for gobbli.augment.base

from abc import ABC, abstractmethod
from pathlib import Path
from typing import List

from gobbli.util import gobbli_dir


[docs]def augment_dir() -> Path: return gobbli_dir() / "augment"
[docs]class BaseAugment(ABC): """ Base class for data augmentation methods. """
[docs] @abstractmethod def augment(self, X: List[str], times: int = 5, p: float = 0.1) -> List[str]: """ Return additional texts for each text in the passed array. Args: X: Input texts. times: How many texts to generate per text in the input. p: Probability of considering each token in the input for replacement. Note that some tokens aren't able to be replaced by a given augmentation method and will be ignored, so the actual proportion of replaced tokens in your input may be much lower than this number. Returns: Generated texts (length = ``times * len(X)``). """ raise NotImplementedError
[docs] @classmethod def data_dir(cls) -> Path: """ Returns: The data directory used for this class of augmentation model. """ return augment_dir() / cls.__name__