prithivMLmods commited on
Commit
e985b9a
·
verified ·
1 Parent(s): 07b42b2

update app files

Browse files
Files changed (2) hide show
  1. app.py +406 -0
  2. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import uuid
5
+ import json
6
+ import time
7
+ from threading import Thread
8
+ from typing import Iterable
9
+ from huggingface_hub import snapshot_download
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ import cv2
17
+
18
+ from transformers import (
19
+ Qwen2_5_VLForConditionalGeneration,
20
+ Qwen3VLForConditionalGeneration,
21
+ AutoModelForImageTextToText,
22
+ AutoModelForCausalLM,
23
+ AutoProcessor,
24
+ TextIteratorStreamer,
25
+ )
26
+
27
+ from transformers.image_utils import load_image
28
+ from gradio.themes import Soft
29
+ from gradio.themes.utils import colors, fonts, sizes
30
+
31
+ colors.steel_blue = colors.Color(
32
+ name="steel_blue",
33
+ c50="#EBF3F8",
34
+ c100="#D3E5F0",
35
+ c200="#A8CCE1",
36
+ c300="#7DB3D2",
37
+ c400="#529AC3",
38
+ c500="#4682B4",
39
+ c600="#3E72A0",
40
+ c700="#36638C",
41
+ c800="#2E5378",
42
+ c900="#264364",
43
+ c950="#1E3450",
44
+ )
45
+
46
+ class SteelBlueTheme(Soft):
47
+ def __init__(
48
+ self,
49
+ *,
50
+ primary_hue: colors.Color | str = colors.gray,
51
+ secondary_hue: colors.Color | str = colors.steel_blue,
52
+ neutral_hue: colors.Color | str = colors.slate,
53
+ text_size: sizes.Size | str = sizes.text_lg,
54
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
55
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
56
+ ),
57
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
58
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
59
+ ),
60
+ ):
61
+ super().__init__(
62
+ primary_hue=primary_hue,
63
+ secondary_hue=secondary_hue,
64
+ neutral_hue=neutral_hue,
65
+ text_size=text_size,
66
+ font=font,
67
+ font_mono=font_mono,
68
+ )
69
+ super().set(
70
+ background_fill_primary="*primary_50",
71
+ background_fill_primary_dark="*primary_900",
72
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
73
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
74
+ button_primary_text_color="white",
75
+ button_primary_text_color_hover="white",
76
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
77
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
78
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
79
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
80
+ button_secondary_text_color="black",
81
+ button_secondary_text_color_hover="white",
82
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
83
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
84
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
85
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
86
+ slider_color="*secondary_500",
87
+ slider_color_dark="*secondary_600",
88
+ block_title_text_weight="600",
89
+ block_border_width="3px",
90
+ block_shadow="*shadow_drop_lg",
91
+ button_primary_shadow="*shadow_drop_lg",
92
+ button_large_padding="11px",
93
+ color_accent_soft="*primary_100",
94
+ block_label_background_fill="*primary_200",
95
+ )
96
+
97
+ steel_blue_theme = SteelBlueTheme()
98
+
99
+ css = """
100
+ #main-title h1 {
101
+ font-size: 2.3em !important;
102
+ }
103
+ #output-title h2 {
104
+ font-size: 2.2em !important;
105
+ }
106
+
107
+ /* RadioAnimated Styles */
108
+ .ra-wrap{ width: fit-content; }
109
+ .ra-inner{
110
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
111
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
112
+ }
113
+ .ra-input{ display: none; }
114
+ .ra-label{
115
+ position: relative; z-index: 2; padding: 8px 16px;
116
+ font-family: inherit; font-size: 14px; font-weight: 600;
117
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
118
+ }
119
+ .ra-highlight{
120
+ position: absolute; z-index: 1; top: 6px; left: 6px;
121
+ height: calc(100% - 12px); border-radius: 9999px;
122
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
123
+ transition: transform 0.2s, width 0.2s;
124
+ }
125
+ .ra-input:checked + .ra-label{ color: black; }
126
+
127
+ /* Dark mode adjustments for Radio */
128
+ .dark .ra-inner { background: var(--neutral-800); }
129
+ .dark .ra-label { color: var(--neutral-400); }
130
+ .dark .ra-highlight { background: var(--neutral-600); }
131
+ .dark .ra-input:checked + .ra-label { color: white; }
132
+
133
+ #gpu-duration-container {
134
+ padding: 10px;
135
+ border-radius: 8px;
136
+ background: var(--background-fill-secondary);
137
+ border: 1px solid var(--border-color-primary);
138
+ margin-top: 10px;
139
+ }
140
+ """
141
+
142
+ MAX_MAX_NEW_TOKENS = 4096
143
+ DEFAULT_MAX_NEW_TOKENS = 2048
144
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
145
+
146
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
147
+
148
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
149
+ print("torch.__version__ =", torch.__version__)
150
+ print("torch.version.cuda =", torch.version.cuda)
151
+ print("cuda available:", torch.cuda.is_available())
152
+ print("cuda device count:", torch.cuda.device_count())
153
+ if torch.cuda.is_available():
154
+ print("current device:", torch.cuda.current_device())
155
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
156
+
157
+ print("Using device:", device)
158
+
159
+ class RadioAnimated(gr.HTML):
160
+ def __init__(self, choices, value=None, **kwargs):
161
+ if not choices or len(choices) < 2:
162
+ raise ValueError("RadioAnimated requires at least 2 choices.")
163
+ if value is None:
164
+ value = choices[0]
165
+
166
+ uid = uuid.uuid4().hex[:8]
167
+ group_name = f"ra-{uid}"
168
+
169
+ inputs_html = "\n".join(
170
+ f"""
171
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
172
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
173
+ """
174
+ for i, c in enumerate(choices)
175
+ )
176
+
177
+ html_template = f"""
178
+ <div class="ra-wrap" data-ra="{uid}">
179
+ <div class="ra-inner">
180
+ <div class="ra-highlight"></div>
181
+ {inputs_html}
182
+ </div>
183
+ </div>
184
+ """
185
+
186
+ js_on_load = r"""
187
+ (() => {
188
+ const wrap = element.querySelector('.ra-wrap');
189
+ const inner = element.querySelector('.ra-inner');
190
+ const highlight = element.querySelector('.ra-highlight');
191
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
192
+
193
+ if (!inputs.length) return;
194
+
195
+ const choices = inputs.map(i => i.value);
196
+
197
+ function setHighlightByIndex(idx) {
198
+ const n = choices.length;
199
+ const pct = 100 / n;
200
+ highlight.style.width = `calc(${pct}% - 6px)`;
201
+ highlight.style.transform = `translateX(${idx * 100}%)`;
202
+ }
203
+
204
+ function setCheckedByValue(val, shouldTrigger=false) {
205
+ const idx = Math.max(0, choices.indexOf(val));
206
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
207
+ setHighlightByIndex(idx);
208
+
209
+ props.value = choices[idx];
210
+ if (shouldTrigger) trigger('change', props.value);
211
+ }
212
+
213
+ setCheckedByValue(props.value ?? choices[0], false);
214
+
215
+ inputs.forEach((inp) => {
216
+ inp.addEventListener('change', () => {
217
+ setCheckedByValue(inp.value, true);
218
+ });
219
+ });
220
+ })();
221
+ """
222
+
223
+ super().__init__(
224
+ value=value,
225
+ html_template=html_template,
226
+ js_on_load=js_on_load,
227
+ **kwargs
228
+ )
229
+
230
+ def apply_gpu_duration(val: str):
231
+ return int(val)
232
+
233
+ MODEL_ID_V = "datalab-to/chandra"
234
+ processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
235
+ model_v = Qwen3VLForConditionalGeneration.from_pretrained(
236
+ MODEL_ID_V,
237
+ attn_implementation="kernels-community/flash-attn2",
238
+ trust_remote_code=True,
239
+ torch_dtype=torch.float16
240
+ ).to(device).eval()
241
+
242
+ MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
243
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
244
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
245
+ MODEL_ID_X,
246
+ attn_implementation="kernels-community/flash-attn2",
247
+ trust_remote_code=True,
248
+ torch_dtype=torch.bfloat16,
249
+ ).to(device).eval()
250
+
251
+ MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" # -> alt of [rednote-hilab/dots.ocr]
252
+ processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
253
+ model_d = AutoModelForCausalLM.from_pretrained(
254
+ MODEL_PATH_D,
255
+ attn_implementation="kernels-community/flash-attn2",
256
+ torch_dtype=torch.bfloat16,
257
+ device_map="auto",
258
+ trust_remote_code=True
259
+ ).eval()
260
+
261
+ MODEL_ID_M = "allenai/olmOCR-2-7B-1025"
262
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
263
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
264
+ MODEL_ID_M,
265
+ attn_implementation="kernels-community/flash-attn2",
266
+ trust_remote_code=True,
267
+ torch_dtype=torch.float16
268
+ ).to(device).eval()
269
+
270
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
271
+ max_new_tokens: int, temperature: float, top_p: float,
272
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
273
+ """Calculate GPU timeout duration for image inference."""
274
+ try:
275
+ return int(gpu_timeout)
276
+ except:
277
+ return 60
278
+
279
+ @spaces.GPU(duration=calc_timeout_image)
280
+ def generate_image(model_name: str, text: str, image: Image.Image,
281
+ max_new_tokens: int, temperature: float, top_p: float,
282
+ top_k: int, repetition_penalty: float, gpu_timeout: int = 60):
283
+ """
284
+ Generates responses using the selected model for image input.
285
+ Yields raw text and Markdown-formatted text.
286
+ """
287
+ if model_name == "olmOCR-2-7B-1025":
288
+ processor = processor_m
289
+ model = model_m
290
+ elif model_name == "Nanonets-OCR2-3B":
291
+ processor = processor_x
292
+ model = model_x
293
+ elif model_name == "Chandra-OCR":
294
+ processor = processor_v
295
+ model = model_v
296
+ elif model_name == "Dots.OCR":
297
+ processor = processor_d
298
+ model = model_d
299
+ else:
300
+ yield "Invalid model selected.", "Invalid model selected."
301
+ return
302
+
303
+ if image is None:
304
+ yield "Please upload an image.", "Please upload an image."
305
+ return
306
+
307
+ messages = [{
308
+ "role": "user",
309
+ "content": [
310
+ {"type": "image"},
311
+ {"type": "text", "text": text},
312
+ ]
313
+ }]
314
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
315
+
316
+ inputs = processor(
317
+ text=[prompt_full],
318
+ images=[image],
319
+ return_tensors="pt",
320
+ padding=True).to(device)
321
+
322
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
323
+ generation_kwargs = {
324
+ **inputs,
325
+ "streamer": streamer,
326
+ "max_new_tokens": max_new_tokens,
327
+ "do_sample": True,
328
+ "temperature": temperature,
329
+ "top_p": top_p,
330
+ "top_k": top_k,
331
+ "repetition_penalty": repetition_penalty,
332
+ }
333
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
334
+ thread.start()
335
+ buffer = ""
336
+ for new_text in streamer:
337
+ buffer += new_text
338
+ buffer = buffer.replace("<|im_end|>", "")
339
+ time.sleep(0.01)
340
+ yield buffer, buffer
341
+
342
+ image_examples = [
343
+ ["Convert to Markdown.", "examples/3.jpg"],
344
+ ["Perform OCR on the image. [Markdown]", "examples/1.jpg"],
345
+ ["Extract the contents. [Markdown].", "examples/2.jpg"],
346
+ ]
347
+
348
+ with gr.Blocks() as demo:
349
+ gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
350
+ with gr.Row():
351
+ with gr.Column(scale=2):
352
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
353
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
354
+
355
+ image_submit = gr.Button("Submit", variant="primary")
356
+ gr.Examples(
357
+ examples=image_examples,
358
+ inputs=[image_query, image_upload]
359
+ )
360
+
361
+ with gr.Accordion("Advanced options", open=False):
362
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
363
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
364
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
365
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
366
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
367
+
368
+ with gr.Column(scale=3):
369
+ gr.Markdown("## Output", elem_id="output-title")
370
+ output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=15)
371
+ with gr.Accordion("(Result.md)", open=False):
372
+ markdown_output = gr.Markdown(label="(Result.Md)")
373
+
374
+ model_choice = gr.Radio(
375
+ choices=["Nanonets-OCR2-3B", "Chandra-OCR", "Dots.OCR", "olmOCR-2-7B-1025"],
376
+ label="Select Model",
377
+ value="Nanonets-OCR2-3B"
378
+ )
379
+
380
+ with gr.Row(elem_id="gpu-duration-container"):
381
+ with gr.Column():
382
+ gr.Markdown("**GPU Duration (seconds)**")
383
+ radioanimated_gpu_duration = RadioAnimated(
384
+ choices=["60", "90", "120", "180", "240", "300"],
385
+ value="60",
386
+ elem_id="radioanimated_gpu_duration"
387
+ )
388
+ gpu_duration_state = gr.Number(value=60, visible=False)
389
+
390
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
391
+
392
+ radioanimated_gpu_duration.change(
393
+ fn=apply_gpu_duration,
394
+ inputs=radioanimated_gpu_duration,
395
+ outputs=[gpu_duration_state],
396
+ api_visibility="private"
397
+ )
398
+
399
+ image_submit.click(
400
+ fn=generate_image,
401
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
402
+ outputs=[output, markdown_output]
403
+ )
404
+
405
+ if __name__ == "__main__":
406
+ demo.queue(max_size=50).launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@v4.57.6
2
+ git+https://github.com/huggingface/accelerate.git
3
+ git+https://github.com/huggingface/peft.git
4
+ transformers-stream-generator
5
+ huggingface_hub
6
+ qwen-vl-utils
7
+ sentencepiece
8
+ opencv-python
9
+ torch==2.8.0
10
+ torchvision
11
+ matplotlib
12
+ requests
13
+ kernels
14
+ hf_xet
15
+ spaces
16
+ pillow
17
+ gradio # - gradio@6.3.0
18
+ av