Spaces:
Runtime error
Runtime error
| from typing import List | |
| from sentence_transformers import SentenceTransformer | |
| from kmeans_pytorch import kmeans | |
| import torch | |
| from sklearn.cluster import KMeans | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,Text2TextGenerationPipeline | |
| from inference_hf import InferenceHF | |
| from .dimension_reduction import PCA | |
| from unsupervised_learning.clustering import GaussianMixture | |
| from models import KeyBartAdapter | |
| class Template: | |
| def __init__(self): | |
| self.PLM = { | |
| 'sentence-transformer-mini': '''sentence-transformers/all-MiniLM-L6-v2''', | |
| 'sentence-t5-xxl': '''sentence-transformers/sentence-t5-xxl''', | |
| 'all-mpnet-base-v2':'''sentence-transformers/all-mpnet-base-v2''' | |
| } | |
| self.dimension_reduction = { | |
| 'pca': PCA, | |
| 'vae': None, | |
| 'cnn': None | |
| } | |
| self.clustering = { | |
| 'kmeans-cosine': kmeans, | |
| 'kmeans-euclidean': KMeans, | |
| 'gmm': GaussianMixture | |
| } | |
| self.keywords_extraction = { | |
| 'keyphrase-transformer': '''snrspeaks/KeyPhraseTransformer''', | |
| 'KeyBartAdapter': '''Adapting/KeyBartAdapter''', | |
| 'KeyBart': '''bloomberg/KeyBART''' | |
| } | |
| template = Template() | |
| def __create_model__(model_ckpt): | |
| ''' | |
| :param model_ckpt: keys in Template class | |
| :return: model/function: callable | |
| ''' | |
| if model_ckpt == '''sentence-transformer-mini''': | |
| return SentenceTransformer(template.PLM[model_ckpt]) | |
| elif model_ckpt == '''sentence-t5-xxl''': | |
| return SentenceTransformer(template.PLM[model_ckpt]) | |
| elif model_ckpt == '''all-mpnet-base-v2''': | |
| return SentenceTransformer(template.PLM[model_ckpt]) | |
| elif model_ckpt == 'none': | |
| return None | |
| elif model_ckpt == 'kmeans-cosine': | |
| def ret(x,k): | |
| tmp = template.clustering[model_ckpt]( | |
| X=torch.from_numpy(x), num_clusters=k, distance='cosine', | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ) | |
| return tmp[0].cpu().detach().numpy(), tmp[1].cpu().detach().numpy() | |
| return ret | |
| elif model_ckpt == 'pca': | |
| pca = template.dimension_reduction[model_ckpt](0.95) | |
| return pca | |
| elif model_ckpt =='kmeans-euclidean': | |
| def ret(x,k): | |
| tmp = KMeans(n_clusters=k,random_state=50).fit(x) | |
| return tmp.labels_, tmp.cluster_centers_ | |
| return ret | |
| elif model_ckpt == 'gmm': | |
| def ret(x,k): | |
| model = GaussianMixture(k,50) | |
| model.fit(x) | |
| return model.getLabels(), model.getClusterCenters() | |
| return ret | |
| elif model_ckpt == 'keyphrase-transformer': | |
| model_ckpt = template.keywords_extraction[model_ckpt] | |
| def ret(texts: List[str]): | |
| # first try inference API | |
| response = InferenceHF.inference( | |
| inputs=texts, | |
| model_name=model_ckpt | |
| ) | |
| # inference failed: | |
| if not isinstance(response, list): | |
| tokenizer = AutoTokenizer.from_pretrained(model_ckpt) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt) | |
| pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer) | |
| tmp = pipe(texts) | |
| results = [ | |
| set( | |
| map(str.strip, | |
| x['generated_text'].split('|') # [str...] | |
| ) | |
| ) | |
| for x in tmp] # [{str...}...] | |
| return results | |
| # inference sucsess | |
| else: | |
| results = [ | |
| set( | |
| map(str.strip, | |
| x['generated_text'].split('|') # [str...] | |
| ) | |
| ) | |
| for x in response] # [{str...}...] | |
| return results | |
| return ret | |
| elif model_ckpt == 'KeyBart': | |
| model_ckpt = template.keywords_extraction[model_ckpt] | |
| def ret(texts: List[str]): | |
| # first try inference API | |
| response = InferenceHF.inference( | |
| inputs=texts, | |
| model_name=model_ckpt | |
| ) | |
| # inference failed: | |
| if not isinstance(response,list): | |
| tokenizer = AutoTokenizer.from_pretrained(model_ckpt) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt) | |
| pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer) | |
| tmp = pipe(texts) | |
| results = [ | |
| set( | |
| map(str.strip, | |
| x['generated_text'].split(';') # [str...] | |
| ) | |
| ) | |
| for x in tmp] # [{str...}...] | |
| return results | |
| # inference sucsess | |
| else: | |
| results = [ | |
| set( | |
| map(str.strip, | |
| x['generated_text'].split(';') # [str...] | |
| ) | |
| ) | |
| for x in response] # [{str...}...] | |
| return results | |
| return ret | |
| elif model_ckpt == 'KeyBartAdapter': | |
| def ret(texts: List[str]): | |
| model = KeyBartAdapter.from_pretrained('Adapting/KeyBartAdapter',revision='3aee5ecf1703b9955ab0cd1b23208cc54eb17fce', adapter_hid_dim=32) | |
| tokenizer = AutoTokenizer.from_pretrained("bloomberg/KeyBART") | |
| pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer) | |
| tmp = pipe(texts) | |
| results = [ | |
| set( | |
| map(str.strip, | |
| x['generated_text'].split(';') # [str...] | |
| ) | |
| ) | |
| for x in tmp] # [{str...}...] | |
| return results | |
| return ret | |
| else: | |
| raise RuntimeError(f'The model {model_ckpt} is not supported. Please open an issue on the GitHub about the model.') | |