from typing import List, Optional, Type, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import LongTensor, Tensor from transformers.cache_utils import Cache from transformers.configuration_utils import PretrainedConfig from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from transformers.models.siglip.modeling_siglip import SiglipVisionModel from .configuration_vila import VILAConfig class DownSample3x3BlockFix(nn.Module): def forward(self, x: Tensor) -> Tensor: """ Args: x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). Returns: The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9). """ batch_size, sequence_length, hidden_size = x.shape feat_size = int(sequence_length**0.5) if feat_size**2 != sequence_length: raise ValueError(f"Cannot take square root: sequence_length {sequence_length} is not a perfect square") features = x.reshape(batch_size, feat_size, feat_size, hidden_size) pad_after = (3 - feat_size % 3) % 3 if pad_after > 0: features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) feat_size = feat_size + pad_after features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size) features = features.permute(0, 1, 3, 2, 4, 5).contiguous() features = features.reshape(batch_size, -1, 9 * hidden_size) return features class MultimodalProjector(nn.Module): layers: nn.Sequential def __init__( self, config: VILAConfig, *args, **kwargs, ): super().__init__(*args, **kwargs) if config.mm_projector_type == "mlp_downsample_3x3_fix": self.layers = nn.Sequential( DownSample3x3BlockFix(), nn.LayerNorm(config.mm_hidden_size * 9), nn.Linear( config.mm_hidden_size * 9, config.mm_hidden_size * 3, ), nn.GELU(), nn.LayerNorm(config.vision_config.hidden_size * 3), nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size), ) else: raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}") self.layers.type(config.torch_dtype) @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype def forward(self, x: Tensor) -> Tensor: """ Args: x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). Returns: The output tensor of shape (batch_size, image_pad_len, hidden_size). """ return self.layers(x.to(device=self.device, dtype=self.dtype)) class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class: Type[PretrainedConfig] = VILAConfig base_model_prefix: str = "llm" _auto_class = "AutoModelForImageTextToText" _no_split_modules: List[str] = ["MultimodalProjector"] _skip_keys_device_placement: List[str] = ["past_key_values"] supports_gradient_checkpointing = True _supports_flash_attn_2: bool = True _supports_sdpa = True config: VILAConfig llm: Qwen2ForCausalLM mm_projector: MultimodalProjector vision_tower: SiglipVisionModel def __init__( self, config: VILAConfig, *args, **kwargs, ): super().__init__(config, *args, **kwargs) self.llm = Qwen2ForCausalLM._from_config(config.text_config, *args, **kwargs) self.mm_projector = MultimodalProjector(config) self.vision_tower = SiglipVisionModel._from_config(config.vision_config, *args, **kwargs) self.post_init() def forward( self, *, attention_mask: Optional[Tensor] = None, input_ids: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, past_key_values: Optional[Cache] = None, pixel_values: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, logits_to_keep: Union[int, Tensor] = 0, **kwargs, ) -> CausalLMOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") if past_key_values is None: # Prefill if input_ids is not None: inputs_embeds = self._embed(input_ids, pixel_values) input_ids = None outputs = self.llm.__call__( attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None), input_ids=(input_ids.to(device=self.llm.device) if input_ids is not None else None), inputs_embeds=( inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype) if inputs_embeds is not None else None ), past_key_values=past_key_values, position_ids=(position_ids.to(device=self.llm.device) if position_ids is not None else None), logits_to_keep=logits_to_keep, **kwargs, ) return outputs def get_output_embeddings(self) -> nn.Module: return self.llm.get_output_embeddings() def _embed( self, input_ids: Tensor, pixel_values: Optional[Tensor], ) -> Tensor: """Gets the embedding of the input ids and pixel values. Args: input_ids: The input ids. pixel_values: The pixel values. Returns: The embedding of the input ids and pixel values. """ if torch.any(input_ids == self.config.video_token_id): raise ValueError("Video token ids should not be present in the input ids.") image_token_mask = input_ids == self.config.image_token_id text_embedding: Tensor = self.llm.get_input_embeddings().__call__(input_ids * ~image_token_mask) if pixel_values is None: return text_embedding vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__( pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype), output_hidden_states=True, ) mm_projector_input = self._vision_tower_output_to_mm_projector_input(vision_tower_output) image_embedding: Tensor = self.mm_projector.__call__( mm_projector_input.to(device=self.mm_projector.device, dtype=self.mm_projector.dtype) ) image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1]) text_embedding.masked_scatter_( image_token_mask.to(device=text_embedding.device, dtype=torch.bool).unsqueeze(-1), image_embedding.to(device=text_embedding.device, dtype=text_embedding.dtype).flatten(), ) return text_embedding def _vision_tower_output_to_mm_projector_input( self, vision_tower_output: BaseModelOutputWithPooling, ) -> Tensor: assert vision_tower_output.hidden_states is not None selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer] if self.config.mm_vision_select_feature == "cls_patch": return selected_layer_hidden_states else: raise NotImplementedError(f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}")