Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import segmentation_models_pytorch as smp | |
| import importlib.util | |
| #import onnxruntime as ort | |
| REPO_ID = "IFMedTech/Skin-Analysis" | |
| # List of Python files and corresponding class names | |
| PY_MODULES = { | |
| "dark_circles.py": "DarkCircleDetector", | |
| "inflammation.py": "RednessDetector", | |
| "texture.py": "TextureDetector", | |
| "skin_tone.py": "SkinToneDetector", | |
| "oiliness.py": "OilinessDetector", | |
| "wrinkle_unet.py": "WrinkleDetector", | |
| "age.py": "AgePredictor" | |
| } | |
| # def load_model(token): | |
| # """Download and load ONNX model""" | |
| # model_path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename="model/wrinkle_model.onnx.data", # Adjust path if needed | |
| # token=token | |
| # ) | |
| # # Create ONNX Runtime session | |
| # session = ort.InferenceSession( | |
| # model_path, | |
| # providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| # if torch.cuda.is_available() else ['CPUExecutionProvider'] | |
| # ) | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # return session, device | |
| def load_model(token): | |
| repo_id = "IFMedTech/Skin-Analysis" | |
| filename = "model/wrinkles_unet_v1.pth" | |
| # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token | |
| # if not token: | |
| # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.") | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = smp.Unet( | |
| encoder_name="resnet34", | |
| encoder_weights=None, | |
| in_channels=3, | |
| classes=1 | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model, device | |
| def dynamic_import(module_path, class_name): | |
| spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return getattr(module, class_name) | |
| # Dynamically download and import modules | |
| detector_classes = {} | |
| token = os.environ.get("HUGGINGFACE_TOKEN") | |
| if not token: | |
| raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable in repo secrets!") | |
| for py_file, class_name in PY_MODULES.items(): | |
| py_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=py_file, | |
| token=token | |
| ) | |
| detector_classes[class_name] = dynamic_import(py_path, class_name) | |
| # --- Skin analysis function using downloaded detectors --- | |
| def analyze_skin(image: np.ndarray, analysis_type: str) -> np.ndarray: | |
| output = image.copy() | |
| if analysis_type == "Dark Circles": | |
| detector = detector_classes["DarkCircleDetector"](image) | |
| result = detector.predict_json() | |
| output = detector.draw_json() | |
| elif analysis_type == "Redness": | |
| detector = detector_classes["RednessDetector"](image) | |
| result = detector.predict_json() | |
| output = result.get("overlay_image") | |
| elif analysis_type == "Texture": | |
| detector = detector_classes["TextureDetector"](image) | |
| result = detector.predict_json() | |
| # print(result) | |
| output = result.get("overlay_image") | |
| elif analysis_type == "Skin Tone": | |
| detector = detector_classes["SkinToneDetector"](image) | |
| result = detector.predict_json() | |
| output = result.get("output_image") | |
| elif analysis_type == "Oiliness": | |
| detector = detector_classes["OilinessDetector"](image) | |
| result = detector.predict_json() | |
| if result.get("detected"): | |
| output = result.get("overlay_image") | |
| # print(f"Oiliness scores: {result.get('scores')}") | |
| # else: | |
| # print(f"Oiliness detection error: {result.get('error')}") | |
| elif analysis_type == "Wrinkles": | |
| model, device = load_model(token) | |
| detector = detector_classes["WrinkleDetector"](image, model, device) | |
| result = detector.predict_json() | |
| if result.get("detected") is not None: | |
| output = detector.draw_json(result) | |
| elif analysis_type == "Skin Age": | |
| detector = detector_classes["AgePredictor"](image) | |
| result = detector.predict_json() | |
| output = detector.draw_json(result) | |
| return output | |
| # --- Gradio Interface code --- | |
| app = gr.Interface( | |
| fn=analyze_skin, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Upload your face image"), | |
| gr.Radio( | |
| ["Dark Circles", "Redness", "Texture", "Skin Tone", "Oiliness", "Wrinkles", "Skin Age"], | |
| label="Select Skin Analysis Type" | |
| ), | |
| ], | |
| outputs=gr.Image(type="numpy", label="Analyzed Image"), | |
| title="Skin Analysis Demo", | |
| description="Upload an image and choose a skin analysis parameter." | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |