File size: 3,759 Bytes
144afae
 
c3f0641
 
144afae
 
 
c3f0641
 
 
 
 
 
 
 
 
 
dec259d
144afae
 
 
03bafc0
144afae
 
03bafc0
144afae
 
 
c3f0641
03bafc0
 
 
144afae
03bafc0
 
 
 
 
 
 
 
 
 
c3f0641
 
03bafc0
144afae
03bafc0
144afae
03bafc0
 
144afae
c3f0641
03bafc0
 
 
 
 
 
 
 
 
144afae
 
 
c3f0641
144afae
 
dec259d
03bafc0
 
 
 
144afae
 
 
 
03bafc0
 
144afae
 
03bafc0
 
 
 
 
 
 
 
 
 
 
 
 
 
144afae
03bafc0
 
 
 
 
 
 
 
144afae
03bafc0
144afae
03bafc0
 
 
 
144afae
 
 
c3f0641
dec259d
c3f0641
 
dec259d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import numpy as np
import argparse
import os
from src.pipeline import ObjectRemovalPipeline
from src.utils import visualize_mask

try:
    import spaces
except ImportError:
    class spaces:
        @staticmethod
        def GPU(duration=120):
            def decorator(func):
                return func
            return decorator

# Initialize pipeline
pipeline = ObjectRemovalPipeline()

def ensure_uint8(image):
    if image is None: return None
    image = np.array(image)
    if image.dtype != np.uint8:
        if image.max() <= 1.0: image = image * 255.0
        image = np.clip(image, 0, 255).astype(np.uint8)
    return image

@spaces.GPU(duration=120)
def step1_detect(image, text_query):
    if image is None or not text_query:
        return [], [], "Please upload image and enter text."
    
    candidates, msg = pipeline.get_candidates(image, text_query)
    
    if not candidates:
        return [], [], f"Error: {msg}"
    
    masks = [c['mask'] for c in candidates]
    
    gallery_imgs = []
    for i, mask in enumerate(masks):
        viz = visualize_mask(image, mask)
        score = candidates[i].get('weighted_score', 0)
        label = f"Option {i+1} (Score: {score:.2f})"
        gallery_imgs.append((ensure_uint8(viz), label))
        
    return masks, gallery_imgs, "Select the best match below."

def on_select(evt: gr.SelectData):
    return evt.index

@spaces.GPU(duration=120)
def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
    if not masks or selected_idx is None:
        return None, "Please select an object first."
    
    target_mask = masks[selected_idx]
    
    result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
    
    return ensure_uint8(result), "Success!"

css = """
.gradio-container {min-height: 0px !important}
button.gallery-item {object-fit: contain !important}
"""

with gr.Blocks(title="TextEraser") as demo:
    mask_state = gr.State([])
    idx_state = gr.State(0) 

    gr.Markdown("## TextEraser: Interactive Object Removal")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Input Image", type="numpy", height=400)
            text_query = gr.Textbox(label="What to remove?", placeholder="e.g. 'bottle', 'shadow'")
            btn_detect = gr.Button("1. Detect Objects", variant="primary")
        
        with gr.Column(scale=1):
            gallery = gr.Gallery(
                label="Candidates (Select One)", 
                columns=2, 
                height=400, 
                allow_preview=True, 
                object_fit="contain" 
            )
            status = gr.Textbox(label="Status", interactive=False)

    with gr.Row():
        with gr.Column(scale=1):
            shadow_slider = gr.Slider(0, 40, value=10, label="Shadow Fix (Expand Mask Downwards)")
            inpaint_prompt = gr.Textbox(label="Background Description", value="background")
            btn_remove = gr.Button("2. Remove Selected", variant="stop")
            
        with gr.Column(scale=1):
            output_image = gr.Image(label="Final Result", height=400)

    btn_detect.click(
        fn=step1_detect,
        inputs=[input_image, text_query],
        outputs=[mask_state, gallery, status]
    )
    
    gallery.select(fn=on_select, inputs=None, outputs=idx_state)
    
    btn_remove.click(
        fn=step2_remove,
        inputs=[input_image, mask_state, idx_state, inpaint_prompt, shadow_slider],
        outputs=[output_image, status]
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--share", action="store_true")
    args = parser.parse_args()
    
    demo.queue().launch(share=args.share, css=css)