Antuke
commited on
Commit
·
972535f
1
Parent(s):
576fd2d
fix
Browse files- src/model.py +4 -3
src/model.py
CHANGED
|
@@ -18,7 +18,7 @@ DROPOUT_P = 0.5
|
|
| 18 |
|
| 19 |
|
| 20 |
class MTLModel(nn.Module):
|
| 21 |
-
def __init__(self, backbone, tasks: List[Task],
|
| 22 |
rank: int = 64,
|
| 23 |
use_lora: bool = True,
|
| 24 |
truncate_idx: int = 22,
|
|
@@ -28,7 +28,8 @@ class MTLModel(nn.Module):
|
|
| 28 |
use_deep_head:bool = False,
|
| 29 |
use_batch_norm:bool = True,
|
| 30 |
use_mtl_attn_pool: bool = True,
|
| 31 |
-
use_dora:bool = True
|
|
|
|
| 32 |
|
| 33 |
super().__init__()
|
| 34 |
self.use_mtl_attn_pool=use_mtl_attn_pool
|
|
@@ -52,7 +53,7 @@ class MTLModel(nn.Module):
|
|
| 52 |
self.ln_post = backbone.ln_post
|
| 53 |
|
| 54 |
# save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
|
| 55 |
-
orig_attn_pool = backbone.attn_pool.to(
|
| 56 |
|
| 57 |
self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
|
| 58 |
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class MTLModel(nn.Module):
|
| 21 |
+
def __init__(self, backbone, tasks: List[Task], device
|
| 22 |
rank: int = 64,
|
| 23 |
use_lora: bool = True,
|
| 24 |
truncate_idx: int = 22,
|
|
|
|
| 28 |
use_deep_head:bool = False,
|
| 29 |
use_batch_norm:bool = True,
|
| 30 |
use_mtl_attn_pool: bool = True,
|
| 31 |
+
use_dora:bool = True,
|
| 32 |
+
):
|
| 33 |
|
| 34 |
super().__init__()
|
| 35 |
self.use_mtl_attn_pool=use_mtl_attn_pool
|
|
|
|
| 53 |
self.ln_post = backbone.ln_post
|
| 54 |
|
| 55 |
# save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
|
| 56 |
+
orig_attn_pool = backbone.attn_pool.to(device)
|
| 57 |
|
| 58 |
self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
|
| 59 |
|