Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -16,47 +16,47 @@ PY_MODULES = {
|
|
| 16 |
"texture.py": "TextureDetector",
|
| 17 |
"skin_tone.py": "SkinToneDetector",
|
| 18 |
"oiliness.py": "OilinessDetector",
|
| 19 |
-
"
|
| 20 |
"age.py": "AgePredictor"
|
| 21 |
}
|
| 22 |
-
def load_model(token):
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
#
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
#
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
|
| 61 |
def dynamic_import(module_path, class_name):
|
| 62 |
spec = importlib.util.spec_from_file_location(class_name, module_path)
|
|
|
|
| 16 |
"texture.py": "TextureDetector",
|
| 17 |
"skin_tone.py": "SkinToneDetector",
|
| 18 |
"oiliness.py": "OilinessDetector",
|
| 19 |
+
"wrinkle_unet.py": "WrinkleDetector",
|
| 20 |
"age.py": "AgePredictor"
|
| 21 |
}
|
| 22 |
+
# def load_model(token):
|
| 23 |
+
# """Download and load ONNX model"""
|
| 24 |
+
# model_path = hf_hub_download(
|
| 25 |
+
# repo_id=REPO_ID,
|
| 26 |
+
# filename="model/wrinkle_model.onnx.data", # Adjust path if needed
|
| 27 |
+
# token=token
|
| 28 |
+
# )
|
| 29 |
|
| 30 |
+
# # Create ONNX Runtime session
|
| 31 |
+
# session = ort.InferenceSession(
|
| 32 |
+
# model_path,
|
| 33 |
+
# providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 34 |
+
# if torch.cuda.is_available() else ['CPUExecutionProvider']
|
| 35 |
+
# )
|
| 36 |
|
| 37 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
# return session, device
|
| 39 |
+
def load_model(token):
|
| 40 |
+
repo_id = "IFMedTech/Skin-Analysis"
|
| 41 |
+
filename = "model/wrinkles_unet_v1.pth"
|
| 42 |
+
# token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
|
| 43 |
|
| 44 |
+
# if not token:
|
| 45 |
+
# raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
|
| 46 |
|
| 47 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
|
| 48 |
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
model = smp.Unet(
|
| 51 |
+
encoder_name="resnet34",
|
| 52 |
+
encoder_weights=None,
|
| 53 |
+
in_channels=3,
|
| 54 |
+
classes=1
|
| 55 |
+
)
|
| 56 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 57 |
+
model.to(device)
|
| 58 |
+
model.eval()
|
| 59 |
+
return model, device
|
| 60 |
|
| 61 |
def dynamic_import(module_path, class_name):
|
| 62 |
spec = importlib.util.spec_from_file_location(class_name, module_path)
|