TextEraser / app.py
lxzcpro's picture
fix bugs
dec259d
import gradio as gr
import numpy as np
import argparse
import os
from src.pipeline import ObjectRemovalPipeline
from src.utils import visualize_mask
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=120):
def decorator(func):
return func
return decorator
# Initialize pipeline
pipeline = ObjectRemovalPipeline()
def ensure_uint8(image):
if image is None: return None
image = np.array(image)
if image.dtype != np.uint8:
if image.max() <= 1.0: image = image * 255.0
image = np.clip(image, 0, 255).astype(np.uint8)
return image
@spaces.GPU(duration=120)
def step1_detect(image, text_query):
if image is None or not text_query:
return [], [], "Please upload image and enter text."
candidates, msg = pipeline.get_candidates(image, text_query)
if not candidates:
return [], [], f"Error: {msg}"
masks = [c['mask'] for c in candidates]
gallery_imgs = []
for i, mask in enumerate(masks):
viz = visualize_mask(image, mask)
score = candidates[i].get('weighted_score', 0)
label = f"Option {i+1} (Score: {score:.2f})"
gallery_imgs.append((ensure_uint8(viz), label))
return masks, gallery_imgs, "Select the best match below."
def on_select(evt: gr.SelectData):
return evt.index
@spaces.GPU(duration=120)
def step2_remove(image, masks, selected_idx, prompt, shadow_exp):
if not masks or selected_idx is None:
return None, "Please select an object first."
target_mask = masks[selected_idx]
result = pipeline.inpaint_selected(image, target_mask, prompt, shadow_expansion=shadow_exp)
return ensure_uint8(result), "Success!"
css = """
.gradio-container {min-height: 0px !important}
button.gallery-item {object-fit: contain !important}
"""
with gr.Blocks(title="TextEraser") as demo:
mask_state = gr.State([])
idx_state = gr.State(0)
gr.Markdown("## TextEraser: Interactive Object Removal")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="numpy", height=400)
text_query = gr.Textbox(label="What to remove?", placeholder="e.g. 'bottle', 'shadow'")
btn_detect = gr.Button("1. Detect Objects", variant="primary")
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Candidates (Select One)",
columns=2,
height=400,
allow_preview=True,
object_fit="contain"
)
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
with gr.Column(scale=1):
shadow_slider = gr.Slider(0, 40, value=10, label="Shadow Fix (Expand Mask Downwards)")
inpaint_prompt = gr.Textbox(label="Background Description", value="background")
btn_remove = gr.Button("2. Remove Selected", variant="stop")
with gr.Column(scale=1):
output_image = gr.Image(label="Final Result", height=400)
btn_detect.click(
fn=step1_detect,
inputs=[input_image, text_query],
outputs=[mask_state, gallery, status]
)
gallery.select(fn=on_select, inputs=None, outputs=idx_state)
btn_remove.click(
fn=step2_remove,
inputs=[input_image, mask_state, idx_state, inpaint_prompt, shadow_slider],
outputs=[output_image, status]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
demo.queue().launch(share=args.share, css=css)