File size: 2,712 Bytes
144afae
a8246e3
03bafc0
 
9de67ae
03bafc0
144afae
a8246e3
144afae
 
 
03bafc0
144afae
03bafc0
 
 
 
 
 
9de67ae
03bafc0
 
144afae
9de67ae
03bafc0
 
 
 
9de67ae
03bafc0
a8246e3
03bafc0
 
a8246e3
9de67ae
03bafc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de67ae
 
03bafc0
 
 
9de67ae
03bafc0
 
 
 
 
 
 
 
 
 
 
 
a8246e3
03bafc0
a8246e3
03bafc0
9de67ae
 
03bafc0
 
 
 
 
a8246e3
03bafc0
 
144afae
03bafc0
 
9de67ae
03bafc0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import cv2
import torch
import gc

from .segmenter import YOLOWorldDetector, SAM2Predictor
from .matcher import CLIPMatcher
from .painter import SDXLInpainter

class ObjectRemovalPipeline:
    def __init__(self):
        pass
    
    def _clear_ram(self):
        """Helper to force clear RAM & VRAM"""
        gc.collect()
        torch.cuda.empty_cache()

    def get_candidates(self, image, text_query):

        candidates = []
        box_candidates = []
        

        detector = YOLOWorldDetector()
        try:
            box_candidates = detector.detect(image, text_query)
        finally:
            del detector
            self._clear_ram()
            
        if not box_candidates:
            return [], "No objects detected."


        segmenter = SAM2Predictor()
        segments_to_score = []
        try:
            segmenter.set_image(image)
            for cand in box_candidates[:3]: 
                bbox = cand['bbox']
                mask_variations = segmenter.predict_from_box(bbox)
                for i, (mask, sam_score) in enumerate(mask_variations):
                    segments_to_score.append({
                        'mask': mask,
                        'bbox': bbox,
                        'area': mask.sum(),
                        'label': f"{cand['label']} (Var {i+1})"
                    })
        finally:

            segmenter.clear_memory()
            del segmenter
            self._clear_ram()


        matcher = CLIPMatcher()
        ranked_candidates = []
        try:
            ranked_candidates = matcher.get_top_k_segments(
                image, 
                segments_to_score, 
                text_query, 
                k=len(segments_to_score)
            )
        finally:
            del matcher
            self._clear_ram()
            
        return ranked_candidates, f"Found {len(ranked_candidates)} options."

    def inpaint_selected(self, image, selected_mask, inpaint_prompt="", shadow_expansion=0):


        if shadow_expansion > 0:
            kernel_h = int(shadow_expansion * 1.5)
            kernel_w = int(shadow_expansion * 0.5)
            kernel = np.ones((kernel_h, kernel_w), np.uint8)
            selected_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)

        kernel = np.ones((10, 10), np.uint8)
        final_mask = cv2.dilate(selected_mask.astype(np.uint8), kernel, iterations=1)
        
        result = None
        

        inpainter = SDXLInpainter()
        try:
            result = inpainter.inpaint(image, final_mask, prompt=inpaint_prompt)
        finally:
            del inpainter
            self._clear_ram()
            
        return result