gobbli.experiment.classification module¶
-
class
gobbli.experiment.classification.ClassificationExperiment(*args, **kwargs)[source]¶ Bases:
gobbli.experiment.base.BaseExperimentRun a classification experiment. This entails training a model to make predictions given some input.
The experiment will, for each combination of model hyperparameters, train the model on a training set and evaluate it on a validation set. The best combination of hyperparameters will be retrained on the combined training/validation sets and evaluated on the test set. After completion, the experiment will return
ClassificationExperimentResults, which will allow the user to examine the results in various ways.-
data_dir()¶ - Return type
Path- Returns
The main data directory unique to this experiment.
-
property
metadata_path¶ - Return type
Path- Returns
The path to the experiment’s metadata file containing information about the experiment parameters.
-
run(dataset_split=None, seed=1, train_batch_size=32, valid_batch_size=32, test_batch_size=32, num_train_epochs=5)[source]¶ Run the experiment.
- Parameters
dataset_split¶ (
Union[Tuple[float,float],Tuple[float,float,float],None]) – A tuple describing the proportion of the dataset to be added to the train/validation/test splits. If the experiment uses an explicit test set (passesBaseExperiment.params.test_dataset), this should be a 2-tuple describing the train/validation split. Otherwise, it should be a 3-tuple describing the train/validation/test split. The tuple must sum to 1.seed¶ (
int) – Random seed to be used for dataset splitting for reproducibility.train_batch_size¶ (
int) – Number of observations per batch on the training dataset.valid_batch_size¶ (
int) – Number of observations per batch on the validation dataset.test_batch_size¶ (
int) – Number of observations per batch on the test dataset.num_train_epochs¶ (
int) – Number of epochs to use for training.
- Return type
- Returns
The results of the experiment.
-
-
class
gobbli.experiment.classification.ClassificationExperimentResults(training_results, labels, X, y_true, y_pred_proba, best_model_checkpoint, best_model_checkpoint_name, metric_funcs=None)[source]¶ Bases:
objectResults from a classification experiment. An experiment entails training a set of models based on a grid of parameters, retraining on the full train/validation dataset with the best set of parameters, and evaluating predictions on the test set.
- Parameters
training_results¶ (
List[Dict[str,Any]]) – A list of dictionaries containing information about each training run, one for each unique combination of hyperparameters inBaseExperiment.params.param_grid.labels¶ (
List[str]) – The set of unique labels in the dataset.X¶ (
List[str]) – The list of texts to classify.y_true¶ (
Union[List[str],List[List[str]]]) – The true labels for the test set, as passed by the user.y_pred_proba¶ (
DataFrame) – A dataframe containing a row for each observation in the test set and a column for each label in the training data. Cells are predicted probabilities.best_model_checkpoint¶ (
Union[bytes,Path]) – If results came from another process on the master node, this is the directory containing the checkpoint. If the results came from a worker node, this is a bytes object containing the compressed model weights.best_model_checkpoint_name¶ (
str) – Path to the best checkpoint within the directory or or compressed blob.metric_funcs¶ (
Optional[Dict[str,Callable[[Sequence[+T_co],Sequence[+T_co]],float]]]) – Overrides for the default set of metric functions used to evaluate the classifier’s performance.
-
errors(*args, **kwargs)[source]¶ See
ClassificationEvaluation.errors().- Return type
Dict[str,Tuple[List[ClassificationError],List[ClassificationError]]]
-
errors_report(*args, **kwargs)[source]¶ See
ClassificationEvaluation.errors_report().- Return type
str
-
get_checkpoint(base_path=None)[source]¶ Return a filesystem path to our checkpoint, which can be used to initialize future models from the same state. If a base_path is provided, copy/extract the checkpoint under that path.
NOTE: If no base_path is provided and the checkpoint comes from a remote worker, the checkpoint will be extracted to a temporary directory, and a warning will be emitted. gobbli will make no effort to ensure the temporary directory is cleaned up after creation.
- Parameters
base_path¶ (
Optional[Path]) – Optional directory to extract/copy the checkpoint to. If not provided, the original path will be returned if the checkpoint already existed on the current machine’s filesystem. If the checkpoint is a bytes object, a temporary directory will be created. The directory must not already exist.- Return type
Path- Returns
The path to the extracted checkpoint.
-
metric_funcs= None¶
-
metrics(*args, **kwargs)[source]¶ See
ClassificationEvaluation.metrics().- Return type
Dict[str,float]
-
class
gobbli.experiment.classification.RemoteTrainResult(metadata, labels, checkpoint_name, checkpoint_id, model_params, ip_address)[source]¶ Bases:
objectResults from a training process on a (possibly remote) worker.
- Parameters
metadata¶ (
Dict[str,Any]) – Metadata from the training output.labels¶ (
List[str]) – List of labels identified in the data.checkpoint_name¶ (
str) – Name of the checkpoint under the checkpoint directory.checkpoint_id¶ (
ObjectID) – ray ObjectID for the checkpoint directory (path or bytes)model_params¶ (
Dict[str,Any]) – Parameters used to initialize the model.ip_address¶ (
str) – IP address of the node that ran training.