prithivMLmods commited on
Commit
ac4e192
·
verified ·
1 Parent(s): 589ac46

update app

Browse files
Files changed (1) hide show
  1. app.py +70 -8
app.py CHANGED
@@ -8,7 +8,6 @@ import re
8
  import tempfile
9
  import ast
10
  import html
11
- import spaces
12
  from threading import Thread
13
  from typing import Iterable, Optional
14
 
@@ -18,6 +17,17 @@ import numpy as np
18
  from PIL import Image, ImageDraw, ImageOps
19
  import requests
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  from transformers import (
22
  AutoModel,
23
  AutoModelForCausalLM,
@@ -25,6 +35,7 @@ from transformers import (
25
  AutoProcessor,
26
  TextIteratorStreamer,
27
  HunYuanVLForConditionalGeneration,
 
28
  )
29
  from gradio.themes import Soft
30
  from gradio.themes.utils import colors, fonts, sizes
@@ -133,6 +144,17 @@ model_hy = HunYuanVLForConditionalGeneration.from_pretrained(
133
  device_map="auto"
134
  ).eval()
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  print("✅ All models loaded successfully.")
137
 
138
  # --- Helper Functions ---
@@ -289,7 +311,7 @@ def run_model(
289
  {
290
  "role": "user",
291
  "content": [
292
- {"type": "image", "image": image}, # The processor handles PIL images in list if passed correctly
293
  {"type": "text", "text": query},
294
  ],
295
  }
@@ -305,7 +327,7 @@ def run_model(
305
  generated_ids = model_hy.generate(
306
  **inputs,
307
  max_new_tokens=max_new_tokens,
308
- do_sample=False # Hunyuan OCR usually preferred greedy/beam
309
  )
310
 
311
  input_len = inputs.input_ids.shape[1]
@@ -315,6 +337,48 @@ def run_model(
315
  final_text = clean_repeated_substrings(output_text)
316
  yield final_text, None
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  # --- Gradio UI ---
319
 
320
  image_examples = [
@@ -325,13 +389,13 @@ image_examples = [
325
 
326
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
327
  gr.Markdown("# **Super-OCRs-Demo**", elem_id="main-title")
328
- gr.Markdown("Compare **DeepSeek-OCR**, **Dots.OCR**, and **HunyuanOCR** in one space.")
329
 
330
  with gr.Row():
331
  with gr.Column(scale=1):
332
  # Global Inputs
333
  model_choice = gr.Radio(
334
- choices=["HunyuanOCR", "DeepSeek-OCR-Latest-BF16.I64", "Dots.OCR-Latest-BF16"],
335
  label="Select Model",
336
  value="DeepSeek-OCR-Latest-BF16.I64"
337
  )
@@ -339,7 +403,6 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
339
 
340
  # DeepSeek Specific Options
341
  with gr.Group(visible=True) as ds_group:
342
- #gr.Markdown("### DeepSeek Settings")
343
  ds_model_size = gr.Dropdown(
344
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
345
  value="Gundam (Recommended)", label="DeepSeek Resolution"
@@ -350,7 +413,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
350
  )
351
  ds_ref_text = gr.Textbox(label="Reference Text (for 'Locate' task only)", placeholder="e.g., the title, red car...", visible=False)
352
 
353
- # General Prompt (for Dots/Hunyuan)
354
  with gr.Group(visible=False) as prompt_group:
355
  custom_prompt = gr.Textbox(label="Custom Query / Prompt", placeholder="Extract text...", lines=2)
356
 
@@ -365,7 +428,6 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
365
  gr.Examples(examples=image_examples, inputs=image_input)
366
 
367
  with gr.Column(scale=2):
368
- #gr.Markdown("## Output", elem_id="output-title")
369
  output_text = gr.Textbox(label="Recognized Text / Markdown", lines=15, show_copy_button=True)
370
  output_image = gr.Image(label="Visual Grounding Result (DeepSeek Only)", type="pil")
371
 
 
8
  import tempfile
9
  import ast
10
  import html
 
11
  from threading import Thread
12
  from typing import Iterable, Optional
13
 
 
17
  from PIL import Image, ImageDraw, ImageOps
18
  import requests
19
 
20
+ # Import spaces if available, otherwise mock it
21
+ try:
22
+ import spaces
23
+ except ImportError:
24
+ class spaces:
25
+ @staticmethod
26
+ def GPU(func):
27
+ def wrapper(*args, **kwargs):
28
+ return func(*args, **kwargs)
29
+ return wrapper
30
+
31
  from transformers import (
32
  AutoModel,
33
  AutoModelForCausalLM,
 
35
  AutoProcessor,
36
  TextIteratorStreamer,
37
  HunYuanVLForConditionalGeneration,
38
+ Qwen2_5_VLForConditionalGeneration,
39
  )
40
  from gradio.themes import Soft
41
  from gradio.themes.utils import colors, fonts, sizes
 
144
  device_map="auto"
145
  ).eval()
146
 
147
+ # 4. Nanonets-OCR2-3B
148
+ MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
149
+ print(f"Loading {MODEL_ID_X}...")
150
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
151
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
152
+ MODEL_ID_X,
153
+ trust_remote_code=True,
154
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
155
+ device_map="auto" # or .to(device)
156
+ ).eval()
157
+
158
  print("✅ All models loaded successfully.")
159
 
160
  # --- Helper Functions ---
 
311
  {
312
  "role": "user",
313
  "content": [
314
+ {"type": "image", "image": image},
315
  {"type": "text", "text": query},
316
  ],
317
  }
 
327
  generated_ids = model_hy.generate(
328
  **inputs,
329
  max_new_tokens=max_new_tokens,
330
+ do_sample=False
331
  )
332
 
333
  input_len = inputs.input_ids.shape[1]
 
337
  final_text = clean_repeated_substrings(output_text)
338
  yield final_text, None
339
 
340
+ # === Nanonets-OCR2-3B Logic ===
341
+ elif model_choice == "Nanonets-OCR2-3B":
342
+ query = custom_prompt if custom_prompt else "Extract the text from this image."
343
+ messages = [
344
+ {
345
+ "role": "user",
346
+ "content": [
347
+ {"type": "image", "image": image},
348
+ {"type": "text", "text": query},
349
+ ],
350
+ }
351
+ ]
352
+
353
+ # Prepare inputs for Qwen2.5-VL based architecture
354
+ text = processor_x.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
355
+
356
+ inputs = processor_x(
357
+ text=[text],
358
+ images=[image],
359
+ padding=True,
360
+ return_tensors="pt",
361
+ ).to(model_x.device)
362
+
363
+ streamer = TextIteratorStreamer(processor_x, skip_prompt=True, skip_special_tokens=True)
364
+ generation_kwargs = {
365
+ **inputs,
366
+ "streamer": streamer,
367
+ "max_new_tokens": max_new_tokens,
368
+ "do_sample": True,
369
+ "temperature": temperature,
370
+ "top_p": top_p,
371
+ "top_k": int(top_k),
372
+ }
373
+
374
+ thread = Thread(target=model_x.generate, kwargs=generation_kwargs)
375
+ thread.start()
376
+
377
+ buffer = ""
378
+ for new_text in streamer:
379
+ buffer += new_text.replace("<|im_end|>", "")
380
+ yield buffer, None
381
+
382
  # --- Gradio UI ---
383
 
384
  image_examples = [
 
389
 
390
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
391
  gr.Markdown("# **Super-OCRs-Demo**", elem_id="main-title")
392
+ gr.Markdown("Compare **DeepSeek-OCR**, **Dots.OCR**, **HunyuanOCR**, and **Nanonets-OCR2-3B** in one space.")
393
 
394
  with gr.Row():
395
  with gr.Column(scale=1):
396
  # Global Inputs
397
  model_choice = gr.Radio(
398
+ choices=["HunyuanOCR", "DeepSeek-OCR-Latest-BF16.I64", "Dots.OCR-Latest-BF16", "Nanonets-OCR2-3B"],
399
  label="Select Model",
400
  value="DeepSeek-OCR-Latest-BF16.I64"
401
  )
 
403
 
404
  # DeepSeek Specific Options
405
  with gr.Group(visible=True) as ds_group:
 
406
  ds_model_size = gr.Dropdown(
407
  choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
408
  value="Gundam (Recommended)", label="DeepSeek Resolution"
 
413
  )
414
  ds_ref_text = gr.Textbox(label="Reference Text (for 'Locate' task only)", placeholder="e.g., the title, red car...", visible=False)
415
 
416
+ # General Prompt (for Dots/Hunyuan/Nanonets)
417
  with gr.Group(visible=False) as prompt_group:
418
  custom_prompt = gr.Textbox(label="Custom Query / Prompt", placeholder="Extract text...", lines=2)
419
 
 
428
  gr.Examples(examples=image_examples, inputs=image_input)
429
 
430
  with gr.Column(scale=2):
 
431
  output_text = gr.Textbox(label="Recognized Text / Markdown", lines=15, show_copy_button=True)
432
  output_image = gr.Image(label="Visual Grounding Result (DeepSeek Only)", type="pil")
433