Skin-Analysis / main.py
IFMedTechdemo's picture
Update main.py
bb46171 verified
import gradio as gr
import numpy as np
import cv2
import os
from huggingface_hub import hf_hub_download
import torch
import segmentation_models_pytorch as smp
import importlib.util
#import onnxruntime as ort
REPO_ID = "IFMedTech/Skin-Analysis"
# List of Python files and corresponding class names
PY_MODULES = {
"dark_circles.py": "DarkCircleDetector",
"inflammation.py": "RednessDetector",
"texture.py": "TextureDetector",
"skin_tone.py": "SkinToneDetector",
"oiliness.py": "OilinessDetector",
"wrinkle_unet.py": "WrinkleDetector",
"age.py": "AgePredictor"
}
# def load_model(token):
# """Download and load ONNX model"""
# model_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="model/wrinkle_model.onnx.data", # Adjust path if needed
# token=token
# )
# # Create ONNX Runtime session
# session = ort.InferenceSession(
# model_path,
# providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
# if torch.cuda.is_available() else ['CPUExecutionProvider']
# )
# device = "cuda" if torch.cuda.is_available() else "cpu"
# return session, device
def load_model(token):
repo_id = "IFMedTech/Skin-Analysis"
filename = "model/wrinkles_unet_v1.pth"
# token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
# if not token:
# raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=3,
classes=1
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model, device
def dynamic_import(module_path, class_name):
spec = importlib.util.spec_from_file_location(class_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, class_name)
# Dynamically download and import modules
detector_classes = {}
token = os.environ.get("HUGGINGFACE_TOKEN")
if not token:
raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable in repo secrets!")
for py_file, class_name in PY_MODULES.items():
py_path = hf_hub_download(
repo_id=REPO_ID,
filename=py_file,
token=token
)
detector_classes[class_name] = dynamic_import(py_path, class_name)
# --- Skin analysis function using downloaded detectors ---
def analyze_skin(image: np.ndarray, analysis_type: str) -> np.ndarray:
output = image.copy()
if analysis_type == "Dark Circles":
detector = detector_classes["DarkCircleDetector"](image)
result = detector.predict_json()
output = detector.draw_json()
elif analysis_type == "Redness":
detector = detector_classes["RednessDetector"](image)
result = detector.predict_json()
output = result.get("overlay_image")
elif analysis_type == "Texture":
detector = detector_classes["TextureDetector"](image)
result = detector.predict_json()
# print(result)
output = result.get("overlay_image")
elif analysis_type == "Skin Tone":
detector = detector_classes["SkinToneDetector"](image)
result = detector.predict_json()
output = result.get("output_image")
elif analysis_type == "Oiliness":
detector = detector_classes["OilinessDetector"](image)
result = detector.predict_json()
if result.get("detected"):
output = result.get("overlay_image")
# print(f"Oiliness scores: {result.get('scores')}")
# else:
# print(f"Oiliness detection error: {result.get('error')}")
elif analysis_type == "Wrinkles":
model, device = load_model(token)
detector = detector_classes["WrinkleDetector"](image, model, device)
result = detector.predict_json()
if result.get("detected") is not None:
output = detector.draw_json(result)
elif analysis_type == "Skin Age":
detector = detector_classes["AgePredictor"](image)
result = detector.predict_json()
output = detector.draw_json(result)
return output
# --- Gradio Interface code ---
app = gr.Interface(
fn=analyze_skin,
inputs=[
gr.Image(type="numpy", label="Upload your face image"),
gr.Radio(
["Dark Circles", "Redness", "Texture", "Skin Tone", "Oiliness", "Wrinkles", "Skin Age"],
label="Select Skin Analysis Type"
),
],
outputs=gr.Image(type="numpy", label="Analyzed Image"),
title="Skin Analysis Demo",
description="Upload an image and choose a skin analysis parameter."
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)