| |
| """Text and Image processor for CASA models using Qwen2.5_VL image encoder""" |
|
|
| from math import ceil |
| from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload |
| from typing import cast as type_cast |
|
|
| import torch |
| import torchvision.transforms.v2 as T |
| from einops import rearrange |
| from PIL import Image |
| from torchvision.transforms import InterpolationMode |
| from torchvision.transforms.functional import to_tensor as pil_to_tensor |
| from torchvision.transforms.v2 import functional as F |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.processing_utils import ProcessorMixin |
|
|
| if TYPE_CHECKING: |
| from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
|
|
|
|
| ImageMessage = TypedDict( |
| "ImageMessage", |
| { |
| "type": Literal["image"], |
| "image": str | Image.Image | None, |
| }, |
| ) |
|
|
| TextMessage = TypedDict( |
| "TextMessage", |
| { |
| "type": Literal["text"], |
| "text": str, |
| }, |
| ) |
|
|
| MessageContent = list[ImageMessage | TextMessage] |
|
|
| Message = TypedDict( |
| "Message", |
| { |
| "role": Literal["system", "user", "assistant"], |
| "content": MessageContent, |
| }, |
| ) |
|
|
| ProcessorInput = list[list[Message]] | list[Message] |
|
|
| __INTERP_NAME_TO_MODE__ = { |
| "nearest": InterpolationMode.NEAREST, |
| "bilinear": InterpolationMode.BILINEAR, |
| "bicubic": InterpolationMode.BICUBIC, |
| "lanczos": InterpolationMode.LANCZOS, |
| } |
|
|
| __INTERP_INT_TO_MODE__ = { |
| 0: InterpolationMode.NEAREST, |
| 2: InterpolationMode.BILINEAR, |
| 3: InterpolationMode.BICUBIC, |
| 4: InterpolationMode.BOX, |
| 5: InterpolationMode.HAMMING, |
| 1: InterpolationMode.LANCZOS, |
| } |
|
|
|
|
| @overload |
| def universal_resize( |
| img: Image.Image, |
| size: tuple[int, int], |
| interpolation: str | InterpolationMode | int = "bilinear", |
| antialias: bool = True, |
| ) -> Image.Image: ... |
| @overload |
| def universal_resize( |
| img: torch.Tensor, |
| size: tuple[int, int], |
| interpolation: str | InterpolationMode | int = "bilinear", |
| antialias: bool = True, |
| ) -> torch.Tensor: ... |
| def universal_resize( |
| img: Image.Image | torch.Tensor, |
| size: tuple[int, int], |
| interpolation: str | InterpolationMode | int = "bilinear", |
| antialias: bool = True, |
| ) -> Image.Image | torch.Tensor: |
| """Resize that works for PIL.Image, CHW tensor, or BCHW tensor""" |
| if isinstance(interpolation, str): |
| interpolation = __INTERP_NAME_TO_MODE__[interpolation] |
| elif isinstance(interpolation, int): |
| interpolation = __INTERP_INT_TO_MODE__[interpolation] |
|
|
| return F.resize( |
| img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias |
| ) |
|
|
|
|
| @overload |
| def convert_to_rgb(img: Image.Image) -> Image.Image: ... |
| @overload |
| def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ... |
| def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor: |
| """Convert any image to RGB in a way that does not throw PIL warning""" |
| if isinstance(img, torch.Tensor): |
| return img |
| if img.mode == "RGB": |
| return img |
| if img.mode == "P": |
| return img.convert("RGBA").convert("RGB") |
| return img.convert("RGB") |
|
|
|
|
| class QwenImageProcessor(BaseImageProcessor): |
| """Resizing for the Qwen2.5VL encoder. Note that the normalization is |
| handled in the image_encoder in the model forward""" |
|
|
| def __init__( |
| self, |
| img_size: int = 448, |
| interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic", |
| max_ratio: int = 10, |
| round_to_patch_size: int = 56, |
| use_fast: bool = True, |
| **kwargs: Any, |
| ) -> None: |
| |
| |
| self._num_target_channels = 588 |
| self._merge_size = 2 |
| self._patch_size = 14 |
| super().__init__( |
| use_fast=use_fast, |
| do_normalize=False, |
| **kwargs, |
| ) |
| self.img_size = img_size |
| self.interpolation = interpolation |
| self.max_ratio = max_ratio |
| self.round_to_patch_size = round_to_patch_size |
|
|
| def resize_transform( |
| self, img: Image.Image | torch.Tensor, img_size: int | None = None |
| ) -> Image.Image | torch.Tensor: |
| if img_size is None: |
| img_size = self.img_size |
| max_area = img_size**2 |
| if isinstance(img, Image.Image): |
| img = convert_to_rgb(img) |
| w_og, h_og = img.size |
| else: |
| h_og, w_og = img.shape[-2:] |
| w, h = w_og, h_og |
|
|
| |
| if self.max_ratio > 0: |
| w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio) |
|
|
| |
| current_area = w * h |
| if current_area > max_area: |
| scale = (max_area / current_area) ** 0.5 |
| w, h = int(w * scale), int(h * scale) |
|
|
| |
| if self.round_to_patch_size > 0: |
| w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size |
| h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size |
|
|
| |
| if w != w_og or h != h_og: |
| img = universal_resize(img, (h, w), self.interpolation) |
| if isinstance(img, torch.Tensor): |
| img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img)) |
| return img |
|
|
| def __process_one__( |
| self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None |
| ) -> torch.Tensor: |
| """Same operation as __process_one_with_processor__ but without going through numpy""" |
| video_or_img = self.resize_transform(video_or_img, img_size) |
| if isinstance(video_or_img, Image.Image): |
| video_or_img = pil_to_tensor(video_or_img) |
| assert isinstance(video_or_img, torch.Tensor) |
| if video_or_img.ndim == 3: |
| video_or_img = video_or_img[None] |
| assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, ( |
| f"Invalid shape {video_or_img.shape}." |
| ) |
| t, c, h, w = video_or_img.shape |
| p = self._patch_size |
| m = self._merge_size |
|
|
| |
| if c == 1: |
| video_or_img = video_or_img.expand((-1, 3, -1, -1)) |
| if c == 4: |
| video_or_img = video_or_img[:, :3] |
| c = video_or_img.shape[1] |
| assert c == 3, "Expecting RGB image in QwenNormalize" |
|
|
| |
| h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p |
| rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m) |
|
|
| video_or_img = rearrange( |
| video_or_img, |
| "t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)", |
| **rearrange_dict, |
| ) |
| assert video_or_img.shape[-1] == self._num_target_channels, ( |
| f"{video_or_img.shape[-1]} != {self._num_target_channels}" |
| ) |
| video_or_img = video_or_img.view((-1, h, w, self._num_target_channels)) |
|
|
| return video_or_img |
|
|
| @overload |
| def process_images( |
| self, image: Image.Image | torch.Tensor, img_size: int | None = None |
| ) -> torch.Tensor: ... |
| @overload |
| def process_images( |
| self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None |
| ) -> list[torch.Tensor]: ... |
| def process_images( |
| self, |
| image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor], |
| img_size: int | None = None, |
| ) -> torch.Tensor | list[torch.Tensor]: |
| if isinstance(image, list): |
| return [self.__process_one__(_x, img_size) for _x in image] |
| return self.__process_one__(image, img_size) |
|
|
|
|
| class ProcessorOutput(dict): |
| input_ids: torch.Tensor |
| attention_mask: torch.Tensor |
| image_embeds_insertion_points: list[torch.Tensor] | None |
| pixel_values: torch.Tensor | list[torch.Tensor] | None |
|
|
| def to( |
| self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16 |
| ) -> "ProcessorOutput": |
| return ProcessorOutput( |
| { |
| "input_ids": self["input_ids"].to(device), |
| "attention_mask": self["attention_mask"].to(device), |
| "image_embeds_insertion_points": self["image_embeds_insertion_points"], |
| "pixel_values": ( |
| self["pixel_values"].to(dtype).to(device) |
| if isinstance(self["pixel_values"], torch.Tensor) |
| else [x.to(dtype).to(device) for x in self["pixel_values"]] |
| if self["pixel_values"] is not None |
| else None |
| ), |
| } |
| ) |
|
|
|
|
| class BaseProcessor(ProcessorMixin): |
| def __init__( |
| self, |
| tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer", |
| pre_image_tokens: tuple[int, ...] = (), |
| post_image_tokens: tuple[int, ...] = (), |
| system_start_tokens: tuple[int, ...] = (), |
| system_end_tokens: tuple[int, ...] = (), |
| user_start_tokens: tuple[int, ...] = (), |
| user_end_tokens: tuple[int, ...] = (), |
| asst_start_tokens: tuple[int, ...] = (), |
| asst_end_tokens: tuple[int, ...] = (), |
| allow_system_prompt: bool = True, |
| pad_token: int = 0, |
| bos_token: int | None = None, |
| ) -> None: |
| self.pre_image_tokens = list(pre_image_tokens) |
| self.post_image_tokens = list(post_image_tokens) |
| self.system_start_tokens = list(system_start_tokens) |
| self.system_end_tokens = list(system_end_tokens) |
| self.user_start_tokens = list(user_start_tokens) |
| self.user_end_tokens = list(user_end_tokens) |
| self.asst_start_tokens = list(asst_start_tokens) |
| self.asst_end_tokens = list(asst_end_tokens) |
| self._allow_system_prompt = allow_system_prompt |
| self.tokenizer = tokenizer |
| self._image_processor = None |
| self._pad_token = pad_token |
| self.bos_token = bos_token |
|
|
| @property |
| def image_processor(self) -> QwenImageProcessor: |
| assert self._image_processor is not None |
| return self._image_processor |
|
|
| def _process_content( |
| self, |
| message_content: MessageContent, |
| role: Literal["system", "user", "assistant"], |
| tokenized_messages: list[torch.Tensor], |
| insertion_points: list[int], |
| image_list: list[torch.Tensor | None], |
| token_count: int, |
| img_size: int | None = None, |
| **kwargs: Any, |
| ) -> int: |
| mapping = { |
| "user": (self.user_start_tokens, self.user_end_tokens), |
| "assistant": (self.asst_start_tokens, self.asst_end_tokens), |
| "system": (self.system_start_tokens, self.system_end_tokens), |
| } |
| if role.lower() not in mapping: |
| raise ValueError(f"Unknown role '{role}' encountered in messages.") |
| start_tokens, end_tokens = mapping[role.lower()] |
| |
| if start_tokens: |
| tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long)) |
| token_count += len(start_tokens) |
| |
| for part in message_content: |
| elt_type = part["type"] |
| if elt_type == "image": |
| part = cast(ImageMessage, part) |
| self._process_image_message( |
| part, |
| tokenized_messages, |
| image_list, |
| img_size=img_size, |
| ) |
| token_count += len(self.pre_image_tokens) |
| insertion_points.append(token_count) |
| token_count += len(self.post_image_tokens) |
| else: |
| part = cast(TextMessage, part) |
| self._process_text_message( |
| part["text"], |
| role=role, |
| token_list=tokenized_messages, |
| **kwargs, |
| ) |
| token_count += tokenized_messages[-1].size(0) |
| |
| if end_tokens: |
| tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long)) |
| token_count += len(end_tokens) |
| return token_count |
|
|
| def _process_text_message( |
| self, |
| message: str, |
| role: Literal["system", "user", "assistant"], |
| token_list: list[torch.Tensor], |
| **kwargs: Any, |
| ) -> None: |
| if role.lower() == "system" and not self._allow_system_prompt: |
| raise ValueError("System prompts are not allowed in this tokenizer configuration.") |
| tokens = self.tokenizer.encode( |
| message, add_special_tokens=False, return_tensors="pt", **kwargs |
| ) |
| tokens = cast(torch.Tensor, tokens) |
| token_list.append(tokens.flatten().to(torch.long)) |
|
|
| def _process_image_message( |
| self, |
| message: ImageMessage, |
| token_list: list[torch.Tensor], |
| image_list: list[torch.Tensor | None], |
| img_size: int | None = None, |
| ) -> None: |
| img = message["image"] |
| if img is None: |
| image_list.append(None) |
| else: |
| image_list.append( |
| self.image_processor.process_images( |
| self._load_image(img), img_size=img_size |
| ).squeeze(0) |
| ) |
| if self.pre_image_tokens: |
| token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long)) |
|
|
| if self.post_image_tokens: |
| token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long)) |
|
|
| def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image: |
| if isinstance(image_path_or_image, str): |
| return Image.open(image_path_or_image).convert("RGB") |
| return image_path_or_image |
|
|
| def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor: |
| return torch.nn.functional.pad( |
| tokens, |
| (0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0), |
| value=pad_value, |
| ) |
|
|
| def pad_tokenized_messages( |
| self, |
| tokenized_messages_batch: list[torch.Tensor], |
| image_insertion_points_batch: list[torch.Tensor] | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: |
| max_len = max(len(x) for x in tokenized_messages_batch) |
| if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left": |
| image_insertion_points_batch = [ |
| x + max_len - len(tokenized_messages_batch[idx]) |
| for idx, x in enumerate(image_insertion_points_batch) |
| ] |
| input_ids = torch.stack( |
| [ |
| self._maybe_pad(s, max_len - s.size(0), self._pad_token) |
| for s in tokenized_messages_batch |
| ], |
| dim=0, |
| ) |
| attention_mask = torch.stack( |
| [ |
| self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0) |
| for s in tokenized_messages_batch |
| ], |
| dim=0, |
| ) |
| return input_ids, attention_mask, image_insertion_points_batch |
|
|
| def tokenize_messages( |
| self, |
| messages: ProcessorInput, |
| suppress_bos_token: bool = False, |
| **kwargs: Any, |
| ) -> ProcessorOutput | None: |
| """Tokenize a batch of messages into token IDs suitable for Helium1 CASA model. |
| |
| Args: |
| messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages), |
| where each message is a list of dictionaries with 'role' and 'content' keys. |
| continue_final_message (bool, optional): If True, the final message in each list will not have an end token added. |
| Defaults to False. |
| suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added. |
| Defaults to False. |
| **kwargs: Additional keyword arguments passed to the underlying encode method. |
| """ |
| if not messages: |
| return None |
| if isinstance(messages[0], dict): |
| messages = [messages] |
|
|
| messages = cast(list[list[Message]], messages) |
| image_insertion_points_batch = [] |
| tokenized_messages_batch = [] |
| image_list: list[torch.Tensor | None] = [] |
| for msgs in messages: |
| |
| |
| |
| |
| tokenized_messages = [] |
| if not suppress_bos_token and self.bos_token is not None: |
| tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long)) |
| insertion_points = [] |
| token_count = 0 |
| for msg in msgs: |
| token_count = self._process_content( |
| msg["content"], |
| role=msg["role"], |
| tokenized_messages=tokenized_messages, |
| insertion_points=insertion_points, |
| image_list=image_list, |
| token_count=token_count, |
| **kwargs, |
| ) |
| tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long)) |
| image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long)) |
|
|
| if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant": |
| |
| end_token_len = len(self.asst_end_tokens) |
| tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len] |
| if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user": |
| |
| end_token_len = len(self.asst_end_tokens) |
| tokenized_messages_batch[-1] = torch.cat( |
| [ |
| tokenized_messages_batch[-1], |
| torch.Tensor(self.asst_start_tokens).to(torch.long), |
| ] |
| ) |
|
|
| input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages( |
| tokenized_messages_batch, image_insertion_points_batch |
| ) |
|
|
| if image_list: |
| assert sum(img is None for img in image_list) % len(image_list) == 0, ( |
| "Either all or no image must be None." |
| ) |
| pixel_values: None | torch.Tensor | list[torch.Tensor] |
| if image_list[0] is None: |
| pixel_values = None |
| else: |
| pixel_values = cast(list[torch.Tensor], image_list) |
| return ProcessorOutput( |
| input_ids=input_ids, |
| image_embeds_insertion_points=image_embeds_insertion_points, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| ) |
|
|