gobbli.experiment.classification module

class gobbli.experiment.classification.ClassificationExperiment(*args, **kwargs)[source]

Bases: gobbli.experiment.base.BaseExperiment

Run a classification experiment. This entails training a model to make predictions given some input.

The experiment will, for each combination of model hyperparameters, train the model on a training set and evaluate it on a validation set. The best combination of hyperparameters will be retrained on the combined training/validation sets and evaluated on the test set. After completion, the experiment will return ClassificationExperimentResults, which will allow the user to examine the results in various ways.

Return type



The main data directory unique to this experiment.

property metadata_path
Return type



The path to the experiment’s metadata file containing information about the experiment parameters.

run(dataset_split=None, seed=1, train_batch_size=32, valid_batch_size=32, test_batch_size=32, num_train_epochs=5)[source]

Run the experiment.

  • dataset_split (Union[Tuple[float, float], Tuple[float, float, float], None]) – A tuple describing the proportion of the dataset to be added to the train/validation/test splits. If the experiment uses an explicit test set (passes BaseExperiment.params.test_dataset), this should be a 2-tuple describing the train/validation split. Otherwise, it should be a 3-tuple describing the train/validation/test split. The tuple must sum to 1.

  • seed (int) – Random seed to be used for dataset splitting for reproducibility.

  • train_batch_size (int) – Number of observations per batch on the training dataset.

  • valid_batch_size (int) – Number of observations per batch on the validation dataset.

  • test_batch_size (int) – Number of observations per batch on the test dataset.

  • num_train_epochs (int) – Number of epochs to use for training.

Return type



The results of the experiment.

class gobbli.experiment.classification.ClassificationExperimentResults(training_results, labels, X, y_true, y_pred_proba, best_model_checkpoint, best_model_checkpoint_name, metric_funcs=None)[source]

Bases: object

Results from a classification experiment. An experiment entails training a set of models based on a grid of parameters, retraining on the full train/validation dataset with the best set of parameters, and evaluating predictions on the test set.

  • training_results (List[Dict[str, Any]]) – A list of dictionaries containing information about each training run, one for each unique combination of hyperparameters in BaseExperiment.params.param_grid.

  • labels (List[str]) – The set of unique labels in the dataset.

  • X (List[str]) – The list of texts to classify.

  • y_true (Union[List[str], List[List[str]]]) – The true labels for the test set, as passed by the user.

  • y_pred_proba (DataFrame) – A dataframe containing a row for each observation in the test set and a column for each label in the training data. Cells are predicted probabilities.

  • best_model_checkpoint (Union[bytes, Path]) – If results came from another process on the master node, this is the directory containing the checkpoint. If the results came from a worker node, this is a bytes object containing the compressed model weights.

  • best_model_checkpoint_name (str) – Path to the best checkpoint within the directory or or compressed blob.

  • metric_funcs (Optional[Dict[str, Callable[[Sequence[+T_co], Sequence[+T_co]], float]]]) – Overrides for the default set of metric functions used to evaluate the classifier’s performance.

errors(*args, **kwargs)[source]

See ClassificationEvaluation.errors().

Return type

Dict[str, Tuple[List[ClassificationError], List[ClassificationError]]]

errors_report(*args, **kwargs)[source]

See ClassificationEvaluation.errors_report().

Return type



Return a filesystem path to our checkpoint, which can be used to initialize future models from the same state. If a base_path is provided, copy/extract the checkpoint under that path.

NOTE: If no base_path is provided and the checkpoint comes from a remote worker, the checkpoint will be extracted to a temporary directory, and a warning will be emitted. gobbli will make no effort to ensure the temporary directory is cleaned up after creation.


base_path (Optional[Path]) – Optional directory to extract/copy the checkpoint to. If not provided, the original path will be returned if the checkpoint already existed on the current machine’s filesystem. If the checkpoint is a bytes object, a temporary directory will be created. The directory must not already exist.

Return type



The path to the extracted checkpoint.

metric_funcs = None
metrics(*args, **kwargs)[source]

See ClassificationEvaluation.metrics().

Return type

Dict[str, float]

metrics_report(*args, **kwargs)[source]

See ClassificationEvaluation.metrics_report().

Return type


plot(*args, **kwargs)[source]

See ClassificationEvaluation.plot().

Return type


class gobbli.experiment.classification.RemoteTrainResult(metadata, labels, checkpoint_name, checkpoint_id, model_params, ip_address)[source]

Bases: object

Results from a training process on a (possibly remote) worker.

  • metadata (Dict[str, Any]) – Metadata from the training output.

  • labels (List[str]) – List of labels identified in the data.

  • checkpoint_name (str) – Name of the checkpoint under the checkpoint directory.

  • checkpoint_id (ObjectID) – ray ObjectID for the checkpoint directory (path or bytes)

  • model_params (Dict[str, Any]) – Parameters used to initialize the model.

  • ip_address (str) – IP address of the node that ran training.