Source code for gobbli.dataset.cmu_movie_summary
import json
from typing import List, Tuple
import pandas as pd
from gobbli.dataset.base import BaseDataset
from gobbli.util import download_archive
[docs]class MovieSummaryDataset(BaseDataset):
"""
gobbli Dataset for the CMU Movie Summary dataset, framed as a multilabel
classification problem predicting movie genres from plot summaries.
http://www.cs.cmu.edu/~ark/personas/
"""
PLOT_SUMMARIES_FILE = "MovieSummaries/plot_summaries.txt"
METADATA_FILE = "MovieSummaries/movie.metadata.tsv"
TRAIN_PCT = 0.8
def _build(self):
data_dir = self.data_dir()
data_dir.mkdir(exist_ok=True, parents=True)
download_archive(
"http://www.cs.cmu.edu/~ark/personas/data/MovieSummaries.tar.gz", data_dir
)
@staticmethod
def _make_multilabels(genres: pd.Series) -> List[List[str]]:
return [list(json.loads(g).values()) for g in genres]
def _is_built(self) -> bool:
data_dir = self.data_dir()
return (data_dir / MovieSummaryDataset.PLOT_SUMMARIES_FILE).exists() and (
data_dir / MovieSummaryDataset.METADATA_FILE
).exists()
def _get_source_df_split(self) -> Tuple[pd.DataFrame, int]:
if not hasattr(self, "_source_df"):
data_dir = self.data_dir()
plot_df = pd.read_csv(
data_dir / MovieSummaryDataset.PLOT_SUMMARIES_FILE,
delimiter="\t",
index_col=0,
header=None,
names=["wiki_id", "plot"],
)
meta_df = pd.read_csv(
data_dir / MovieSummaryDataset.METADATA_FILE,
delimiter="\t",
index_col=0,
header=None,
names=[
"wiki_id",
"freebase_id",
"name",
"release_date",
"revenue",
"runtime",
"languages",
"countries",
"genres",
],
)
self._source_df = plot_df.join(meta_df, how="inner")[
["plot", "genres"]
].sort_index()
return (
self._source_df,
int(len(self._source_df) * MovieSummaryDataset.TRAIN_PCT),
)
[docs] def X_train(self):
source_df, split_ndx = self._get_source_df_split()
return source_df["plot"].tolist()[:split_ndx]
[docs] def y_train(self):
source_df, split_ndx = self._get_source_df_split()
return MovieSummaryDataset._make_multilabels(source_df["genres"][:split_ndx])
[docs] def X_test(self):
source_df, split_ndx = self._get_source_df_split()
return source_df["plot"].tolist()[split_ndx:]
[docs] def y_test(self):
source_df, split_ndx = self._get_source_df_split()
return MovieSummaryDataset._make_multilabels(source_df["genres"][split_ndx:])