Source code for gobbli.model.mixin

"""
Mixins which can be applied to classes derived from BaseModel.
"""
from abc import ABCMeta, abstractmethod
from pathlib import Path
from timeit import default_timer as timer
from typing import Any, Callable, Optional, cast

import gobbli.io
from gobbli.model.context import ContainerTaskContext
from gobbli.util import format_duration, generate_uuid, write_metadata


def _run_task(
    task_func: Callable[[Any, ContainerTaskContext], Any],
    task_input: gobbli.io.TaskIO,
    root_dir: Path,
    dir_name: Optional[str] = None,
) -> gobbli.io.TaskIO:
    """
    Run a task function that generates some output.  Can create a unique id
    to name the directory storing the input/output or use a user-provided name.
    Generate a context object to pass to the task.
    """
    if dir_name is None:
        task_id = generate_uuid()
        task_root_dir = root_dir / task_id
    else:
        task_root_dir = root_dir / dir_name

    if task_root_dir.exists():
        raise ValueError(
            f"Directory '{task_root_dir}' already exists.  Supply a different `dir_name`."
        )
    context = ContainerTaskContext(task_root_dir=task_root_dir)

    write_metadata(
        task_input.metadata(),
        context.host_input_dir / gobbli.io.TaskIO._METADATA_FILENAME,
    )

    task_output = cast(gobbli.io.TaskIO, task_func(task_input, context))

    write_metadata(
        task_output.metadata(),
        context.host_output_dir / gobbli.io.TaskIO._METADATA_FILENAME,
    )

    return task_output


[docs]class TrainMixin(metaclass=ABCMeta): """ Apply to a model which can be trained in some way. """
[docs] @abstractmethod def data_dir(self): raise NotImplementedError
@property @abstractmethod def logger(self): raise NotImplementedError
[docs] def train_dir(self) -> Path: """ The directory to be used for data related to training (data files, etc). Returns: Path to the training data directory. """ return self.data_dir() / "train"
[docs] def train( self, train_input: gobbli.io.TrainInput, train_dir_name: Optional[str] = None ) -> gobbli.io.TrainOutput: """ Trains a model using params in the given :obj:`gobbli.io.TrainInput`. The training process varies depending on the model, but in general, it includes the following steps: - Update weights using the training dataset - Evaluate performance on the validation dataset. - Repeat for a number of epochs. - When finished, report loss/accuracy and return the trained weights. Args: train_input: Contains various parameters needed to determine how to train and what data to train on. train_dir_name: Optional name to store training input and output under. The directory is always created under the model's :meth:`data_dir<gobbli.model.base.BaseModel.data_dir>`. If a name is not given, a unique name is generated via a UUID. If a name is given, that directory must not already exist. Returns: Output of training. """ self.logger.info("Starting training.") start = timer() train_output = cast( gobbli.io.TrainOutput, _run_task(self._train, train_input, self.train_dir(), train_dir_name), ) end = timer() for log_line in ( f"Training finished in {format_duration(end - start)}.", "RESULTS:", f" Validation loss: {train_output.valid_loss}", f" Validation accuracy: {train_output.valid_accuracy}", f" Training loss: {train_output.train_loss}", ): self.logger.info(log_line) return train_output
@abstractmethod def _train( self, train_input: gobbli.io.TrainInput, context: ContainerTaskContext ) -> gobbli.io.TrainOutput: raise NotImplementedError
[docs]class PredictMixin(metaclass=ABCMeta): """ Apply to a model which can be used to predict on new data. """
[docs] @abstractmethod def data_dir(self): raise NotImplementedError
@property @abstractmethod def logger(self): raise NotImplementedError
[docs] def predict_dir(self) -> Path: """ The directory to be used for data related to prediction (weights, predictions, etc) Returns: Path to the prediction data directory. """ return self.data_dir() / "predict"
[docs] def predict( self, predict_input: gobbli.io.PredictInput, predict_dir_name: Optional[str] = None, ) -> gobbli.io.PredictOutput: """ Runs prediction on new data using params containing in the given :obj:`gobbli.io.PredictInput`. Args: predict_input: Contains various parameters needed to determine how to run prediction and what data to predict for. predict_dir_name: Optional name to store prediction input and output under. The directory is always created under the model's :meth:`data_dir<gobbli.model.base.BaseModel.data_dir>`. If a name is not given, a unique name is generated via a UUID. If a name is given, that directory must not already exist. """ self.logger.info("Starting prediction.") start = timer() predict_output = cast( gobbli.io.PredictOutput, _run_task( self._predict, predict_input, self.predict_dir(), predict_dir_name ), ) end = timer() self.logger.info(f"Prediction finished in {format_duration(end - start)}.") return predict_output
@abstractmethod def _predict( self, predict_input: gobbli.io.PredictInput, context: ContainerTaskContext ) -> gobbli.io.PredictOutput: raise NotImplementedError
[docs]class EmbedMixin(metaclass=ABCMeta): """ Apply to a model which can be used to generate embeddings from data. """
[docs] @abstractmethod def data_dir(self): raise NotImplementedError
@property @abstractmethod def logger(self): raise NotImplementedError
[docs] def embed_dir(self) -> Path: """ The directory to be used for data related to embedding (weights, embeddings, etc) Returns: Path to the embedding data directory. """ return self.data_dir() / "embed"
[docs] def embed( self, embed_input: gobbli.io.EmbedInput, embed_dir_name: Optional[str] = None ) -> gobbli.io.EmbedOutput: """ Generates embeddings using a model and the params in the given :obj:`gobbli.io.EmbedInput`. Args: embed_input: Contains various parameters needed to determine how to generate embeddings and what data to generate embeddings for. embed_dir_name: Optional name to store embedding input and output under. The directory is always created under the model's :meth:`data_dir<gobbli.model.base.BaseModel.data_dir>`. If a name is not given, a unique name is generated via a UUID. If a name is given, that directory must not already exist. Returns: Output of training. """ self.logger.info("Generating embeddings.") start = timer() embed_output = cast( gobbli.io.EmbedOutput, _run_task(self._embed, embed_input, self.embed_dir(), embed_dir_name), ) end = timer() self.logger.info( f"Embedding generation finished in {format_duration(end - start)}." ) return embed_output
@abstractmethod def _embed( self, embed_input: gobbli.io.EmbedInput, context: ContainerTaskContext ) -> gobbli.io.EmbedOutput: raise NotImplementedError