| """Backbone components for Mimi models - shared attention transformers.""" |
|
|
| import math |
| from typing import Optional, Union |
|
|
| import torch |
| from torch import nn |
|
|
| from transformers.cache_utils import Cache, DynamicCache, StaticCache |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.utils import logging |
|
|
| try: |
| from .configuration_mimi import MimiConfig |
| from .modeling_mimi_clean import ( |
| MimiAttention, |
| MimiMLP, |
| MimiLayerScale, |
| MimiRotaryEmbedding, |
| apply_rotary_pos_emb, |
| MIMI_ATTENTION_CLASSES |
| ) |
| except ImportError: |
| from configuration_mimi import MimiConfig |
| from modeling_mimi_clean import ( |
| MimiAttention, |
| MimiMLP, |
| MimiLayerScale, |
| MimiRotaryEmbedding, |
| apply_rotary_pos_emb, |
| MIMI_ATTENTION_CLASSES |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CausalAttentionTransformer(nn.Module): |
| """ |
| Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers. |
| Each layer is a [`MimiTransformerLayer`] with self-attention only. |
| |
| This is a standard decoder-only transformer architecture for causal language modeling. |
| |
| Args: |
| config: MimiConfig |
| """ |
|
|
| def __init__(self, config: MimiConfig): |
| super().__init__() |
| |
| self.layers = nn.ModuleList( |
| [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self._attn_implementation = config._attn_implementation |
| self.gradient_checkpointing = False |
| self.config = config |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[tuple, BaseModelOutputWithPast]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Input embeddings or hidden states from previous layer |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.max_position_embeddings - 1]`. |
| |
| [What are position IDs?](../glossary#position-ids) |
| past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up |
| sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous |
| stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| |
| Two formats are allowed: |
| - a [`~cache_utils.Cache`] instance; |
| - Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
| cache format. |
| |
| The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
| legacy cache format will be returned. |
| |
| If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape |
| `(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if self.gradient_checkpointing and self.training and use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| ) |
| use_cache = False |
|
|
| if use_cache and not isinstance(past_key_values, Cache): |
| if past_key_values is None: |
| past_key_values = DynamicCache() |
| else: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| logger.warning_once( |
| "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " |
| "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " |
| "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" |
| ) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=hidden_states, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class MimiTransformerLayer(GradientCheckpointingLayer): |
| def __init__(self, config: MimiConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
| self.mlp = MimiMLP(config) |
| self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| self.self_attn_layer_scale = MimiLayerScale(config) |
| self.mlp_layer_scale = MimiLayerScale(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): |
| attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| query_sequence_length, key_sequence_length)` if default attention is used. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence |
| kwargs (`dict`, *optional*): |
| Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
| into the model |
| """ |
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + self.mlp_layer_scale(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| class CrossAttention(nn.Module): |
| """ |
| Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs. |
| Queries come from decoder, keys and values come from encoder. |
| Supports monotonic attention where each query can only attend to a progressive subset of keys. |
| """ |
|
|
| def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| if layer_idx is None: |
| logger.warning_once( |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| "when creating this class." |
| ) |
|
|
| self.attention_dropout = config.attention_dropout |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
| self.is_causal = True |
| self.scaling = 1 / math.sqrt(config.head_dim) |
|
|
| if self.hidden_size % self.num_heads != 0: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
|
|
| |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
| |
| |
| self.rotary_emb = MimiRotaryEmbedding(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| bsz, q_len, _ = hidden_states.size() |
| _, kv_len, _ = encoder_hidden_states.size() |
|
|
| |
| query_states = self.q_proj(hidden_states) |
| |
| key_states = self.k_proj(encoder_hidden_states) |
| value_states = self.v_proj(encoder_hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| if position_ids is not None: |
| cos, sin = self.rotary_emb(value_states, position_ids) |
| query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin if position_ids is not None else None, |
| "cos": cos if position_ids is not None else None, |
| "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling |
|
|
| |
| if alignment_chunk_sizes is not None: |
| monotonic_mask = _create_monotonic_attention_mask( |
| alignment_chunk_sizes=alignment_chunk_sizes, |
| query_length=q_len, |
| key_length=kv_len, |
| device=attn_weights.device, |
| dtype=attn_weights.dtype, |
| ) |
| attn_weights = attn_weights + monotonic_mask |
|
|
| |
| if attention_mask is not None: |
| |
| |
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, -1) |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
| class CrossAttentionLayer(GradientCheckpointingLayer): |
| """ |
| Cross-attention transformer layer with layer normalization and MLP. |
| Includes self-attention on decoder, cross-attention to encoder, and feed-forward. |
| """ |
| |
| def __init__(self, config: MimiConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| |
| self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| |
| |
| self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx) |
|
|
| self.mlp = MimiMLP(config) |
| |
| |
| self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| |
| |
| self.self_attn_layer_scale = MimiLayerScale(config) |
| self.cross_attn_layer_scale = MimiLayerScale(config) |
| self.mlp_layer_scale = MimiLayerScale(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Cache] = None, |
| cross_past_key_value: Optional[Cache] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)` |
| encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention |
| encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions |
| position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
| past_key_value (`Cache`, *optional*): cached self-attention states |
| cross_past_key_value (`Cache`, *optional*): cached cross-attention states |
| output_attentions (`bool`, *optional*): whether to return attention weights |
| use_cache (`bool`, *optional*): whether to use caching |
| cache_position (`torch.LongTensor`, *optional*): cache positions |
| """ |
| residual = hidden_states |
|
|
| |
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| |
| hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_ids=position_ids, |
| past_key_value=cross_past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| alignment_chunk_sizes=alignment_chunk_sizes, |
| ) |
| hidden_states = residual + self.cross_attn_layer_scale(hidden_states) |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_cross_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + self.mlp_layer_scale(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights, cross_attn_weights) |
|
|
| if use_cache: |
| outputs += (present_key_value, cross_present_key_value) |
|
|
| return outputs |
|
|
|
|
| class CrossAttentionTransformer(nn.Module): |
| """ |
| Cross-attention transformer consisting of N cross-attention layers. |
| Each layer performs self-attention on decoder and cross-attention to encoder. |
| |
| Args: |
| config: MimiConfig |
| """ |
|
|
| def __init__(self, config: MimiConfig): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList( |
| [CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self._attn_implementation = config._attn_implementation |
|
|
| self.gradient_checkpointing = False |
| self.config = config |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| ) -> Union[tuple, BaseModelOutputWithPast]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)` |
| encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)` |
| attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention |
| encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions |
| position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
| past_key_values (`Cache` or `list`, *optional*): cached self-attention states |
| cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states |
| use_cache (`bool`, *optional*): whether to use caching |
| output_attentions (`bool`, *optional*): whether to return attention weights |
| output_hidden_states (`bool`, *optional*): whether to return hidden states |
| return_dict (`bool`, *optional*): whether to return ModelOutput |
| cache_position (`torch.LongTensor`, *optional*): cache positions |
| alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying |
| how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention |
| where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if use_cache and past_key_values is None: |
| logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.") |
| past_key_values = DynamicCache() |
| |
| if use_cache and cross_past_key_values is None: |
| cross_past_key_values = DynamicCache() |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| |
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=hidden_states, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| all_cross_attns = () if output_attentions else None |
| next_decoder_cache = None |
| next_cross_cache = None |
|
|
| for layer_idx, decoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| |
| layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None |
| layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| encoder_hidden_states, |
| causal_mask, |
| encoder_attention_mask, |
| position_ids, |
| layer_past_key_value, |
| layer_cross_past_key_value, |
| output_attentions, |
| use_cache, |
| cache_position, |
| alignment_chunk_sizes, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=causal_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| position_ids=position_ids, |
| past_key_value=layer_past_key_value, |
| cross_past_key_value=layer_cross_past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| alignment_chunk_sizes=alignment_chunk_sizes, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| |
| if output_attentions: |
| next_decoder_cache = layer_outputs[3] |
| next_cross_cache = layer_outputs[4] |
| else: |
| next_decoder_cache = layer_outputs[1] |
| next_cross_cache = layer_outputs[2] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
| all_cross_attns += (layer_outputs[2],) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| next_cross_cache = next_cross_cache if use_cache else None |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def _create_monotonic_attention_mask( |
| alignment_chunk_sizes: torch.Tensor, |
| query_length: int, |
| key_length: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| """ |
| Create a monotonic attention mask where each query can only attend to a progressive subset of keys. |
| |
| Args: |
| alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents |
| how many keys the corresponding query can attend to cumulatively. |
| query_length: Number of queries (text tokens) |
| key_length: Number of keys (speech features) |
| device: Device to create the mask on |
| dtype: Data type for the mask |
| |
| Returns: |
| Attention mask of shape (batch_size, 1, query_length, key_length) where |
| -inf masks out invalid positions, 0.0 allows attention. |
| """ |
| batch_size = alignment_chunk_sizes.shape[0] |
| |
| |
| cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) |
| |
| |
| cumulative_positions = torch.clamp(cumulative_positions, max=key_length) |
| |
| |
| key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) |
| |
| |
| cumulative_positions = cumulative_positions.unsqueeze(2) |
| |
| |
| mask = key_positions < cumulative_positions |
| |
| |
| attention_mask = torch.where(mask, 0.0, float('-inf')) |
| |
| |
| attention_mask = attention_mask.unsqueeze(1) |
| |
| return attention_mask.to(dtype) |
|
|
|
|
|
|
| __all__ = [ |
| "CausalAttentionTransformer", |
| "MimiTransformerLayer", |
| "CrossAttention", |
| "CrossAttentionLayer", |
| "CrossAttentionTransformer", |
| "_create_monotonic_attention_mask", |
| ] |
|
|