Spaces:
Sleeping
Sleeping
Commit
·
d1bffba
1
Parent(s):
297686d
adding app with CLIP image segmentation
Browse files- app.py +93 -0
- images/image2.png +0 -0
- images/room.jpg +0 -0
- requirements.txt +11 -0
app.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from turtle import title
|
| 2 |
+
import os
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
|
| 10 |
+
from skimage.measure import label, regionprops
|
| 11 |
+
|
| 12 |
+
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 13 |
+
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
| 14 |
+
classes = list()
|
| 15 |
+
|
| 16 |
+
def create_mask(image,image_mask,alpha=0.7):
|
| 17 |
+
mask = np.zeros_like(image)
|
| 18 |
+
# copy your image_mask to all dimensions (i.e. colors) of your image
|
| 19 |
+
for i in range(3):
|
| 20 |
+
mask[:,:,i] = image_mask.copy()
|
| 21 |
+
# apply the mask to your image
|
| 22 |
+
overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
|
| 23 |
+
return overlay_image
|
| 24 |
+
|
| 25 |
+
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
|
| 26 |
+
bbox = np.asarray(bbox)/model_shape
|
| 27 |
+
y1,y2 = bbox[::2] *orig_image_shape[0]
|
| 28 |
+
x1,x2 = bbox[1::2]*orig_image_shape[1]
|
| 29 |
+
return [int(y1),int(x1),int(y2),int(x2)]
|
| 30 |
+
|
| 31 |
+
def detect_using_clip(image,prompts=[],threshould=0.4):
|
| 32 |
+
model_detections = dict()
|
| 33 |
+
predicted_images = dict()
|
| 34 |
+
inputs = processor(
|
| 35 |
+
text=prompts,
|
| 36 |
+
images=[image] * len(prompts),
|
| 37 |
+
padding="max_length",
|
| 38 |
+
return_tensors="pt",
|
| 39 |
+
)
|
| 40 |
+
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
|
| 41 |
+
outputs = model(**inputs)
|
| 42 |
+
preds = outputs.logits.unsqueeze(1)
|
| 43 |
+
|
| 44 |
+
detection = outputs.logits[0] # Assuming class index 0
|
| 45 |
+
for i,prompt in enumerate(prompts):
|
| 46 |
+
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
| 47 |
+
predicted_image = np.where(predicted_image>threshould,255,0)
|
| 48 |
+
# extract countours from the image
|
| 49 |
+
lbl_0 = label(predicted_image)
|
| 50 |
+
props = regionprops(lbl_0)
|
| 51 |
+
prompt = prompt.lower()
|
| 52 |
+
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
|
| 53 |
+
predicted_images[prompt]= cv2.resize(predicted_image,image.shape[:2])
|
| 54 |
+
return model_detections , predicted_images
|
| 55 |
+
|
| 56 |
+
def visualize_images(image,detections,predicted_image,prompt):
|
| 57 |
+
alpha = 0.7
|
| 58 |
+
H,W = image.shape[:2]
|
| 59 |
+
prompt = prompt.lower()
|
| 60 |
+
image_copy = image.copy()
|
| 61 |
+
mask_image = create_mask(image=image_copy,image_mask=predicted_image)
|
| 62 |
+
|
| 63 |
+
if prompt not in detections.keys():
|
| 64 |
+
print("prompt not in query ..")
|
| 65 |
+
return image_copy
|
| 66 |
+
for bbox in detections[prompt]:
|
| 67 |
+
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
|
| 68 |
+
cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
|
| 69 |
+
final_image = cv2.addWeighted(image_copy,alpha,mask_image,1-alpha,0)
|
| 70 |
+
return final_image
|
| 71 |
+
|
| 72 |
+
def shot(image, labels_text,selected_categoty):
|
| 73 |
+
prompts = labels_text.split(',')
|
| 74 |
+
prompts = list(map(lambda x: x.strip(),prompts))
|
| 75 |
+
|
| 76 |
+
model_detections,predicted_images = detect_using_clip(image,prompts=prompts)
|
| 77 |
+
|
| 78 |
+
category_image = visualize_images(image=image,detections=model_detections,predicted_image=predicted_images,prompt=selected_categoty)
|
| 79 |
+
return category_image
|
| 80 |
+
|
| 81 |
+
iface = gr.Interface(fn=shot,
|
| 82 |
+
inputs = ["image","text","text"],
|
| 83 |
+
outputs = "image",
|
| 84 |
+
description ="Add an Image and list of category to be detected separated by commas",
|
| 85 |
+
title = "Zero-shot Image Classification with Prompt ",
|
| 86 |
+
examples=[
|
| 87 |
+
["images/room.jpg","bed, table, plant, light, window",'plant'],
|
| 88 |
+
["images/image2.png","banner, building,door, sign","sign"]
|
| 89 |
+
],
|
| 90 |
+
# allow_flagging=False,
|
| 91 |
+
# analytics_enabled=False,
|
| 92 |
+
)
|
| 93 |
+
iface.launch()
|
images/image2.png
ADDED
|
images/room.jpg
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
torch
|
| 3 |
+
sentencepiece
|
| 4 |
+
huggingface_hub
|
| 5 |
+
numpy
|
| 6 |
+
scikit-image
|
| 7 |
+
opencv-python
|
| 8 |
+
Pillow
|
| 9 |
+
requests
|
| 10 |
+
urllib3<2
|
| 11 |
+
git+https://github.com/facebookresearch/segment-anything.git
|