IFMedTechdemo commited on
Commit
09f77fb
·
verified ·
1 Parent(s): bd8f002

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -34
main.py CHANGED
@@ -16,47 +16,47 @@ PY_MODULES = {
16
  "texture.py": "TextureDetector",
17
  "skin_tone.py": "SkinToneDetector",
18
  "oiliness.py": "OilinessDetector",
19
- "wrinkle.py": "WrinkleDetector",
20
  "age.py": "AgePredictor"
21
  }
22
- def load_model(token):
23
- """Download and load ONNX model"""
24
- model_path = hf_hub_download(
25
- repo_id=REPO_ID,
26
- filename="model/wrinkle_model.onnx.data", # Adjust path if needed
27
- token=token
28
- )
29
 
30
- # Create ONNX Runtime session
31
- session = ort.InferenceSession(
32
- model_path,
33
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
34
- if torch.cuda.is_available() else ['CPUExecutionProvider']
35
- )
36
 
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
- return session, device
39
- # def load_model(token):
40
- # repo_id = "IFMedTech/Skin-Analysis"
41
- # filename = "model/wrinkles_unet_v1.pth"
42
- # # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
43
 
44
- # # if not token:
45
- # # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
46
 
47
- # model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
48
 
49
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
- # model = smp.Unet(
51
- # encoder_name="resnet34",
52
- # encoder_weights=None,
53
- # in_channels=3,
54
- # classes=1
55
- # )
56
- # model.load_state_dict(torch.load(model_path, map_location=device))
57
- # model.to(device)
58
- # model.eval()
59
- # return model, device
60
 
61
  def dynamic_import(module_path, class_name):
62
  spec = importlib.util.spec_from_file_location(class_name, module_path)
 
16
  "texture.py": "TextureDetector",
17
  "skin_tone.py": "SkinToneDetector",
18
  "oiliness.py": "OilinessDetector",
19
+ "wrinkle_unet.py": "WrinkleDetector",
20
  "age.py": "AgePredictor"
21
  }
22
+ # def load_model(token):
23
+ # """Download and load ONNX model"""
24
+ # model_path = hf_hub_download(
25
+ # repo_id=REPO_ID,
26
+ # filename="model/wrinkle_model.onnx.data", # Adjust path if needed
27
+ # token=token
28
+ # )
29
 
30
+ # # Create ONNX Runtime session
31
+ # session = ort.InferenceSession(
32
+ # model_path,
33
+ # providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
34
+ # if torch.cuda.is_available() else ['CPUExecutionProvider']
35
+ # )
36
 
37
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ # return session, device
39
+ def load_model(token):
40
+ repo_id = "IFMedTech/Skin-Analysis"
41
+ filename = "model/wrinkles_unet_v1.pth"
42
+ # token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # Set this env var with your token
43
 
44
+ # if not token:
45
+ # raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable is required for private repo access.")
46
 
47
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token)
48
 
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ model = smp.Unet(
51
+ encoder_name="resnet34",
52
+ encoder_weights=None,
53
+ in_channels=3,
54
+ classes=1
55
+ )
56
+ model.load_state_dict(torch.load(model_path, map_location=device))
57
+ model.to(device)
58
+ model.eval()
59
+ return model, device
60
 
61
  def dynamic_import(module_path, class_name):
62
  spec = importlib.util.spec_from_file_location(class_name, module_path)