File size: 3,123 Bytes
7a343d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8735204
7a343d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import io
import base64
import time
import logging
from typing import Optional

import gradio as gr
from gradio_client import Client
from PIL import Image

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN environment variable is required")

backend_status = {
    "client": None,
    "connected": False,
    "last_check": None,
    "error_message": ""
}

def check_backend_connection():
    """Check connection to SAKS backend"""
    try:
        test_client = Client("SnapwearAI/Saks-backend-new", hf_token=HF_TOKEN)
        backend_status.update({
            "client": test_client,
            "connected": True,
            "error_message": "",
            "last_check": time.time(),
        })
        logger.info("βœ… SAKS Backend connection established")
        return True, "🟒 Backend is ready"
    except Exception as e:
        backend_status.update({
            "client": None,
            "connected": False,
            "last_check": time.time(),
            "error_message": str(e),
        })
        err = str(e).lower()
        if "timeout" in err or "read operation timed out" in err:
            return False, "🟑 Backend starting up..."
        return False, f"πŸ”΄ Backend error: {e}"

check_backend_connection()

def image_to_base64(image: Image.Image) -> str:
    if image is None:
        return ""
    if image.mode != "RGB":
        image = image.convert("RGB")
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

def base64_to_image(b64: str) -> Optional[Image.Image]:
    if not b64:
        return None
    try:
        return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
    except Exception as e:
        logger.error(f"Failed to decode base64: {e}")
        return None

def process_image(input_image: Image.Image):
    """Process image through SAKS backend"""
    if input_image is None:
        return None
    
    try:
        if not backend_status["connected"]:
            check_backend_connection()
            if not backend_status["connected"]:
                return None
        
        client = backend_status["client"]
        img_b64 = image_to_base64(input_image)
        
        result = client.predict(
            img_b64,
            api_name="/predict",
        )
        
        if not result or len(result) < 1:
            return None
        
        output_b64 = result[0]
        return base64_to_image(output_b64)
        
    except Exception as e:
        logger.error(f"Processing error: {e}")
        return None

with gr.Blocks(title="SAKS") as demo:
    gr.Markdown("# SAKS Jewelry Detection")
    
    with gr.Row():
        input_img = gr.Image(label="Input Image", type="pil")
        output_img = gr.Image(label="Output Image")
    
    process_btn = gr.Button("Process", variant="primary")
    
    process_btn.click(
        fn=process_image,
        inputs=input_img,
        outputs=output_img
    )

if __name__ == "__main__":
    demo.launch()