Antuke commited on
Commit
972535f
·
1 Parent(s): 576fd2d
Files changed (1) hide show
  1. src/model.py +4 -3
src/model.py CHANGED
@@ -18,7 +18,7 @@ DROPOUT_P = 0.5
18
 
19
 
20
  class MTLModel(nn.Module):
21
- def __init__(self, backbone, tasks: List[Task],
22
  rank: int = 64,
23
  use_lora: bool = True,
24
  truncate_idx: int = 22,
@@ -28,7 +28,8 @@ class MTLModel(nn.Module):
28
  use_deep_head:bool = False,
29
  use_batch_norm:bool = True,
30
  use_mtl_attn_pool: bool = True,
31
- use_dora:bool = True):
 
32
 
33
  super().__init__()
34
  self.use_mtl_attn_pool=use_mtl_attn_pool
@@ -52,7 +53,7 @@ class MTLModel(nn.Module):
52
  self.ln_post = backbone.ln_post
53
 
54
  # save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
55
- orig_attn_pool = backbone.attn_pool.to('cuda')
56
 
57
  self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
58
 
 
18
 
19
 
20
  class MTLModel(nn.Module):
21
+ def __init__(self, backbone, tasks: List[Task], device
22
  rank: int = 64,
23
  use_lora: bool = True,
24
  truncate_idx: int = 22,
 
28
  use_deep_head:bool = False,
29
  use_batch_norm:bool = True,
30
  use_mtl_attn_pool: bool = True,
31
+ use_dora:bool = True,
32
+ ):
33
 
34
  super().__init__()
35
  self.use_mtl_attn_pool=use_mtl_attn_pool
 
53
  self.ln_post = backbone.ln_post
54
 
55
  # save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
56
+ orig_attn_pool = backbone.attn_pool.to(device)
57
 
58
  self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
59