prithivMLmods commited on
Commit
7dc9ea8
·
verified ·
1 Parent(s): d8886fd

update app [.]

Browse files
Files changed (1) hide show
  1. app.py +73 -28
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import random
3
  import uuid
4
  import json
@@ -16,17 +17,7 @@ import torch
16
  import numpy as np
17
  from PIL import Image, ImageDraw, ImageOps
18
  import requests
19
-
20
- # Import spaces if available, otherwise mock it
21
- try:
22
- import spaces
23
- except ImportError:
24
- class spaces:
25
- @staticmethod
26
- def GPU(func):
27
- def wrapper(*args, **kwargs):
28
- return func(*args, **kwargs)
29
- return wrapper
30
 
31
  from transformers import (
32
  AutoModel,
@@ -35,7 +26,8 @@ from transformers import (
35
  AutoProcessor,
36
  TextIteratorStreamer,
37
  HunYuanVLForConditionalGeneration,
38
- Qwen2_5_VLForConditionalGeneration,
 
39
  )
40
  from gradio.themes import Soft
41
  from gradio.themes.utils import colors, fonts, sizes
@@ -52,7 +44,6 @@ if torch.cuda.is_available():
52
  print("current device:", torch.cuda.current_device())
53
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
54
 
55
- # --- Theme Definition ---
56
  colors.steel_blue = colors.Color(
57
  name="steel_blue",
58
  c50="#EBF3F8",
@@ -148,7 +139,7 @@ print(f"Loading {MODEL_HUNYUAN}...")
148
  processor_hy = AutoProcessor.from_pretrained(MODEL_HUNYUAN, use_fast=False)
149
  model_hy = HunYuanVLForConditionalGeneration.from_pretrained(
150
  MODEL_HUNYUAN,
151
- attn_implementation="eager", # Use eager to avoid SDPA issues if old torch
152
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
153
  device_map="auto"
154
  ).eval()
@@ -161,12 +152,40 @@ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
161
  MODEL_ID_X,
162
  trust_remote_code=True,
163
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
164
- device_map="auto" # or .to(device)
165
  ).eval()
166
 
167
- print("✅ All models loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- # --- Helper Functions ---
170
 
171
  def clean_repeated_substrings(text):
172
  """Clean repeated substrings in text (for Hunyuan)"""
@@ -193,8 +212,6 @@ def find_result_image(path):
193
  print(f"Error opening result image: {e}")
194
  return None
195
 
196
- # --- Main Inference Logic ---
197
-
198
  @spaces.GPU
199
  def run_model(
200
  model_choice,
@@ -359,7 +376,6 @@ def run_model(
359
  }
360
  ]
361
 
362
- # Prepare inputs for Qwen2.5-VL based architecture
363
  text = processor_x.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
364
 
365
  inputs = processor_x(
@@ -388,7 +404,33 @@ def run_model(
388
  buffer += new_text.replace("<|im_end|>", "")
389
  yield buffer, None
390
 
391
- # --- Gradio UI ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  image_examples = [
394
  ["examples/1.jpg"],
@@ -398,13 +440,19 @@ image_examples = [
398
 
399
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
400
  gr.Markdown("# **Super-OCRs-Demo**", elem_id="main-title")
401
- gr.Markdown("Compare DeepSeek-OCR, Dots.OCR, HunyuanOCR, and Nanonets-OCR2-3B in one space.")
402
 
403
  with gr.Row():
404
  with gr.Column(scale=1):
405
  # Global Inputs
406
  model_choice = gr.Dropdown(
407
- choices=["HunyuanOCR", "DeepSeek-OCR-Latest-BF16.I64", "Dots.OCR-Latest-BF16", "Nanonets-OCR2-3B"],
 
 
 
 
 
 
408
  label="Select Model",
409
  value="DeepSeek-OCR-Latest-BF16.I64"
410
  )
@@ -414,7 +462,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
414
  with gr.Group(visible=True) as ds_group:
415
  ds_model_size = gr.Dropdown(
416
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
417
- value="Large", label="DeepSeek Resolution"
418
  )
419
  ds_task_type = gr.Dropdown(
420
  choices=["Free OCR", "Convert to Markdown", "Parse Figure", "Locate Object by Reference"],
@@ -422,9 +470,8 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
422
  )
423
  ds_ref_text = gr.Textbox(label="Reference Text (for 'Locate' task only)", placeholder="e.g., the title, red car...", visible=False)
424
 
425
- # General Prompt (for Dots/Hunyuan/Nanonets)
426
  with gr.Group(visible=False) as prompt_group:
427
- custom_prompt = gr.Textbox(label="Custom Query / Prompt", placeholder="Extract text...", lines=2, value="OCR the content precisely")
428
 
429
  with gr.Accordion("Advanced Settings", open=False):
430
  max_new_tokens = gr.Slider(minimum=128, maximum=8192, value=2048, step=128, label="Max New Tokens")
@@ -440,8 +487,6 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
440
  output_text = gr.Textbox(label="Recognized Text / Markdown", lines=15, show_copy_button=True)
441
  output_image = gr.Image(label="Visual Grounding Result (DeepSeek Only)", type="pil")
442
 
443
- # --- UI Event Logic ---
444
-
445
  def update_visibility(model):
446
  is_ds = (model == "DeepSeek-OCR-Latest-BF16.I64")
447
  return gr.Group(visible=is_ds), gr.Group(visible=not is_ds)
 
1
  import os
2
+ import sys
3
  import random
4
  import uuid
5
  import json
 
17
  import numpy as np
18
  from PIL import Image, ImageDraw, ImageOps
19
  import requests
20
+ from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
21
 
22
  from transformers import (
23
  AutoModel,
 
26
  AutoProcessor,
27
  TextIteratorStreamer,
28
  HunYuanVLForConditionalGeneration,
29
+ Qwen2_5_VLForConditionalGeneration,
30
+ GenerationConfig
31
  )
32
  from gradio.themes import Soft
33
  from gradio.themes.utils import colors, fonts, sizes
 
44
  print("current device:", torch.cuda.current_device())
45
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
46
 
 
47
  colors.steel_blue = colors.Color(
48
  name="steel_blue",
49
  c50="#EBF3F8",
 
139
  processor_hy = AutoProcessor.from_pretrained(MODEL_HUNYUAN, use_fast=False)
140
  model_hy = HunYuanVLForConditionalGeneration.from_pretrained(
141
  MODEL_HUNYUAN,
142
+ attn_implementation="eager",
143
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
144
  device_map="auto"
145
  ).eval()
 
152
  MODEL_ID_X,
153
  trust_remote_code=True,
154
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
155
+ device_map="auto"
156
  ).eval()
157
 
158
+ # 5. NVIDIA-Nemotron-Parse-v1.1
159
+ print("Downloading NVIDIA-Nemotron snapshot to ensure all scripts are present...")
160
+ try:
161
+ NEMO_DIR = snapshot_download(repo_id="nvidia/NVIDIA-Nemotron-Parse-v1.1")
162
+ print(f"Model downloaded to: {NEMO_DIR}")
163
+ sys.path.append(NEMO_DIR)
164
+
165
+ # Import postprocessing from the downloaded directory
166
+ # Note: Using try/except in case imports fail, though usually required for this model
167
+ try:
168
+ from postprocessing import extract_classes_bboxes, transform_bbox_to_original, postprocess_text
169
+ except ImportError:
170
+ print("Warning: Could not import Nemotron postprocessing scripts. Fallback to raw decode.")
171
+
172
+ MODEL_NEMO = "nvidia/NVIDIA-Nemotron-Parse-v1.1"
173
+ print(f"Loading {MODEL_NEMO}...")
174
+ processor_nemo = AutoProcessor.from_pretrained(NEMO_DIR, trust_remote_code=True)
175
+ model_nemo = AutoModel.from_pretrained(
176
+ NEMO_DIR,
177
+ trust_remote_code=True,
178
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
179
+ ).to(device).eval()
180
+
181
+ # Load generation config
182
+ gen_config_nemo = GenerationConfig.from_pretrained(NEMO_DIR, trust_remote_code=True)
183
+ NEMO_AVAILABLE = True
184
+ except Exception as e:
185
+ print(f"Error loading NVIDIA-Nemotron: {e}")
186
+ NEMO_AVAILABLE = False
187
 
188
+ print("✅ All models loaded successfully.")
189
 
190
  def clean_repeated_substrings(text):
191
  """Clean repeated substrings in text (for Hunyuan)"""
 
212
  print(f"Error opening result image: {e}")
213
  return None
214
 
 
 
215
  @spaces.GPU
216
  def run_model(
217
  model_choice,
 
376
  }
377
  ]
378
 
 
379
  text = processor_x.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
380
 
381
  inputs = processor_x(
 
404
  buffer += new_text.replace("<|im_end|>", "")
405
  yield buffer, None
406
 
407
+ # === NVIDIA-Nemotron-Parse-v1.1 Logic ===
408
+ elif model_choice == "NVIDIA-Nemotron-Parse-v1.1":
409
+ if not NEMO_AVAILABLE:
410
+ yield "Nemotron model failed to load. Check logs.", None
411
+ return
412
+
413
+ # Default Prompt for Nemotron markdown extraction
414
+ task_prompt = "</s><s><predict_bbox><predict_classes><output_markdown>"
415
+
416
+ # If user provides a custom prompt, we might want to use it,
417
+ # but Nemotron is highly specialized. Let's stick to the default strict prompt
418
+ # unless we want to support just raw text. For this demo, we use the standard full pipeline.
419
+ inputs = processor_nemo(images=[image], text=task_prompt, return_tensors="pt").to(model_nemo.device)
420
+
421
+ with torch.no_grad():
422
+ outputs = model_nemo.generate(
423
+ **inputs,
424
+ generation_config=gen_config_nemo,
425
+ max_new_tokens=max_new_tokens
426
+ )
427
+
428
+ generated_text = processor_nemo.batch_decode(outputs, skip_special_tokens=True)[0]
429
+
430
+ # The output might contain the prompt or special tokens depending on exact decoding
431
+ # The prompt used </s><s> which usually gets stripped by skip_special_tokens=True
432
+
433
+ yield generated_text, None
434
 
435
  image_examples = [
436
  ["examples/1.jpg"],
 
440
 
441
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
442
  gr.Markdown("# **Super-OCRs-Demo**", elem_id="main-title")
443
+ gr.Markdown("Compare DeepSeek-OCR, Dots.OCR, HunyuanOCR, Nanonets-OCR2-3B, and NVIDIA-Nemotron-Parse-v1.1")
444
 
445
  with gr.Row():
446
  with gr.Column(scale=1):
447
  # Global Inputs
448
  model_choice = gr.Dropdown(
449
+ choices=[
450
+ "DeepSeek-OCR-Latest-BF16.I64",
451
+ "Dots.OCR-Latest-BF16",
452
+ "HunyuanOCR",
453
+ "Nanonets-OCR2-3B",
454
+ "NVIDIA-Nemotron-Parse-v1.1"
455
+ ],
456
  label="Select Model",
457
  value="DeepSeek-OCR-Latest-BF16.I64"
458
  )
 
462
  with gr.Group(visible=True) as ds_group:
463
  ds_model_size = gr.Dropdown(
464
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
465
+ value="Gundam (Recommended)", label="DeepSeek Resolution"
466
  )
467
  ds_task_type = gr.Dropdown(
468
  choices=["Free OCR", "Convert to Markdown", "Parse Figure", "Locate Object by Reference"],
 
470
  )
471
  ds_ref_text = gr.Textbox(label="Reference Text (for 'Locate' task only)", placeholder="e.g., the title, red car...", visible=False)
472
 
 
473
  with gr.Group(visible=False) as prompt_group:
474
+ custom_prompt = gr.Textbox(label="Custom Query / Prompt", placeholder="Extract text...", lines=2, value="Convert to Markdown precisely.")
475
 
476
  with gr.Accordion("Advanced Settings", open=False):
477
  max_new_tokens = gr.Slider(minimum=128, maximum=8192, value=2048, step=128, label="Max New Tokens")
 
487
  output_text = gr.Textbox(label="Recognized Text / Markdown", lines=15, show_copy_button=True)
488
  output_image = gr.Image(label="Visual Grounding Result (DeepSeek Only)", type="pil")
489
 
 
 
490
  def update_visibility(model):
491
  is_ds = (model == "DeepSeek-OCR-Latest-BF16.I64")
492
  return gr.Group(visible=is_ds), gr.Group(visible=not is_ds)