""" Custom GroundingDINO model class for transformers compatibility. """ import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.configuration_utils import PretrainedConfig class GroundingDINOConfig(PretrainedConfig): """Configuration class for GroundingDINO.""" model_type = "groundingdino" def __init__( self, num_classes=1180, num_queries=900, hidden_dim=256, num_feature_levels=4, nheads=8, enc_layers=6, dec_layers=6, dim_feedforward=2048, dropout=0.0, max_text_len=256, text_encoder_type="bert-base-uncased", backbone="swin_T_224_1k", position_embedding="sine", **kwargs ): super().__init__(**kwargs) self.num_classes = num_classes self.num_queries = num_queries self.hidden_dim = hidden_dim self.num_feature_levels = num_feature_levels self.nheads = nheads self.enc_layers = enc_layers self.dec_layers = dec_layers self.dim_feedforward = dim_feedforward self.dropout = dropout self.max_text_len = max_text_len self.text_encoder_type = text_encoder_type self.backbone = backbone self.position_embedding = position_embedding class GroundingDINOModel(PreTrainedModel): """GroundingDINO model for transformers.""" config_class = GroundingDINOConfig def __init__(self, config): super().__init__(config) self.config = config # This is a placeholder - in practice, you would load the actual model architecture # For now, we'll create a simple wrapper self.model = None def forward(self, images, text_prompts=None, return_dict=True): """ Forward pass of the model. Args: images: Input images tensor text_prompts: Text prompts for grounding return_dict: Whether to return a dictionary Returns: Model outputs """ if self.model is None: raise NotImplementedError( "Model architecture not implemented. " "Please use the original GroundingDINO implementation for inference." ) outputs = self.model(images, captions=text_prompts) if return_dict: return { "logits": outputs.get("pred_logits", torch.tensor([])), "boxes": outputs.get("pred_boxes", torch.tensor([])), "last_hidden_state": outputs.get("last_hidden_state", torch.tensor([])) } else: return outputs