import distutils.file_util
import enum
import gzip
import io
import json
import logging
import os
import random
import shutil
import tarfile
import tempfile
import uuid
import warnings
import zipfile
from distutils.util import strtobool
from pathlib import Path
from typing import (
Any,
Container,
Dict,
Iterator,
List,
Optional,
Set,
TypeVar,
Union,
cast,
)
import humanize
import pandas as pd
import pkg_resources
import requests
from sklearn.preprocessing import MultiLabelBinarizer
LOGGER = logging.getLogger(__name__)
[docs]def default_gobbli_dir() -> Path:
"""
Returns:
The default directory to be used to store gobbli data if there's no user-specified
default.
"""
return Path.home() / ".gobbli"
[docs]def gobbli_dir() -> Path:
"""
Returns:
The directory used to store gobbli data. Can be overridden using the `GOBBLI_DIR`
environment variable.
"""
gobbli_dir = Path(os.getenv("GOBBLI_DIR", str(default_gobbli_dir())))
gobbli_dir.mkdir(parents=True, exist_ok=True)
return gobbli_dir
[docs]def model_dir() -> Path:
"""
Returns:
The subdirectory storing all model data and task run input/output.
"""
return gobbli_dir() / "model"
[docs]def gobbli_version() -> str:
"""
Returns:
The version of gobbli installed.
"""
return pkg_resources.get_distribution("gobbli").version
[docs]def disk_usage() -> int:
"""
Returns:
Disk usage (in bytes) of all gobbli data in the current gobbli directory.
"""
return sum(f.stat().st_size for f in gobbli_dir().glob("**/*") if f.is_file())
[docs]def human_disk_usage() -> str:
"""
Returns:
Human-readable string representing disk usage of all gobbli data in the current gobbli
directory.
"""
return humanize.naturalsize(disk_usage())
[docs]def cleanup(force: bool = False, full: bool = False):
"""
Cleans up the gobbli directory, removing files and directories. By default, removes
only task input/output. If the ``full`` argument is True, removes all data,
including downloaded model weights and datasets.
Args:
force: If True, don't prompt for user confirmation before cleaning up.
full: If True, remove all data (including downloaded model weights and datasets).
"""
if not force:
msg = "Cleanup will remove all task input/output, including trained models. Are you sure? [Y/n]"
confirmation = bool(strtobool(input(msg)))
if not confirmation:
return
if full:
full_msg = "Full cleanup will additionally remove all downloaded model weights; they'll need to be redownloaded to be used again. Are you sure? [Y/n]"
confirmation = bool(strtobool(input(full_msg)))
if not confirmation:
return
if full:
shutil.rmtree(gobbli_dir())
else:
base_model_dir = gobbli_dir() / "model"
if base_model_dir.exists():
for model_dir in (gobbli_dir() / "model").iterdir():
for instance_dir in model_dir.iterdir():
# Special directories containing cached weights or models
# which are reusable -- don't delete them here
if instance_dir.name not in ("cache", "weights"):
shutil.rmtree(instance_dir)
[docs]def pred_prob_to_pred_label(y_pred_proba: pd.DataFrame) -> List[str]:
"""
Convert a dataframe of predicted probabilities (shape (n_samples, n_classes)) to
a list of predicted classes.
"""
if len(y_pred_proba) == 0:
return []
return y_pred_proba.idxmax(axis=1).tolist()
[docs]def pred_prob_to_pred_multilabel(
y_pred_proba: pd.DataFrame, threshold: float = 0.5
) -> pd.DataFrame:
"""
Convert a dataframe of predicted probabilities (shape (n_samples, n_classes)) to
a dataframe of predicted label indicators (shape (n_samples, n_classes)).
"""
return y_pred_proba > threshold
[docs]def is_multilabel(y: Union[List[str], List[List[str]]]) -> bool:
"""
Returns:
True if the passed dataset is formatted as a multilabel problem, otherwise false
(it's multiclass).
"""
return len(y) > 0 and isinstance(y[0], list)
[docs]def as_multilabel(
y: Union[List[str], List[List[str]]], is_multilabel: bool
) -> List[List[str]]:
"""
Returns:
Labels in multilabel format, validated against
possible mismatch between the multilabel attribute and data format.
"""
if is_multilabel:
return cast(List[List[str]], y)
return multiclass_to_multilabel_target(cast(List[str], y))
[docs]def as_multiclass(
y: Union[List[str], List[List[str]]], is_multilabel: bool
) -> List[str]:
"""
Returns:
Labels in multiclass format, validated against
possible mismatch between the multilabel attribute and data format.
"""
if is_multilabel:
raise ValueError("Multilabel training input can't be converted to multiclass.")
return cast(List[str], y)
[docs]def multiclass_to_multilabel_target(y: List[str]) -> List[List[str]]:
"""
Args:
y: Multiclass-formatted target variable
Returns:
The multiclass-formatted variable generalized to a multilabel context
"""
return [[l] for l in y]
[docs]def multilabel_to_indicator_df(y: List[List[str]], labels: List[str]) -> pd.DataFrame:
"""
Convert a list of label lists to a 0/1 indicator dataframe.
Args:
y: List of label lists
labels: List of all unique labels found in y
Returns:
The dataframe will have a column for each label and a row for each observation,
with a 1 if the observation has that label or a 0 if not.
"""
mlb = MultiLabelBinarizer(classes=labels)
return pd.DataFrame(mlb.fit_transform(y), columns=mlb.classes_)
[docs]def collect_labels(y: Any) -> List[str]:
"""
Collect the unique labels in the given target variable, which could be in
multiclass or multilabel format.
Args:
y: Target variable
Returns:
An ordered list of all labels in y.
"""
if is_multilabel(y):
label_set: Set[str] = set()
for labels in y:
label_set.update(labels)
else:
label_set = set(y)
return list(sorted(label_set))
[docs]def truncate_text(text: str, length: int) -> str:
"""
Truncate the text to the given length. Append an ellipsis to make it clear the
text was truncated.
Args:
text: The text to truncate.
length: Maximum length of the truncated text.
Returns:
The truncated text.
"""
truncated = text[:length]
if len(truncated) < len(text):
truncated += "..."
return truncated
[docs]def assert_type(name: str, val: Any, cls: Any):
"""
Raise an informative error if the passed value isn't of the given type.
Args:
name: User-friendly name of the value to be printed in an exception if raised.
val: The value to check type of.
cls: The type or tuple of types to compare the value's type against.
"""
if not isinstance(val, cls):
raise TypeError(f"{name} must be of type(s) '{cls}'; got '{type(val)}'")
[docs]def assert_in(name: str, val: Any, container: Container):
"""
Raise an informative error if the passed value isn't in the given container.
Args:
name: User-friendly name of the value to be printed in an exception if raised.
val: The value to check for membership in the container.
container: The container that should contain the value.
"""
if val not in container:
raise ValueError(f"invalid {name} '{val}'; valid values are {container}")
[docs]def generate_uuid() -> str:
"""
Generate a universally unique ID to be used for randomly naming directories,
models, tasks, etc.
"""
return uuid.uuid4().hex
[docs]def download_dir() -> Path:
"""
Returns:
The directory used to store gobbli downloaded files.
"""
download_dir = gobbli_dir() / "download"
download_dir.mkdir(parents=True, exist_ok=True)
return download_dir
[docs]def escape_line_delimited_text(text: str) -> str:
"""
Convert a single text possibly containing newlines and other troublesome whitespace
into a string suitable for writing and reading to a file where newlines will
divide up the texts.
Args:
text: The text to convert.
Returns:
The text with newlines and whitespace taken care of.
"""
return text.replace("\n", " ").strip()
[docs]def escape_line_delimited_texts(texts: List[str]) -> str:
"""
Convert a list of texts possibly containing newlines and other troublesome whitespace
into a string suitable for writing and reading to a file where newlines
will divide up the texts.
Args:
texts: The list of texts to convert.
Returns:
The newline-delimited string.
"""
clean_texts = [escape_line_delimited_text(text) for text in texts]
return "\n".join(clean_texts)
[docs]def is_dir_empty(dir_path: Path) -> bool:
"""
Determine whether a given directory is empty. Assumes the directory exists.
Args:
dir_path: The directory to check for emptiness.
Returns:
:obj:`True` if the directory is empty, otherwise :obj:`False`.
"""
empty = True
for child in dir_path.iterdir():
empty = False
break
return empty
[docs]def copy_file(src_path: Path, dest_path: Path) -> bool:
"""
Copy a file from the source to destination. Be smart -- only copy if
the source is newer than the destination.
Args:
src_path: The path to the source file.
dest_path: The path to the destination where the source should be copied.
Returns:
True if the file was copied, otherwise False if the destination was already
up to date.
"""
_, copied = distutils.file_util.copy_file(
str(src_path), str(dest_path), update=True
)
return copied == 1
[docs]def download_file(url: str, filename: Optional[str] = None) -> Path:
"""
Save a file in the gobbli download directory if it doesn't already exist there.
Stream the download to avoid running out of memory.
Args:
url: URL for the file.
filename: If passed, use this as the filename instead of the best-effort one
determined from the URL.
Returns:
The path to the downloaded file.
"""
if filename is None:
# Kind of hacky... pull out the last path component as the filename and strip
# a trailing query, if any
local_filename = url.split("/")[-1]
try:
query_start_ndx = local_filename.index("?")
except ValueError:
query_start_ndx = -1
if query_start_ndx != -1:
local_filename = local_filename[:query_start_ndx]
else:
local_filename = filename
local_filepath = download_dir() / local_filename
if local_filepath.exists():
LOGGER.debug(f"Download for URL '{url}' already exists at '{local_filepath}'")
return local_filepath
LOGGER.debug(f"Downloading URL '{url}' to '{local_filepath}'")
try:
with requests.get(url, stream=True) as r:
with open(local_filepath, "wb") as f:
shutil.copyfileobj(r.raw, f)
except Exception as e:
# Don't leave the file in a partially downloaded state
try:
local_filepath.unlink()
LOGGER.debug(f"Removed partially downloaded file at '{local_filepath}'")
except Exception:
warnings.warn(
"Failed to remove partially downloaded file at '{local_filepath}'."
)
raise e
return local_filepath
def _extract_tar_junk_path(tarfile_obj: tarfile.TarFile, archive_extract_dir: Path):
"""
Extract a tarfile while flattening any directory hierarchy
in the archive.
"""
for member in tarfile_obj.getmembers():
if member.isdir():
# Skip directories
continue
# Remove the directory hierarchy from the file
member.name = Path(member.name).name
output_file = archive_extract_dir / member.name
LOGGER.debug(f"Extracting member '{member.name}' to '{output_file}'")
tarfile_obj.extract(member, path=archive_extract_dir)
def _extract_zip_junk_path(zipfile_obj: zipfile.ZipFile, archive_extract_dir: Path):
"""
Extract a zip file while flattening any directory hierarchy in the archive.
"""
for member in zipfile_obj.infolist():
if member.is_dir():
# Skip directories
continue
member_name = Path(member.filename).name
output_file = archive_extract_dir / member_name
LOGGER.debug(f"Extracting member '{member_name}' to '{output_file}'")
with zipfile_obj.open(member, "r") as f:
with output_file.open("wb") as f_out:
shutil.copyfileobj(f, f_out)
_SUPPORTED_ARCHIVE_EXTENSIONS = (".gz", ".zip")
[docs]def is_archive(filepath: Path) -> bool:
"""
Args:
filepath: Path to a file.
Returns:
Whether the file is an archive supported by :func:`extract_archive`.
"""
for ext in _SUPPORTED_ARCHIVE_EXTENSIONS:
if filepath.name.endswith(ext):
return True
return False
[docs]def download_archive(
archive_url: str,
archive_extract_dir: Path,
junk_paths: bool = False,
filename: Optional[str] = None,
) -> Path:
"""
Save an archive in the given directory and extract its contents. Automatically
retry the download once if the file comes back corrupted (in case it's left over
from a partial download that was cancelled before).
Args:
archive_url: URL for the archive.
archive_extract_dir: Download the archive and extract it to this directory.
junk_paths: If True, disregard the archive's internal directory hierarchy
and extract all files directly to the output directory.
filename: If given, store the downloaded file under this name instead of
one automatically inferred from the URL.
Returns:
Path to the directory containing the extracted file contents.
"""
download_path = download_file(archive_url, filename=filename)
try:
return extract_archive(
download_path, archive_extract_dir, junk_paths=junk_paths
)
except (zipfile.BadZipFile, tarfile.ReadError, OSError, EOFError):
LOGGER.warning(
f"Downloaded archive at '{download_path}' is corrupted. Retrying..."
)
download_path.unlink()
download_path = download_file(archive_url, filename=filename)
return extract_archive(
download_path, archive_extract_dir, junk_paths=junk_paths
)
[docs]def dir_to_blob(dir_path: Path) -> bytes:
"""
Archive a directory and save it as a blob in-memory.
Useful for storing a directory's contents in an in-memory object store.
Use compression to reduce file size. Extract with :func:`blob_to_dir`.
Args:
dir_path: Path to the directory to be archived.
Returns:
The compressed directory as a binary buffer.
"""
blob = io.BytesIO()
with tarfile.open(fileobj=blob, mode="w:gz") as archive:
# Set arcname=. to ensure the files will be extracted properly
# using relative paths
archive.add(str(dir_path), arcname=".", recursive=True)
# Seek the beginning of the buffer so we read the whole thing
blob.seek(0)
return blob.getvalue()
[docs]def blob_to_dir(blob: bytes, dir_path: Path):
"""
Extract the given blob (assumed to be a compressed directory
created by :func:`dir_to_blob`) to the given directory.
Args:
blob: The compressed directory as a binary buffer.
dir_path: Path to extract the directory to.
"""
with tarfile.open(fileobj=io.BytesIO(blob), mode="r:gz") as archive:
archive.extractall(dir_path)
T1 = TypeVar("T1")
T2 = TypeVar("T2")
[docs]def shuffle_together(l1: List[T1], l2: List[T2], seed: Optional[int] = None):
"""
Shuffle two lists together, so their order is random but the individual
elements still correspond. Ex. shuffle a list of texts and a list of labels
so the labels still correspond correctly to the texts. The lists are shuffled in-place.
The lists must be the same length for this to make sense.
Args:
l1: The first list to be shuffled.
l2: The second list to be shuffled.
seed: Seed for the random number generator, if any
"""
if not len(l1) == len(l2):
raise ValueError("Lists have unequal length.")
if len(l1) == 0:
return
zipped = list(zip(l1, l2))
if seed is not None:
random.seed(seed)
random.shuffle(zipped)
l1[:], l2[:] = zip(*zipped)
def _train_sentencepiece(spm: Any, texts: List[str], model_path: Path, vocab_size: int):
"""
Train a Sentencepiece model on the given data and save it to the given
location.
Args:
spm: Imported sentencepiece module
texts: Data to train on.
model_path: Path to write the trained model to.
vocab_size: Size of the vocabulary for the model to train.
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write("\n".join(texts))
f.flush()
# log levels:
# https://github.com/google/sentencepiece/blob/d4dd947fe71c4fa4ee24ad8297beee32887d8828/src/common.h#L132
# Default is waaay too chatty
spm.SentencePieceTrainer.train(
f"--input={f.name} --model_prefix={model_path} --vocab_size={vocab_size} --minloglevel=2"
)
[docs]@enum.unique
class TokenizeMethod(enum.Enum):
"""
Enum describing the different canned tokenization methods gobbli supports.
Processes requiring tokenization should generally allow a user to pass in
a custom tokenization function if their needs aren't met by one of these.
Attributes:
SPLIT: Naive tokenization based on whitespace. Probably only useful for testing.
Tokens will be lowercased.
SPACY: Simple tokenization using spaCy's English language model.
Tokens will be lowercased, and non-alphabetic tokens will be filtered out.
SENTENCEPIECE: `SentencePiece <https://github.com/google/sentencepiece>`__-based tokenization.
"""
SPLIT = "split"
SPACY = "spacy"
SENTENCEPIECE = "sentencepiece"
[docs]def tokenize(
method: TokenizeMethod,
texts: List[str],
model_path: Optional[Path] = None,
vocab_size: int = 2000,
) -> List[List[str]]:
"""
Tokenize a list of texts using a predefined method.
Args:
texts: Texts to tokenize.
method: The type of tokenization to apply.
model_path: Path to save a trained model to. Required if the tokenization method
requires training a model; otherwise ignored. If the model doesn't exist, it will
be trained; if it does, the trained model will be reused.
vocab_size: Number of terms in the vocabulary for tokenization methods with a fixed
vocabulary size. You may need to lower this if you get tokenization errors or
raise it if your texts have a very diverse vocabulary.
Returns:
List of tokenized texts.
"""
if method == TokenizeMethod.SPLIT:
return [[tok.lower() for tok in text.split()] for text in texts]
elif method == TokenizeMethod.SPACY:
try:
from spacy.lang.en import English
except ImportError:
raise ImportError(
"spaCy tokenization method requires spaCy and an english language "
"model to be installed."
)
nlp = English()
tokenizer = nlp.Defaults.create_tokenizer(nlp)
processed_texts = []
for doc in tokenizer.pipe(texts):
processed_texts.append([tok.lower_ for tok in doc if tok.is_alpha])
return processed_texts
elif method == TokenizeMethod.SENTENCEPIECE:
try:
import sentencepiece as spm
except ImportError:
raise ImportError(
"SentencePiece tokenization requires the sentencepiece module "
"to be installed."
)
# Train only if the model file path doesn't already exist
# If we weren't given a path to save to, just make a temp file which will
# be discarded after the run
sp = spm.SentencePieceProcessor()
if model_path is None:
with tempfile.TemporaryDirectory() as temp_model_dir:
temp_model_path = Path(temp_model_dir) / "temp"
_train_sentencepiece(spm, texts, temp_model_path, vocab_size)
# Have to add the extension here for sentencepiece to find the model
sp.load(f"{temp_model_path}.model")
return [sp.encode_as_pieces(text) for text in texts]
else:
model_file_path = Path(f"{model_path}.model")
if not model_file_path.exists():
_train_sentencepiece(spm, texts, model_path, vocab_size)
sp.load(str(model_file_path))
return [sp.encode_as_pieces(text) for text in texts]
else:
raise ValueError(f"Unsupported tokenization method '{method}'.")
[docs]def detokenize(
method: TokenizeMethod,
all_tokens: Iterator[List[str]],
model_path: Optional[Path] = None,
) -> List[str]:
"""
Detokenize a nested list of tokens into a list of strings, assuming
they were created using the given predefined method.
Args:
method: The type of tokenization to reverse.
tokens: The nested list of tokens to detokenize.
model_path: Path to load a trained model from. Required if the tokenization method
requires training a model; otherwise ignored.
Returns:
List of texts.
"""
if method in (TokenizeMethod.SPLIT, TokenizeMethod.SPACY):
# TODO can this be better?
return [" ".join(tokens) for tokens in all_tokens]
elif method == TokenizeMethod.SENTENCEPIECE:
try:
import sentencepiece as spm
except ImportError:
raise ImportError(
"SentencePiece detokenization requires the sentencepiece module "
"to be installed."
)
sp = spm.SentencePieceProcessor()
sp.load(f"{model_path}.model")
return [sp.decode_pieces(tokens) for tokens in all_tokens]
else:
raise ValueError(f"Unsupported tokenization method '{method}'.")