Spaces:
Paused
Paused
| import operator | |
| import datasets | |
| import pandas as pd | |
| from huggingface_hub import HfApi | |
| from ragatouille import RAGPretrainedModel | |
| from math import isnan | |
| api = HfApi() | |
| INDEX_DIR_PATH ="./content/my_index/" | |
| #INDEX_DIR_PATH = ".ragatouille/colbert/indexes/CVPR2024-papers-abstract-index/" | |
| #api.snapshot_download( | |
| # repo_id="CVPR2024/CVPR2024-papers-abstract-index", | |
| # repo_type="dataset", | |
| # local_dir=INDEX_DIR_PATH, | |
| #) | |
| ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH) | |
| # Run once to initialize the retriever | |
| ABSTRACT_RETRIEVER.search("LLM") | |
| class PaperList: | |
| COLUMN_INFO = [ | |
| ["Title", "str"], | |
| ["Authors", "str"], | |
| ["Paper page", "markdown"], | |
| ["π", "number"], | |
| ["π¬", "number"], | |
| ["GitHub", "markdown"], | |
| ["Spaces", "markdown"], | |
| ["Models", "markdown"], | |
| ["Datasets", "markdown"], | |
| ["claimed", "markdown"], | |
| ] | |
| def __init__(self): | |
| self.df_raw = self.get_df() | |
| self.df_prettified = self.prettify(self.df_raw) | |
| def get_df() -> pd.DataFrame: | |
| left_df = datasets.load_dataset("CVPR2024/CVPR2024-papers", split="train").to_pandas() | |
| right_df =datasets.load_dataset("CVPR2024/CVPR2024-paper-stats", split="train").to_pandas() | |
| left_df['id'] = left_df['id'].astype(int) | |
| right_df['id'] = right_df['id'].astype(int) | |
| left_df['authors'] = left_df['authors'].astype(str) | |
| right_df['authors'] = right_df['authors'].astype(str) | |
| left_df['title'] = left_df['title'].astype(str) | |
| right_df['title'] = right_df['title'].astype(str) | |
| df = pd.merge( | |
| left=left_df, | |
| right=right_df, | |
| on=["id", "authors", "title"], | |
| how="left", | |
| ) | |
| keys = ["n_authors", "n_linked_authors", "upvotes", "num_comments"] | |
| df[keys] = df[keys].fillna(-1).astype(int) | |
| df["paper_page"] = df["arxiv_id"].apply( | |
| lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else "" | |
| ) | |
| return df | |
| def create_link(text: str, url: str) -> str: | |
| return f'<a href="{url}" target="_blank">{text}</a>' | |
| def prettify(df: pd.DataFrame) -> pd.DataFrame: | |
| rows = [] | |
| for _, row in df.iterrows(): | |
| author_linked = "β " if row.n_linked_authors > 0 else "" | |
| n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors | |
| n_authors = "" if row.n_authors == -1 else row.n_authors | |
| claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}" | |
| upvotes = "" if row.upvotes == -1 else row.upvotes | |
| num_comments = "" if row.num_comments == -1 else row.num_comments | |
| new_row = { | |
| "Title": row["title"], | |
| "Authors": ", ".join(row["authors"]), | |
| "Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]) if not isnan(row["arxiv_id"]) else " ", | |
| "π": upvotes, | |
| "π¬": num_comments, | |
| "GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]] if row["GitHub"]!="[]" else " "), | |
| "Spaces": "\n".join( | |
| [ | |
| PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}") | |
| for repo_id in row["Space"] | |
| ] if row["Space"] != "[]" else [" "]), | |
| "Models": "\n".join( | |
| [PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]] | |
| if row["Model"] != "[]" else [" "]) , | |
| "Datasets": "\n".join( | |
| [ | |
| PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}") | |
| for repo_id in row["Dataset"] | |
| ] if row["Dataset"] != "[]" else [" "] | |
| ), | |
| "claimed": claimed_paper, | |
| } | |
| rows.append(new_row) | |
| return pd.DataFrame(rows, columns=PaperList.get_column_names()) | |
| def get_column_names(): | |
| return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO)) | |
| def get_column_datatypes(self, column_names: list[str]) -> list[str]: | |
| mapping = dict(self.COLUMN_INFO) | |
| return [mapping[name] for name in column_names] | |
| def search( | |
| self, | |
| title_search_query: str, | |
| abstract_search_query: str, | |
| max_num_to_retrieve: int, | |
| filter_names: list[str], | |
| columns_names: list[str], | |
| ) -> pd.DataFrame: | |
| df = self.df_raw.copy() | |
| # As ragatouille uses str for document_id | |
| df["id"] = df["id"].astype(str) | |
| # Filter by title | |
| df = df[df["title"].str.contains(title_search_query, case=False)] | |
| if "Paper page" in filter_names: | |
| df = df[df["paper_page"] != ""] | |
| if "GitHub" in filter_names: | |
| df = df[df["GitHub"].apply(len) > 0] | |
| if "Space" in filter_names: | |
| df = df[df["Space"].apply(len) > 0] | |
| if "Model" in filter_names: | |
| df = df[df["Model"].apply(len) > 0] | |
| if "Dataset" in filter_names: | |
| df = df[df["Dataset"].apply(len) > 0] | |
| # Filter by abstract | |
| if abstract_search_query: | |
| results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve) | |
| remaining_ids = set(map(str, df["id"])) | |
| found_id_set = set() | |
| found_ids = [] | |
| for x in results: | |
| paper_id = x["document_id"] | |
| if paper_id not in remaining_ids: | |
| continue | |
| if paper_id in found_id_set: | |
| continue | |
| found_id_set.add(paper_id) | |
| found_ids.append(paper_id) | |
| df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index() | |
| df_prettified = self.prettify(df) | |
| return df_prettified.loc[:, columns_names] | |