import re
import shutil
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
import gobbli.io
from gobbli.docker import maybe_mount, run_container
from gobbli.model.base import BaseModel
from gobbli.model.context import ContainerTaskContext
from gobbli.model.mixin import EmbedMixin, PredictMixin, TrainMixin
from gobbli.util import (
assert_in,
assert_type,
download_archive,
escape_line_delimited_text,
multilabel_to_indicator_df,
pred_prob_to_pred_label,
pred_prob_to_pred_multilabel,
)
FASTTEXT_VECTOR_ARCHIVES = {
"wiki-news-300d": "https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip",
"wiki-news-300d-subword": "https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M-subword.vec.zip",
"crawl-300d": "https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M.vec.zip",
"crawl-300d-subword": "https://dl.fbaipublicfiles.com/fasttext/vectors-english/crawl-300d-2M-subword.zip",
"wiki-crawl-300d": "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz",
"wiki-aligned-300d": "https://dl.fbaipublicfiles.com/fasttext/vectors-aligned/wiki.en.align.vec",
}
"""
A mapping from pretrained vector names to archives.
See `the fastText docs <https://fasttext.cc/docs/en/english-vectors.html>`__ for information
about each set of vectors. Note, some sets of vectors are very, very large.
"""
_DIM_REGEX = re.compile(r"([0-9]+)d")
def _parse_dim(model_name: str) -> int:
"""
Parse the number of dimensions from a FastText model name.
"""
match = _DIM_REGEX.search(model_name)
if match is None:
raise ValueError(
f"Failed to parse number of dimensions from model name {model_name}"
)
return int(match.group(1))
def _fasttext_preprocess(text: str) -> str:
"""
Preprocess text for the fasttext model.
Lowercase text and escape newlines. Removing or separately
tokenizing punctuation is recommended, but there are too many different
ways to do it, so we leave that up to the user.
"""
return escape_line_delimited_text(text).lower()
[docs]@dataclass
class FastTextCheckpoint:
path: Path
@property
def vectors(self) -> Path:
"""
From a checkpoint, return the path to the text vectors.
"""
return self.path.parent / f"{self.path.stem}.vec"
@property
def model(self) -> Path:
"""
From a checkpoint, return the path to the binary model.
"""
return self.path.parent / f"{self.path.stem}.bin"
[docs]class FastText(BaseModel, TrainMixin, PredictMixin, EmbedMixin):
"""
Wrapper for Facebook's fastText model:
https://github.com/facebookresearch/fastText
Note: fastText benefits from some preprocessing steps:
https://fasttext.cc/docs/en/supervised-tutorial.html#preprocessing-the-data
gobbli will only lowercase and escape newlines in your input by default.
If you want more sophisticated preprocessing for punctuation, stemming, etc,
consider performing some preprocessing on your own beforehand.
"""
_BUILD_PATH = Path(__file__).parent
_TRAIN_INPUT_FILE = "train.txt"
_VALID_INPUT_FILE = "valid.txt"
_TEST_INPUT_FILE = "test.txt"
_PREDICT_OUTPUT_FILE = "predict.txt"
_EMBEDDING_INPUT_FILE = "input.txt"
_EMBEDDING_OUTPUT_FILE = "embeddings.txt"
_CHECKPOINT_BASE = "model"
_LABEL_SPACE_ESCAPE = "$gobbli_space$"
@property
def image_tag(self) -> str:
"""
Returns:
The tag to use for the fastText image.
"""
return "gobbli-fasttext"
@property
def weights_dir(self) -> Path:
"""
Returns:
The directory containing pretrained weights for this instance.
"""
# Weights won't be used if we don't have a model to use
if self.fasttext_model is None:
return self.class_weights_dir
return self.class_weights_dir / self.fasttext_model
[docs] def init(self, params: Dict[str, Any]):
"""
See :meth:`gobbli.model.base.BaseModel.init`.
For more info on fastText parameter semantics, see
`the docs <https://fasttext.cc/docs/en/options.html>`__. The fastText
`supervised tutorial <https://fasttext.cc/docs/en/supervised-tutorial.html>`__ has
some more detailed explanation.
fastText parameters:
- ``word_ngrams`` (:obj:`int`): Max length of word n-grams.
- ``lr`` (:obj:`float`): Learning rate.
- ``dim`` (:obj:`int`): Dimension of learned vectors.
- ``ws`` (:obj:`int`): Context window size.
- ``autotune_duration`` (:obj:`int`): Duration in seconds to spend autotuning parameters.
Any of the above parameters will not be autotuned if they are manually specified.
- ``autotune_modelsize`` (:obj:`str`): Maximum size of autotuned model (ex "2M" for 2
megabytes). Any of the above parameters will not be autotuned if they are manually
specified.
- ``fasttext_model`` (:obj:`str`): Name of a pretrained fastText model to use.
See :obj:`FASTTEXT_VECTOR_ARCHIVES` for a listing of available pretrained models.
"""
self.word_ngrams = None
self.lr = None
self.ws = None
self.fasttext_model = None
# Default to dimensionality of the passed model, if any
if "fasttext_model" in params:
self.dim: Optional[int] = _parse_dim(params["fasttext_model"])
else:
self.dim = None
self.autotune_duration = None
self.autotune_modelsize = None
for name, value in params.items():
if name == "word_ngrams":
assert_type(name, value, int)
self.word_ngrams = value
elif name == "lr":
assert_type(name, value, float)
self.lr = value
elif name == "dim":
assert_type(name, value, int)
self.dim = value
elif name == "ws":
assert_type(name, value, int)
self.ws = value
elif name == "fasttext_model":
assert_in(name, value, set(FASTTEXT_VECTOR_ARCHIVES.keys()))
self.fasttext_model = value
elif name == "autotune_duration":
assert_type(name, value, int)
self.autotune_duration = value
elif name == "autotune_modelsize":
assert_type(name, value, str)
self.autotune_modelsize = value
else:
raise ValueError(f"Unknown param '{name}'")
if (
self.fasttext_model is not None
and f"{self.dim}d" not in self.fasttext_model
):
raise ValueError(
"When using pretrained vectors, 'dim' must match the"
f" dimensionality of the vectors; 'dim' value of {self.dim}"
f" is incompatible with vectors {self.fasttext_model}."
)
def _build(self):
# Download data if we need it and don't already have it
if (
self.fasttext_model is not None
and not (self.weights_dir / self.fasttext_model).exists()
):
with tempfile.TemporaryDirectory() as tmpdir:
tmp_weights_dir = Path(tmpdir) / self.weights_dir.name
tmp_weights_dir.mkdir()
self.logger.info("Downloading pre-trained weights.")
download_archive(
FASTTEXT_VECTOR_ARCHIVES[self.fasttext_model], tmp_weights_dir
)
shutil.move(tmp_weights_dir, self.weights_dir)
self.logger.info("Weights downloaded.")
# Build the custom docker image
self.docker_client.images.build(
path=str(FastText._BUILD_PATH),
tag=self.image_tag,
**self._base_docker_build_kwargs,
)
@staticmethod
def _escape_label(label: str) -> str:
"""
Escape a label for use in fastText's label format. Spaces
must be replaced, or the label will be interpreted as part of the text.
Args:
label: Label to escape
Returns:
The escaped label
"""
return label.replace(" ", FastText._LABEL_SPACE_ESCAPE)
@staticmethod
def _unescape_label(label: str) -> str:
"""
Reverse escaping for a label read from fastText's output format.
Args:
label: Label to unescape
Returns:
The unescaped label
"""
return label.replace(FastText._LABEL_SPACE_ESCAPE, " ")
@staticmethod
def _locate_checkpoint(weights_dir: Path) -> FastTextCheckpoint:
"""
Locate a fastText checkpoint under the given directory,
regardless of its filename.
Args:
weights_dir: The directory to search for a checkpoint (not recursive).
Returns:
A fastText checkpoint.
"""
candidates = list(weights_dir.glob("*.vec"))
if len(candidates) == 0:
raise ValueError(f"No weights files found in '{weights_dir}'.")
elif len(candidates) > 1:
raise ValueError(
f"Multiple weights files found in '{weights_dir}': {candidates}"
)
return FastTextCheckpoint(path=candidates[0].parent / candidates[0].stem)
def _get_checkpoint(
self, user_checkpoint: Optional[Path], context: ContainerTaskContext
) -> Tuple[Optional[FastTextCheckpoint], Optional[FastTextCheckpoint]]:
"""
Determines, if any, the host checkpoint file and container checkpoint file
using the user-requested checkpoint and the container context.
Args:
user_checkpoint: An optional checkpoint passed in by the user. If the user doesn't
pass one, use the default pretrained checkpoint, if any, or no checkpoint.
context: The container context to create the checkpoint in.
Returns:
A 2-tuple: the host checkpoint (if any) and
the container checkpoint (if any)
"""
host_checkpoint = None # type: Optional[FastTextCheckpoint]
container_checkpoint = None # type: Optional[FastTextCheckpoint]
if self.fasttext_model is None and user_checkpoint is None:
# No pretrained vectors
return host_checkpoint, container_checkpoint
elif self.fasttext_model is not None and user_checkpoint is None:
host_checkpoint = FastText._locate_checkpoint(self.weights_dir)
container_checkpoint = FastTextCheckpoint(
BaseModel._CONTAINER_WEIGHTS_PATH / host_checkpoint.path.name
)
return host_checkpoint, container_checkpoint
else: # user_checkpoint is not None; user_checkpoint overrides pretrained model
# This should never happen by the conditional checks above
assert user_checkpoint is not None
host_checkpoint = FastTextCheckpoint(user_checkpoint)
container_checkpoint = FastTextCheckpoint(
BaseModel._CONTAINER_WEIGHTS_PATH / host_checkpoint.path.name
)
return host_checkpoint, container_checkpoint
def _write_input(
self, X: List[str], y: Optional[List[List[str]]], input_path: Path
):
"""
Write the given input and labels (if any) into the format expected by fastText.
Make sure the given directory exists first.
"""
with open(input_path, "w") as f:
if y is not None:
for text, labels in zip(X, y):
label_str = " ".join(
f"__label__{FastText._escape_label(label)}" for label in labels
)
f.write(f"{label_str} {_fasttext_preprocess(text)}\n")
elif y is None:
for text in X:
f.write(f"{_fasttext_preprocess(text)}\n")
def _run_supervised(
self,
user_checkpoint: Optional[Path],
container_input_path: Path,
container_output_path: Path,
context: ContainerTaskContext,
num_epochs: int,
autotune_validation_file_path: Optional[Path] = None,
freeze_vectors: bool = False,
) -> Tuple[str, float]:
"""
Run the fastText "supervised" command. Used for both training and getting
validation loss.
Args:
user_checkpoint: A checkpoint passed by the user
container_input_path: Path to the input file in the container
container_output_path: Path to the output checkpoint in the container
context: Container task context.
validation_file_path: Optional file to use for autotune validation when training.
freeze_vectors: If true, use 0 learning rate; train solely for
the purpose of calculating loss.
Returns:
A 2-tuple: container logs and loss.
"""
host_checkpoint, container_checkpoint = self._get_checkpoint(
user_checkpoint, context
)
cmd = (
"supervised"
f" -input {container_input_path}"
f" -output {container_output_path}"
f" -epoch {num_epochs}"
)
if autotune_validation_file_path is not None:
cmd += f" -autotune-validation {autotune_validation_file_path}"
lr = self.lr
if freeze_vectors:
lr = 0.0
if lr is not None:
cmd += f" -lr {lr}"
for arg_name, attr in (
("wordNgrams", "word_ngrams"),
("dim", "dim"),
("ws", "ws"),
("autotune-duration", "autotune_duration"),
("autotune-modelsize", "autotune_modelsize"),
):
attr_val = getattr(self, attr)
if attr_val is not None:
cmd += f" -{arg_name} {attr_val}"
run_kwargs = self._base_docker_run_kwargs(context)
if host_checkpoint is not None and container_checkpoint is not None:
maybe_mount(
run_kwargs["volumes"],
host_checkpoint.vectors,
container_checkpoint.vectors,
)
cmd += f" -pretrainedVectors {container_checkpoint.vectors}"
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
# Parse the training loss out of the console output
last_loss_ndx = container_logs.rfind("avg.loss:")
failed_parse_msg = (
"Failed to parse loss information from fastText container logs."
" Run with debug logging to"
" see why this might have happened."
)
if last_loss_ndx == -1:
raise ValueError(failed_parse_msg)
# Skip over the word "avg.loss:" - next field in the output is "ETA:"
loss_start_ndx = last_loss_ndx + len("avg.loss:")
loss_end_ndx = container_logs.find("ETA:", loss_start_ndx)
loss = float(container_logs[loss_start_ndx:loss_end_ndx].strip())
return container_logs, loss
def _run_predict_prob(
self,
user_checkpoint: Path,
labels: List[str],
container_input_path: Path,
context: ContainerTaskContext,
) -> Tuple[str, pd.DataFrame]:
"""
Run the fastText "predict-prob" command. Used for obtaining
label predicted probabilities on a dataset.
Args:
container_trained_model_path: Trained model passed by the user (.bin file)
labels: Set of all labels to be used in prediction.
container_input_path: Path to the input file in the container.
context: Container task context.
Returns:
A 2-tuple: container logs and a dataframe of predicted probabilities.
"""
host_checkpoint, container_checkpoint = self._get_checkpoint(
user_checkpoint, context
)
if host_checkpoint is None or container_checkpoint is None:
raise ValueError("A trained checkpoint is required to run prediction.")
host_output_path = context.host_output_dir / FastText._PREDICT_OUTPUT_FILE
container_output_path = (
context.container_output_dir / FastText._PREDICT_OUTPUT_FILE
)
cmd = (
"bash -c './fasttext predict-prob"
f" {container_checkpoint.model}"
f" {container_input_path}"
f" {len(labels)}"
f" >{container_output_path}'"
)
run_kwargs = self._base_docker_run_kwargs(context)
# Override the entrypoint so we can use 'bash -c ...' above
run_kwargs["entrypoint"] = ""
maybe_mount(
run_kwargs["volumes"], host_checkpoint.model, container_checkpoint.model
)
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
# Parse the predicted probabilities out of the output file
pred_prob_data = []
with open(host_output_path, "r") as f:
for line in f:
tokens = line.split()
# Seems that fastText doesn't always return a probability for
# every label, so start out with default = 0.0 so the shape of
# the returned DataFrame will be consistent with the number
# of labels
row_data = {label: 0.0 for label in labels}
for raw_label, prob in zip(tokens[0::2], tokens[1::2]):
# Strip the "__label__" prefix and undo escaping
label = FastText._unescape_label(raw_label[9:])
row_data[label] = float(prob)
pred_prob_data.append(row_data)
return (container_logs, pd.DataFrame(pred_prob_data))
def _train(
self, train_input: gobbli.io.TrainInput, context: ContainerTaskContext
) -> gobbli.io.TrainOutput:
self._write_input(
train_input.X_train,
train_input.y_train_multilabel,
context.host_input_dir / FastText._TRAIN_INPUT_FILE,
)
self._write_input(
train_input.X_valid,
train_input.y_valid_multilabel,
context.host_input_dir / FastText._VALID_INPUT_FILE,
)
container_validation_input_path = (
context.container_input_dir / FastText._VALID_INPUT_FILE
)
train_logs, train_loss = self._run_supervised(
train_input.checkpoint,
context.container_input_dir / FastText._TRAIN_INPUT_FILE,
context.container_output_dir / FastText._CHECKPOINT_BASE,
context,
train_input.num_train_epochs,
autotune_validation_file_path=container_validation_input_path,
)
host_checkpoint_path = context.host_output_dir / f"{FastText._CHECKPOINT_BASE}"
labels = train_input.labels()
# Calculate validation accuracy on our own, since the CLI only provides
# precision/recall
predict_logs, pred_prob_df = self._run_predict_prob(
host_checkpoint_path, labels, container_validation_input_path, context
)
if train_input.multilabel:
pred_labels = pred_prob_to_pred_multilabel(pred_prob_df)
gold_labels = multilabel_to_indicator_df(
train_input.y_valid_multilabel, labels
)
else:
pred_labels = pred_prob_to_pred_label(pred_prob_df)
gold_labels = train_input.y_valid_multiclass
valid_accuracy = accuracy_score(gold_labels, pred_labels)
# Not ideal, but fastText doesn't provide a way to get validation loss;
# Negate the validation accuracy instead
valid_loss = -valid_accuracy
return gobbli.io.TrainOutput(
train_loss=train_loss,
valid_loss=valid_loss,
valid_accuracy=valid_accuracy,
labels=labels,
multilabel=train_input.multilabel,
checkpoint=host_checkpoint_path,
_console_output="\n".join((train_logs, predict_logs)),
)
def _predict(
self, predict_input: gobbli.io.PredictInput, context: ContainerTaskContext
) -> gobbli.io.PredictOutput:
host_input_path = context.host_input_dir / FastText._TEST_INPUT_FILE
self._write_input(predict_input.X, None, host_input_path)
container_input_path = context.to_container(host_input_path)
if predict_input.checkpoint is None:
raise ValueError("fastText requires a trained checkpoint for prediction.")
predict_logs, pred_prob_df = self._run_predict_prob(
predict_input.checkpoint,
predict_input.labels,
container_input_path,
context,
)
return gobbli.io.PredictOutput(
y_pred_proba=pred_prob_df, _console_output=predict_logs
)
def _embed(
self, embed_input: gobbli.io.EmbedInput, context: ContainerTaskContext
) -> gobbli.io.EmbedOutput:
# Check for null checkpoint here to give quick feedback to the user
if embed_input.checkpoint is None:
raise ValueError(
"fastText requires a trained checkpoint to generate embeddings."
)
if embed_input.pooling == gobbli.io.EmbedPooling.NONE:
raise ValueError(
"fastText prints sentence vectors, so pooling is required."
)
host_input_path = context.host_input_dir / FastText._EMBEDDING_INPUT_FILE
self._write_input(embed_input.X, None, host_input_path)
container_input_path = context.to_container(host_input_path)
host_checkpoint, container_checkpoint = self._get_checkpoint(
embed_input.checkpoint, context
)
# We shouldn't get Nones here if the user didn't pass a null checkpoint, but
# check anyway to satisfy mypy
if host_checkpoint is None or container_checkpoint is None:
raise ValueError(
"fastText requires a trained checkpoint to generate embeddings."
)
host_output_path = context.host_output_dir / FastText._EMBEDDING_OUTPUT_FILE
container_output_path = (
context.container_output_dir / FastText._EMBEDDING_OUTPUT_FILE
)
cmd = (
"bash -c './fasttext print-sentence-vectors"
f" {container_checkpoint.model}"
f" <{container_input_path}"
f" >{container_output_path}'"
)
run_kwargs = self._base_docker_run_kwargs(context)
# Override the entrypint so we can use 'bash -c ...' above
run_kwargs["entrypoint"] = ""
maybe_mount(
run_kwargs["volumes"], host_checkpoint.model, container_checkpoint.model
)
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
# Parse the embeddings out of the output file
embeddings = np.loadtxt(host_output_path, comments=None, ndmin=2)
return gobbli.io.EmbedOutput(
X_embedded=embeddings, embed_tokens=None, _console_output=container_logs
)