Spaces:
Runtime error
Runtime error
| from typing import Optional, List, Union, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import random | |
| from torch.nn import CrossEntropyLoss | |
| from transformers.utils import ( | |
| add_start_docstrings_to_model_forward, | |
| add_end_docstrings, | |
| replace_return_docstrings | |
| ) | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers.models.bart.modeling_bart import ( | |
| BartForConditionalGeneration, | |
| _expand_mask, logger, | |
| shift_tokens_right, | |
| BartPretrainedModel, | |
| BART_INPUTS_DOCSTRING, | |
| _CONFIG_FOR_DOC, | |
| BART_GENERATION_EXAMPLE, | |
| BartModel, | |
| BartDecoder | |
| ) | |
| from .adapter import Adapter | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPastAndCrossAttentions, | |
| Seq2SeqModelOutput, | |
| BaseModelOutput, | |
| Seq2SeqLMOutput | |
| ) | |
| class KeyBartAdapter(BartForConditionalGeneration): | |
| def __init__(self,adapter_hid_dim:int) -> None: | |
| keyBart = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") | |
| self.__fix_weights__(keyBart) | |
| super().__init__(keyBart.model.config) | |
| self.lm_head = keyBart.lm_head | |
| self.model = BartPlus(keyBart, adapter_hid_dim) | |
| self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) | |
| def __fix_weights__(self,keyBart:BartForConditionalGeneration): | |
| for i in keyBart.model.parameters(): | |
| i.requires_grad = False | |
| for i in keyBart.lm_head.parameters(): | |
| i.requires_grad = False | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, Seq2SeqLMOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| Returns: | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| if use_cache: | |
| logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") | |
| use_cache = False | |
| if decoder_input_ids is None and decoder_inputs_embeds is None: | |
| decoder_input_ids = shift_tokens_right( | |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id | |
| ) | |
| outputs = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias | |
| masked_lm_loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) | |
| if not return_dict: | |
| output = (lm_logits,) + outputs[1:] | |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
| return Seq2SeqLMOutput( | |
| loss=masked_lm_loss, | |
| logits=lm_logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| cross_attentions=outputs.cross_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| class BartDecoderPlus(BartDecoder): | |
| def __init__(self,keyBart:BartForConditionalGeneration,adapter_hid_dim: int) -> None: | |
| super().__init__(keyBart.get_decoder().config) | |
| self.decoder = keyBart.model.decoder | |
| self.adapters = nn.ModuleList([Adapter(self.decoder.config.d_model,adapter_hid_dim) for _ in range(len(self.decoder.layers))]) | |
| self.config = self.decoder.config | |
| self.dropout = self.decoder.dropout | |
| self.layerdrop = self.decoder.layerdrop | |
| self.padding_idx = self.decoder.padding_idx | |
| self.max_target_positions = self.decoder.max_target_positions | |
| self.embed_scale = self.decoder.embed_scale | |
| self.embed_tokens = self.decoder.embed_tokens | |
| self.embed_positions = self.decoder.embed_positions | |
| self.layers = self.decoder.layers | |
| self.layernorm_embedding = self.decoder.layernorm_embedding | |
| self.gradient_checkpointing = self.decoder.gradient_checkpointing | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # retrieve input_ids and inputs_embeds | |
| if input_ids is not None and inputs_embeds is not None: | |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") | |
| elif input_ids is not None: | |
| input = input_ids | |
| input_shape = input.shape | |
| input_ids = input_ids.view(-1, input_shape[-1]) | |
| elif inputs_embeds is not None: | |
| input_shape = inputs_embeds.size()[:-1] | |
| input = inputs_embeds[:, :, -1] | |
| else: | |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | |
| # past_key_values_length | |
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 | |
| if inputs_embeds is None: | |
| inputs_embeds = self.decoder.embed_tokens(input) * self.decoder.embed_scale | |
| attention_mask = self.decoder._prepare_decoder_attention_mask( | |
| attention_mask, input_shape, inputs_embeds, past_key_values_length | |
| ) | |
| # expand encoder attention mask | |
| if encoder_hidden_states is not None and encoder_attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | |
| # embed positions | |
| positions = self.decoder.embed_positions(input, past_key_values_length) | |
| hidden_states = inputs_embeds + positions | |
| hidden_states = self.decoder.layernorm_embedding(hidden_states) | |
| hidden_states = nn.functional.dropout(hidden_states, p=self.decoder.dropout, training=self.decoder.training) | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None | |
| next_decoder_cache = () if use_cache else None | |
| # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired | |
| for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): | |
| if attn_mask is not None: | |
| if attn_mask.size()[0] != (len(self.decoder.layers)): | |
| raise ValueError( | |
| f"The `{mask_name}` should be specified for {len(self.decoder.layers)} layers, but it is for" | |
| f" {head_mask.size()[0]}." | |
| ) | |
| for idx, decoder_layer in enumerate(self.decoder.layers): | |
| # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| dropout_probability = random.uniform(0, 1) | |
| if self.decoder.training and (dropout_probability < self.decoder.layerdrop): | |
| continue | |
| past_key_value = past_key_values[idx] if past_key_values is not None else None | |
| if self.decoder.gradient_checkpointing and self.decoder.training: | |
| if use_cache: | |
| logger.warning( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
| ) | |
| use_cache = False | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| # None for past_key_value | |
| return module(*inputs, output_attentions, use_cache) | |
| return custom_forward | |
| layer_outputs = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(decoder_layer), | |
| hidden_states, | |
| attention_mask, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| head_mask[idx] if head_mask is not None else None, | |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, | |
| None, | |
| ) | |
| else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), | |
| cross_attn_layer_head_mask=( | |
| cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None | |
| ), | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| ######################### new ################################# | |
| hidden_states = self.adapters[idx](hidden_states) | |
| ######################### new ################################# | |
| if use_cache: | |
| next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| if encoder_hidden_states is not None: | |
| all_cross_attentions += (layer_outputs[2],) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| next_cache = next_decoder_cache if use_cache else None | |
| if not return_dict: | |
| return tuple( | |
| v | |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] | |
| if v is not None | |
| ) | |
| return BaseModelOutputWithPastAndCrossAttentions( | |
| last_hidden_state=hidden_states, | |
| past_key_values=next_cache, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| cross_attentions=all_cross_attentions, | |
| ) | |
| class BartPlus(BartModel): | |
| def __init__(self,keyBart: BartForConditionalGeneration, adapter_hid_dim: int ) -> None: | |
| super().__init__(keyBart.model.config) | |
| self.config = keyBart.model.config | |
| # self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) | |
| self.shared = keyBart.model.shared | |
| #self.encoder = BartEncoder(config, self.shared) | |
| self.encoder = keyBart.model.encoder | |
| #self.decoder = BartDecoder(config, self.shared) | |
| #self.decoder = keyBart.model.decoder | |
| self.decoder = BartDecoderPlus(keyBart,adapter_hid_dim=adapter_hid_dim) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[List[torch.FloatTensor]] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, Seq2SeqModelOutput]: | |
| # different to other models, Bart automatically creates decoder_input_ids from | |
| # input_ids if no decoder_input_ids are provided | |
| if decoder_input_ids is None and decoder_inputs_embeds is None: | |
| if input_ids is None: | |
| raise ValueError( | |
| "If no `decoder_input_ids` or `decoder_inputs_embeds` are " | |
| "passed, `input_ids` cannot be `None`. Please pass either " | |
| "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." | |
| ) | |
| decoder_input_ids = shift_tokens_right( | |
| input_ids, self.config.pad_token_id, self.config.decoder_start_token_id | |
| ) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if encoder_outputs is None: | |
| encoder_outputs = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True | |
| elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
| encoder_outputs = BaseModelOutput( | |
| last_hidden_state=encoder_outputs[0], | |
| hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, | |
| attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
| ) | |
| # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=decoder_attention_mask, | |
| encoder_hidden_states=encoder_outputs[0], | |
| encoder_attention_mask=attention_mask, | |
| head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=decoder_inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if not return_dict: | |
| return decoder_outputs + encoder_outputs | |
| return Seq2SeqModelOutput( | |
| last_hidden_state=decoder_outputs.last_hidden_state, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |