Source code for gobbli.model.transformer.model
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import gobbli.io
from gobbli.docker import maybe_mount, run_container
from gobbli.model.base import BaseModel
from gobbli.model.context import ContainerTaskContext
from gobbli.model.mixin import EmbedMixin, PredictMixin, TrainMixin
from gobbli.util import assert_type, escape_line_delimited_texts
[docs]class Transformer(BaseModel, TrainMixin, PredictMixin, EmbedMixin):
"""
Classifier/embedding wrapper for any of the Transformers from
`transformers <https://github.com/huggingface/transformers>`__.
"""
_BUILD_PATH = Path(__file__).parent
_TRAIN_INPUT_FILE = "train.tsv"
_VALID_INPUT_FILE = "dev.tsv"
_TEST_INPUT_FILE = "test.tsv"
_LABELS_INPUT_FILE = "labels.tsv"
_CONFIG_OVERRIDE_FILE = "config.json"
_TRAIN_OUTPUT_CHECKPOINT = "checkpoint"
_VALID_OUTPUT_FILE = "valid_results.json"
_TEST_OUTPUT_FILE = "test_results.tsv"
_EMBEDDING_INPUT_FILE = "input.tsv"
_EMBEDDING_OUTPUT_FILE = "embeddings.jsonl"
_CONTAINER_CACHE_DIR = Path("/cache")
[docs] def init(self, params: Dict[str, Any]):
"""
See :meth:`gobbli.model.base.BaseModel.init`.
Transformer parameters:
- ``transformer_model`` (:obj:`str`): Name of a transformer model architecture to use.
For training/prediction, the value should be one such that
``from transformers import <value>ForSequenceClassification`` is
a valid import. ex value = "Bert" ->
``from transformers import BertForSequenceClassification``. Note this means
only a subset of the transformers models are supported for these tasks -- search
`the docs <https://huggingface.co/transformers/search.html?q=forsequenceclassification&check_keywords=yes&area=default>`__ to see which ones you can use.
For embedding generation, the import is ``<value>Model``, so any transformer
model is supported.
- ``transformer_weights`` (:obj:`str`): Name of the pretrained weights to use.
See the `transformers docs <https://huggingface.co/transformers/pretrained_models.html>`__
for supported values. These depend on the ``transformer_model`` chosen.
- ``config_overrides`` (:obj:`dict`): Dictionary of keys and values that will
override config for the model.
- ``max_seq_length``: Truncate all sequences to this length after tokenization.
Used to save memory.
- ``lr``: Learning rate for the AdamW optimizer.
- ``adam_eps``: Epsilon value for the AdamW optimizer.
- ``gradient_accumulation_steps``: Number of iterations to accumulate gradients before
updating the model. Used to allow larger effective batch sizes for models too big to
fit a large batch on the GPU. The "effective batch size" is
``gradient_accumulation_steps`` * :paramref:`TrainInput.params.train_batch_size`.
If you encounter memory errors while training, try decreasing the batch size and
increasing ``gradient_accumulation_steps``. For example, if a training batch size of
32 causes memory errors, try decreasing batch size to 16 and increasing
``gradient_accumulation_steps`` to 2. If you still have problems with memory, you can
drop batch size to 8 and ``gradient_accumulation_steps`` to 4, and so on.
Note that gobbli relies on transformers to perform validation on these parameters,
so initialization errors may not be caught until model runtime.
"""
self.transformer_model = "Bert"
self.transformer_weights = "bert-base-uncased"
self.config_overrides = {} # type: Dict[str, Any]
self.max_seq_length = 128
self.lr = 5e-5
self.adam_eps = 1e-8
self.gradient_accumulation_steps = 1
for name, value in params.items():
if name == "transformer_model":
self.transformer_model = value
elif name == "transformer_weights":
self.transformer_weights = value
elif name == "config_overrides":
assert_type(name, value, dict)
self.config_overrides = value
elif name == "max_seq_length":
assert_type(name, value, int)
self.max_seq_length = value
elif name == "lr":
assert_type(name, value, float)
self.lr = value
elif name == "adam_eps":
assert_type(name, value, float)
self.adam_eps = value
elif name == "gradient_accumulation_steps":
assert_type(name, value, int)
self.gradient_accumulation_steps = value
else:
raise ValueError(f"Unknown param '{name}'")
@property
def image_tag(self) -> str:
"""
Returns:
The Docker image tag to be used for the transformer container.
"""
return "gobbli-transformer"
def _build(self):
self.docker_client.images.build(
path=str(Transformer._BUILD_PATH),
tag=self.image_tag,
**self._base_docker_build_kwargs,
)
@staticmethod
def _get_checkpoint(
user_checkpoint: Optional[Path], context: ContainerTaskContext
) -> Tuple[Optional[Path], Optional[Path]]:
"""
Determines the host checkpoint directory and container checkpoint directory
using the user-requested checkpoint (if any) and the container context.
Args:
user_checkpoint: An optional checkpoint passed in by the user. If the user doesn't
pass one, use the default pretrained checkpoint.
context: The container context to create the checkpoint in.
Returns:
A 2-tuple: the host checkpoint directory (if any) and
the container checkpoint directory (if any)
"""
if user_checkpoint is None:
host_checkpoint_dir = None
container_checkpoint_dir = None
else:
host_checkpoint_dir = user_checkpoint
container_checkpoint_dir = context.container_root_dir / "checkpoint"
return host_checkpoint_dir, container_checkpoint_dir
def _get_weights(
self, container_checkpoint_dir: Optional[Path]
) -> Union[str, Path]:
"""
Determine the weights to pass to the run_model script. If we don't have a
checkpoint, we'll use the pretrained weights. Otherwise, we should use the
checkpoint weights.
"""
if container_checkpoint_dir is None:
return self.transformer_weights
else:
return container_checkpoint_dir
@property
def host_cache_dir(self):
"""
Directory to be used for downloaded transformers files.
Should be the same across all instances of the class, since these are
generally static model weights/config files that can be reused.
"""
cache_dir = Transformer.model_class_dir() / "cache"
cache_dir.mkdir(exist_ok=True, parents=True)
return cache_dir
def _write_input(
self,
X: List[str],
labels: Optional[Union[List[str], List[List[str]]]],
input_path: Path,
):
"""
Write the given input texts and (optionally) labels to the file pointed to by
``input_path``.
"""
df = pd.DataFrame({"Text": X})
if labels is not None:
df["Label"] = labels
df.to_csv(input_path, sep="\t", index=False)
def _write_labels(self, labels: List[str], labels_path: Path):
"""
Write the given labels to the file pointed at by ``labels_path``.
"""
labels_path.write_text(escape_line_delimited_texts(labels))
def _write_config(self, config_path: Path):
"""
Write our model configuration overrides to the given path.
"""
with open(config_path, "w") as f:
json.dump(self.config_overrides, f)
def _train(
self, train_input: gobbli.io.TrainInput, context: ContainerTaskContext
) -> gobbli.io.TrainOutput:
self._write_input(
train_input.X_train,
train_input.y_train,
context.host_input_dir / Transformer._TRAIN_INPUT_FILE,
)
self._write_input(
train_input.X_valid,
train_input.y_valid,
context.host_input_dir / Transformer._VALID_INPUT_FILE,
)
self._write_config(context.host_input_dir / Transformer._CONFIG_OVERRIDE_FILE)
labels = train_input.labels()
self._write_labels(
labels, context.host_input_dir / Transformer._LABELS_INPUT_FILE
)
# Determine checkpoint to use
host_checkpoint_dir, container_checkpoint_dir = self._get_checkpoint(
train_input.checkpoint, context
)
cmd = (
"python3 run_model.py"
" train"
f" --input-dir {context.container_input_dir}"
f" --output-dir {context.container_output_dir}"
f" --config-overrides {context.container_input_dir / Transformer._CONFIG_OVERRIDE_FILE}"
f" --model {self.transformer_model}"
f" --weights {self._get_weights(container_checkpoint_dir)}"
f" --cache-dir {Transformer._CONTAINER_CACHE_DIR}"
f" --max-seq-length {self.max_seq_length}"
f" --train-batch-size {train_input.train_batch_size}"
f" --valid-batch-size {train_input.valid_batch_size}"
f" --num-train-epochs {train_input.num_train_epochs}"
f" --lr {self.lr}"
f" --adam-eps {self.adam_eps}"
f" --gradient-accumulation-steps {self.gradient_accumulation_steps}"
)
if train_input.multilabel:
cmd += " --multilabel"
run_kwargs = self._base_docker_run_kwargs(context)
# Mount the checkpoint in the container if needed
maybe_mount(
run_kwargs["volumes"], host_checkpoint_dir, container_checkpoint_dir
)
# Mount the cache directory
maybe_mount(
run_kwargs["volumes"], self.host_cache_dir, Transformer._CONTAINER_CACHE_DIR
)
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
# Read in the generated evaluation results
with open(context.host_output_dir / Transformer._VALID_OUTPUT_FILE, "r") as f:
results = json.load(f)
return gobbli.io.TrainOutput(
valid_loss=results["mean_valid_loss"],
valid_accuracy=results["valid_accuracy"],
train_loss=results["mean_train_loss"],
multilabel=train_input.multilabel,
labels=labels,
checkpoint=context.host_output_dir / Transformer._TRAIN_OUTPUT_CHECKPOINT,
_console_output=container_logs,
)
def _read_predictions(self, predict_path: Path):
return pd.read_csv(predict_path, sep="\t")
def _predict(
self, predict_input: gobbli.io.PredictInput, context: ContainerTaskContext
) -> gobbli.io.PredictOutput:
self._write_input(
predict_input.X, None, context.host_input_dir / Transformer._TEST_INPUT_FILE
)
self._write_config(context.host_input_dir / Transformer._CONFIG_OVERRIDE_FILE)
labels = predict_input.labels
self._write_labels(
labels, context.host_input_dir / Transformer._LABELS_INPUT_FILE
)
host_checkpoint_dir, container_checkpoint_dir = self._get_checkpoint(
predict_input.checkpoint, context
)
cmd = (
"python3 run_model.py"
" predict"
f" --input-dir {context.container_input_dir}"
f" --output-dir {context.container_output_dir}"
f" --config-overrides {context.container_input_dir / Transformer._CONFIG_OVERRIDE_FILE}"
f" --model {self.transformer_model}"
f" --weights {self._get_weights(container_checkpoint_dir)}"
f" --cache-dir {Transformer._CONTAINER_CACHE_DIR}"
f" --max-seq-length {self.max_seq_length}"
f" --predict-batch-size {predict_input.predict_batch_size}"
)
if predict_input.multilabel:
cmd += " --multilabel"
run_kwargs = self._base_docker_run_kwargs(context)
# Mount the checkpoint in the container if needed
maybe_mount(
run_kwargs["volumes"], host_checkpoint_dir, container_checkpoint_dir
)
# Mount the cache directory
maybe_mount(
run_kwargs["volumes"], self.host_cache_dir, Transformer._CONTAINER_CACHE_DIR
)
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
return gobbli.io.PredictOutput(
y_pred_proba=self._read_predictions(
context.host_output_dir / Transformer._TEST_OUTPUT_FILE
),
_console_output=container_logs,
)
def _read_embeddings(
self, embed_path: Path, pooling: gobbli.io.EmbedPooling
) -> Tuple[List[np.ndarray], Optional[List[List[str]]]]:
embeddings = [] # type: List[np.ndarray]
doc_tokens = [] # type: List[List[str]]
with open(embed_path, "r") as f:
for line in f:
line_json = json.loads(line)
embeddings.append(np.array(line_json["embedding"]))
if pooling == gobbli.io.EmbedPooling.NONE:
doc_tokens.append(line_json["tokens"])
tokens = None
if pooling == gobbli.io.EmbedPooling.NONE:
tokens = doc_tokens
return embeddings, tokens
def _embed(
self, embed_input: gobbli.io.EmbedInput, context: ContainerTaskContext
) -> gobbli.io.EmbedOutput:
self._write_input(
embed_input.X,
None,
context.host_input_dir / Transformer._EMBEDDING_INPUT_FILE,
)
self._write_config(context.host_input_dir / Transformer._CONFIG_OVERRIDE_FILE)
host_checkpoint_dir, container_checkpoint_dir = self._get_checkpoint(
embed_input.checkpoint, context
)
cmd = (
"python3 run_model.py"
" embed"
f" --input-dir {context.container_input_dir}"
f" --output-dir {context.container_output_dir}"
f" --config-overrides {context.container_input_dir / Transformer._CONFIG_OVERRIDE_FILE}"
f" --model {self.transformer_model}"
f" --weights {self._get_weights(container_checkpoint_dir)}"
f" --cache-dir {Transformer._CONTAINER_CACHE_DIR}"
f" --max-seq-length {self.max_seq_length}"
f" --embed-batch-size {embed_input.embed_batch_size}"
f" --embed-pooling {embed_input.pooling.value}"
f" --embed-layer -2"
)
run_kwargs = self._base_docker_run_kwargs(context)
# Mount the checkpoint in the container if needed
maybe_mount(
run_kwargs["volumes"], host_checkpoint_dir, container_checkpoint_dir
)
# Mount the cache directory
maybe_mount(
run_kwargs["volumes"], self.host_cache_dir, Transformer._CONTAINER_CACHE_DIR
)
container_logs = run_container(
self.docker_client, self.image_tag, cmd, self.logger, **run_kwargs
)
X_embedded, embed_tokens = self._read_embeddings(
context.host_output_dir / Transformer._EMBEDDING_OUTPUT_FILE,
embed_input.pooling,
)
return gobbli.io.EmbedOutput(
X_embedded=X_embedded,
embed_tokens=embed_tokens,
_console_output=container_logs,
)