Antuke commited on
Commit
73f8599
·
1 Parent(s): 0d814f5
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -72,8 +72,9 @@ def scan_checkpoints(ckpt_dir):
72
 
73
  def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
74
  """Load and configure model."""
 
75
  backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
76
- model = MTLModel(backbone,tasks=TASKS,use_lora=True,use_deep_head=True,
77
  use_mtl_lora=('mtlora' in ckpt_dir),
78
  )
79
  print(f'loading from {ckpt_dir}')
@@ -201,6 +202,7 @@ def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
201
  # Load the perception encoder
202
  model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
203
  model.eval()
 
204
  model.to(device)
205
 
206
 
 
72
 
73
  def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
74
  """Load and configure model."""
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
77
+ model = MTLModel(backbone,device=device,tasks=TASKS,use_lora=True,use_deep_head=True,
78
  use_mtl_lora=('mtlora' in ckpt_dir),
79
  )
80
  print(f'loading from {ckpt_dir}')
 
202
  # Load the perception encoder
203
  model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
204
  model.eval()
205
+ print(device)
206
  model.to(device)
207
 
208