gobbli.io module¶
-
class
gobbli.io.
EmbedInput
(X, embed_batch_size=32, pooling=<EmbedPooling.MEAN: 'mean'>, checkpoint=None)[source]¶ Bases:
gobbli.io.TaskIO
Input for generating embeddings using a model. See
gobbli.model.mixin.EmbedMixin.embed()
.- Parameters
X¶ (
List
[str
]) – Documents to generate embeddings for.embed_batch_size¶ (
int
) – Number of documents to embed at a time.pooling¶ (
EmbedPooling
) – Pooling method to use for resulting embeddings.checkpoint¶ (
Optional
[Path
]) – Checkpoint containing trained weights for the model. SeeTrainOutput.params.checkpoint
.
-
checkpoint
= None¶
-
embed_batch_size
= 32¶
-
metadata
()[source]¶ Returns: The task information that constitutes its metadata – generally parameters of an input task and/or summarized results of an output task.
- Return type
Dict
[str
,Any
]
-
pooling
= 'mean'¶
-
class
gobbli.io.
EmbedOutput
(X_embedded, embed_tokens=None, _console_output='')[source]¶ Bases:
gobbli.io.TaskIO
Output from generating embeddings. See
gobbli.model.mixin.EmbedMixin.embed()
.- Parameters
X_embedded¶ (
List
[ndarray
]) –A list of ndarrays representing the embedding for each document. The shape of each array depends on pooling method.
l
= length of the document, andd
= dimensionality of embedding.Mean pooling (default):
(d,)
No pooling:
(l, d)
embed_tokens¶ (
Optional
[List
[List
[str
]]]) – If pooling strategy is “NONE”, this is the list of tokens corresponding to each embedding for each document. Otherwise, it’sNone
._console_output¶ (
str
) – Raw console output from the container used to generate the embeddings.
-
embed_tokens
= None¶
-
class
gobbli.io.
EmbedPooling
[source]¶ Bases:
enum.Enum
Enum describing all the different pooling methods that can be used when generating embeddings.
-
MEAN
¶ Take the mean across all tokens as the embedding for the document.
-
NONE
¶ Return the token-wise embeddings for each document.
-
MEAN
= 'mean'
-
NONE
= 'none'
-
-
class
gobbli.io.
PredictInput
(X, labels, multilabel=False, predict_batch_size=32, checkpoint=None)[source]¶ Bases:
gobbli.io.TaskIO
Input for generating predictions using a model. See
gobbli.model.mixin.PredictMixin.predict()
.- Parameters
X¶ (
List
[str
]) – Documents to have labels predicted for.labels¶ (
List
[str
]) – SeeTrainOutput.params.labels
.multilabel¶ (
bool
) – True if the model was trained in a multilabel context, otherwise False (indicating a multiclass context).predict_batch_size¶ (
int
) – Number of documents to predict in each batch.checkpoint¶ (
Optional
[Path
]) – Checkpoint containing trained weights for the model. SeeTrainOutput.params.checkpoint
.
-
checkpoint
= None¶
-
metadata
()[source]¶ Returns: The task information that constitutes its metadata – generally parameters of an input task and/or summarized results of an output task.
- Return type
Dict
[str
,Any
]
-
multilabel
= False¶
-
predict_batch_size
= 32¶
-
class
gobbli.io.
PredictOutput
(y_pred_proba, _console_output='')[source]¶ Bases:
gobbli.io.TaskIO
Output from generating predictions using a model. See
gobbli.model.mixin.PredictMixin.predict()
.- Parameters
-
metadata
()[source]¶ Returns: The task information that constitutes its metadata – generally parameters of an input task and/or summarized results of an output task.
- Return type
Dict
[str
,Any
]
-
property
y_pred
¶ - Return type
List
[str
]- Returns
The most likely predicted label for each observation.
-
gobbli.io.
T
= ~T¶
-
class
gobbli.io.
TrainInput
(X_train, y_train, X_valid, y_valid, train_batch_size=32, valid_batch_size=8, num_train_epochs=3, checkpoint=None)[source]¶ Bases:
gobbli.io.TaskIO
Input for training a model. See
gobbli.model.mixin.TrainMixin.train()
.For usage specific to a multiclass or multilabel paradigm, consider using the more specifically checked and typed properties:
y_{train,valid}_{multiclass,multilabel}
as opposed to the more generically typedy_{train,valid}
attributes.- Parameters
X_train¶ (
List
[str
]) – Documents used for training.y_train¶ (
Union
[List
[str
],List
[List
[str
]]]) – Labels for training documents.X_valid¶ (
List
[str
]) – Documents used for validation.y_valid¶ (
Union
[List
[str
],List
[List
[str
]]]) – Labels for validation documents.train_batch_size¶ (
int
) – Number of observations per batch on the training dataset.valid_batch_size¶ (
int
) – Number of observations per batch on the validation dataset.num_train_epochs¶ (
int
) – Number of epochs to use for training.checkpoint¶ (
Optional
[Path
]) – Checkpoint containing trained weights for the model. If passed, training will continue from the checkpoint instead of starting from scratch. SeeTrainOutput.params.checkpoint
.
-
checkpoint
= None¶
-
labels
()[source]¶ - Return type
List
[str
]- Returns
The set of unique labels in the data. Sort and return a list for consistent ordering, in case that matters.
-
metadata
()[source]¶ Returns: The task information that constitutes its metadata – generally parameters of an input task and/or summarized results of an output task.
- Return type
Dict
[str
,Any
]
-
num_train_epochs
= 3¶
-
train_batch_size
= 32¶
-
valid_batch_size
= 8¶
-
property
y_train_multiclass
¶ - Return type
List
[str
]
-
property
y_train_multilabel
¶ - Return type
List
[List
[str
]]
-
property
y_valid_multiclass
¶ - Return type
List
[str
]
-
property
y_valid_multilabel
¶ - Return type
List
[List
[str
]]
-
class
gobbli.io.
TrainOutput
(valid_loss, valid_accuracy, train_loss, labels, multilabel, checkpoint=None, _console_output='')[source]¶ Bases:
gobbli.io.TaskIO
Output from model training. See
gobbli.model.mixin.TrainMixin.train()
.- Parameters
valid_loss¶ (
float
) – Loss on the validation dataset.valid_accuracy¶ (
float
) – Accuracy on the validation dataset.train_loss¶ (
float
) – Loss on the training dataset.labels¶ (
List
[str
]) – List of labels present in the training data. Used to initialize the model for prediction.multilabel¶ (
bool
) – True if the model was trained in a multilabel context, otherwise False (indicating a multiclass context).checkpoint¶ (
Optional
[Path
]) – Path to the best checkpoint from training. This may not be a literal filepath in the case of ex. TensorFlow, but it should give the user everything they need to run prediction using the results of training._console_output¶ (
str
) – Raw console output from the container used to train the model.
-
checkpoint
= None¶
-
class
gobbli.io.
WindowPooling
[source]¶ Bases:
enum.Enum
Enum describing all the different pooling methods that can be used when pooling model output from windowed documents.
-
MEAN
¶ Take the mean across all dimensions/classes as the output for the document.
-
MAX
¶ Take the max across all dimensions/classes as the output for the document.
-
MIN
¶ Take the min across all dimensions/classes as the output for the document.
-
MAX
= 'max'
-
MEAN
= 'mean'
-
MIN
= 'min'
-
-
gobbli.io.
make_document_windows
(X, window_len, y=None, tokenize_method=<TokenizeMethod.SPLIT: 'split'>, model_path=None, vocab_size=None)[source]¶ This is a helper for when you have a dataset with long documents which is going to be passed through a model with a fixed max sequence length. If you don’t have enough memory to raise the max sequence length, but you don’t want to miss out on the information in longer documents, you can use this helper to generate a dataset that splits each document into windows roughly the size of your
max_seq_len
. The resulting dataset can then be used to train your model. You should then usepool_document_windows()
to pool the results from downstream tasks (ex. predictions, embeddings).Note there may still be some mismatch between the window size and the size as tokenized by your model, since some models use custom tokenization methods.
- Parameters
X¶ (
List
[str
]) – List of texts to make windows out of.window_len¶ (
int
) – The maximum length of each window. This should roughly correspond to themax_seq_len
of your model.y¶ (
Optional
[List
[~T]]) – Optional list of classes (or list of list of labels). If passed, a corresponding list of targets for each window (the target(s) associated with the window’s document) will be returned.tokenize_method¶ (
TokenizeMethod
) –gobbli.util.TokenizeMethod
corresponding to the tokenization method to use for determining windows.model_path¶ (
Optional
[Path
]) – This argument is used if the tokenization method requires training a model; otherwise, it’s ignored. Path for a tokenization model. If it doesn’t exist, a new tokenization model will be trained and saved at the given path. If it does exist, the existing model will be used. If no path is given, a temporary directory will be created/used and discardedvocab_size¶ (
Optional
[int
]) – Number of terms in the vocabulary for tokenization. May be ignored depending on the tokenization method and whether a model is already trained.
- Return type
Tuple
[List
[str
],List
[int
],Optional
[List
[~T]]]- Returns
A 3-tuple containing a new list of texts split into windows, a corresponding list containing the index of each original document for each window, and (optionally) a list containing a target per window. The index should be used to pool the output from the windowed text (see
pool_document_windows()
).
-
gobbli.io.
pool_document_windows
(unpooled_output, window_indices, pooling=<WindowPooling.MEAN: 'mean'>)[source]¶ This helper pools output from a model whose input was document windows generated by
make_document_windows()
. The output can be pooled in multiple ways. SeeWindowPooling
for more info.This function mutates the passed output object to preserve other information in the output object.
- Parameters
unpooled_output¶ (
Union
[PredictOutput
,EmbedOutput
]) – The output from the model to be pooled.window_indices¶ (
List
[int
]) – A list (size = number of rows inunpooled_output
) of integers corresponding to the index of the original document for each window. These are used to group the window output appropriately.pooling¶ (
WindowPooling
) – The method to use for pooling.
-
gobbli.io.
validate_X
(X)[source]¶ Confirm a given array matches the expected type for input.
- Parameters
X¶ (
List
[str
]) – Something that should be valid model input.
-
gobbli.io.
validate_X_y
(X, y)[source]¶ Assuming X is valid input and y is valid output, ensure they match sizes.