Source code for gobbli.model.use.model

import json
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List

import numpy as np

import gobbli.io
from gobbli.docker import run_container
from gobbli.model.base import BaseModel
from gobbli.model.context import ContainerTaskContext
from gobbli.model.mixin import EmbedMixin
from gobbli.util import assert_in, download_archive, escape_line_delimited_texts


def _read_embeddings(output_file: Path) -> List[np.ndarray]:
    embeddings = []  # type: List[np.ndarray]
    with open(output_file, "r") as f:
        for line in f:
            embeddings.append(np.array(json.loads(line)))
    return embeddings


USE_MODEL_ARCHIVES = {
    "universal-sentence-encoder": "https://tfhub.dev/google/universal-sentence-encoder/4?tf-hub-format=compressed",
    "universal-sentence-encoder-large": "https://tfhub.dev/google/universal-sentence-encoder-large/5?tf-hub-format=compressed",
    "universal-sentence-encoder-multilingual": "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3?tf-hub-format=compressed",
    "universal-sentence-encoder-multilingual-large": "https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3?tf-hub-format=compressed",
}
"""
A mapping from model names to TFHub URLs.
"universal-sentence-encoder" is a safe default for most situations.
Larger models require more time and GPU memory to run.
"""


[docs]class USE(BaseModel, EmbedMixin): """ Wrapper for Universal Sentence Encoder embeddings: https://tfhub.dev/google/universal-sentence-encoder/4 """ _BUILD_PATH = Path(__file__).parent _INPUT_FILE = "input.txt" _OUTPUT_FILE = "output.jsonl"
[docs] def init(self, params: Dict[str, Any]): """ See :meth:`gobbli.model.base.BaseModel.init`. USE parameters: - ``use_model`` (:obj:`str`): Name of a USE model to use. See :obj:`USE_MODEL_ARCHIVES` for a listing of available USE models. """ self.use_model = "universal-sentence-encoder" for name, value in params.items(): if name == "use_model": assert_in(name, value, set(USE_MODEL_ARCHIVES.keys())) self.use_model = value else: raise ValueError(f"Unknown param '{name}'")
@property def image_tag(self) -> str: """ Returns: The Docker image tag to be used for the USE container. """ device = "gpu" if self.use_gpu else "cpu" return f"gobbli-use-embeddings-{device}" @property def weights_dir(self) -> Path: """ Returns: Directory containing pretrained weights for this instance. """ return self.class_weights_dir / self.use_model def _build(self): # Download data if we don't already have it if not self.weights_dir.exists(): with tempfile.TemporaryDirectory() as tmpdir: tmp_weights_dir = Path(tmpdir) / self.weights_dir.name tmp_weights_dir.mkdir() self.logger.info("Downloading pre-trained weights.") download_archive( USE_MODEL_ARCHIVES[self.use_model], tmp_weights_dir, filename=f"{self.use_model}.tar.gz", ) shutil.move(tmp_weights_dir, self.weights_dir) self.logger.info("Weights downloaded.") # Build the docker image self.docker_client.images.build( path=str(USE._BUILD_PATH), tag=self.image_tag, **self._base_docker_build_kwargs, ) def _embed( self, embed_input: gobbli.io.EmbedInput, context: ContainerTaskContext ) -> gobbli.io.EmbedOutput: if embed_input.pooling == gobbli.io.EmbedPooling.NONE: raise ValueError( "Universal Sentence Encoder does sentence encoding, so pooling is required." ) (context.host_input_dir / USE._INPUT_FILE).write_text( escape_line_delimited_texts(embed_input.X) ) cmd = ( "python use.py" f" --input-file={context.container_input_dir / USE._INPUT_FILE}" f" --output-file={context.container_output_dir / USE._OUTPUT_FILE}" f" --module-dir={BaseModel._CONTAINER_WEIGHTS_PATH}" f" --batch-size={embed_input.embed_batch_size}" ) container_logs = run_container( self.docker_client, self.image_tag, cmd, self.logger, **self._base_docker_run_kwargs(context), ) return gobbli.io.EmbedOutput( X_embedded=_read_embeddings(context.host_output_dir / USE._OUTPUT_FILE), _console_output=container_logs, )