Source code for gobbli.augment.wordnet
import functools
import random
from typing import Any, List
from gobbli.augment.base import BaseAugment
def _detokenize_doc(doc: Any) -> str:
"""
Detokenize a spaCy Doc object back into a string, applying our custom replacements
as needed. This requires the associated extension to have been registered appropriately.
The :class:`WordNet` constructor should handle registering the extension.
"""
return "".join([f"{tok._.replacement}{tok.whitespace_}" for tok in doc])
def _get_lemmas(synsets: List[Any]) -> List[str]:
"""
Return all the lemma names associated with a list of synsets.
"""
return [lemma_name for synset in synsets for lemma_name in synset.lemma_names()]
@functools.lru_cache(maxsize=256)
def _get_wordnet_lemmas(word: str, pos: str) -> List[str]:
"""
Determine all the lemmas for a given word to be considered candidates for
replacement. Wrap this function in an LRU cache to keep from recalculating common words
or terms reused frequently in the same document.
"""
# We should have properly guarded this import in the WordNet constructor
from nltk.corpus import wordnet
synsets = wordnet.synsets(word, pos)
hypernyms = [hypernym for synset in synsets for hypernym in synset.hypernyms()]
hyponyms = [hyponym for synset in synsets for hyponym in synset.hyponyms()]
return list(
frozenset(_get_lemmas(synsets) + _get_lemmas(hypernyms) + _get_lemmas(hyponyms))
)
[docs]class WordNet(BaseAugment):
"""
Data augmentation method based on WordNet. Replaces words with similar
words according to the WordNet ontology. Texts will be Part of Speech-tagged
using spaCy to help ensure only sensible replacements (i.e., within the same part of speech)
are considered.
Args:
skip_download_check: If True, don't try to download the WordNet corpus;
assume it's already been downloaded.
spacy_model: The language model to be used for Part of Speech tagging by spaCy.
The model must already have been installed.
"""
def __init__(self, skip_download_check: bool = False, spacy_model="en_core_web_sm"):
try:
from nltk.corpus import wordnet
import nltk
except ImportError:
raise ImportError(
"WordNet-based data augmentation requires nltk to be installed."
)
self.wn = wordnet
try:
import spacy
from spacy.tokens import Token
except ImportError:
raise ImportError(
"WordNet-based data augmentation requires spaCy and a language "
"model to be installed (for part of speech tagging)."
)
if not skip_download_check:
nltk.download("wordnet")
self.nlp = spacy.load(spacy_model, parser=False, tagger=True, entity=False)
Token.set_extension("replacement", default=None, force=True)
def _maybe_replace_token(self, token: Any) -> str:
if token.pos_ == "ADJ":
wordnet_pos = self.wn.ADJ
elif token.pos_ == "NOUN":
wordnet_pos = self.wn.NOUN
elif token.pos_ == "VERB":
wordnet_pos = self.wn.VERB
elif token.pos_ == "ADV":
wordnet_pos = self.wn.ADV
else:
# If the token's part of speech isn't recognized by WordNet,
# return it without replacing.
return token.text
all_lemma_names = _get_wordnet_lemmas(token.text, wordnet_pos)
if len(all_lemma_names) > 0:
# WordNet lemmas have underscores where spaces should be, so apply spaces
# appropriately
return random.choice(all_lemma_names).replace("_", " ")
else:
return token.text
[docs] def augment(self, X: List[str], times: int = 5, p: float = 0.1) -> List[str]:
new_texts = []
tagged_docs = [doc for doc in self.nlp.pipe(X)]
for _ in range(times):
for doc in tagged_docs:
for tok in doc:
if random.random() < p:
tok._.replacement = self._maybe_replace_token(tok)
else:
tok._.replacement = tok.text
new_texts.append(_detokenize_doc(doc))
return new_texts