| import streamlit as st |
|
|
| from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertConfig |
| import torch |
| from torch.nn.functional import softmax |
|
|
|
|
|
|
| base_model_name = 'distilbert-base-uncased' |
|
|
| @st.cache_data |
| def load_tags_info(): |
| |
| id_to_description = {} |
| with open('tags.txt', 'r') as file: |
| i = 0 |
| for line in file: |
| |
| description = line[:-1] |
| |
| id_to_description[i] = description |
| |
| i += 1 |
| |
| return id_to_description |
|
|
| id_to_description = load_tags_info() |
|
|
| @st.cache_resource |
| def load_model(): |
| config = DistilBertConfig.from_json_file('./config.json') |
| model = DistilBertForSequenceClassification(config) |
| state_dict = torch.load('./pytorch_model.bin', map_location=torch.device('cpu')) |
| model.load_state_dict(state_dict) |
| return model |
|
|
| def load_tokenizer(): |
| return AutoTokenizer.from_pretrained('distilbert-base-uncased') |
|
|
| def top_xx(preds, xx=95): |
| tops = torch.argsort(preds, 1, descending=True) |
| total = 0 |
| index = 0 |
| result = [] |
| while total < xx / 100: |
| next_id = tops[0, index].item() |
| total += preds[0, next_id] |
| index += 1 |
| result.append(id_to_description[next_id]) |
| return result |
|
|
| model = load_model() |
| tokenizer = load_tokenizer() |
| temperature = 1 |
|
|
| st.title('ArXivTager') |
| st.caption('Напишите тему (Title) и параграф из статьи (Abstract). Поля должны быть ЗАПОЛНЕНЫ текстом на АНГЛИЙСКОМ языке для корректной классификации.') |
|
|
| with st.form("ArXivTager"): |
| |
| title = st.text_area(label='Title', height=100) |
| abstract = st.text_area(label='Abstract (optional)', height=200) |
| st.caption('ВЫВОД: набор тем в порядке уменьшения вероятностей.') |
|
|
| submitted = st.form_submit_button("Get tags") |
| if submitted: |
| if title == '': |
| st.markdown("Нужно хоть что-то написать") |
| else: |
| prompt = 'Title: ' + title + ' Abstract: ' + abstract |
| tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] |
| preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) |
| tags = top_xx(preds) |
| other_tags = [] |
| st.header('Inferred tags:') |
| for i, tag_data in enumerate(tags): |
| st.markdown('* ' + tag_data) |