ekolodin commited on
Commit
b3dc287
·
verified ·
1 Parent(s): f30ab06

3b-september-2025 upload

Browse files
config.json CHANGED
@@ -1,127 +1,96 @@
1
  {
2
- "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
3
- "add_eos": true,
4
- "add_pad_token": true,
5
- "architectures": [
6
- "GigarEmbedModel"
7
- ],
8
- "auto_map": {
9
- "AutoConfig": "configuration_gigarembed.GigarEmbedConfig",
10
- "AutoModel": "modeling_gigarembed.GigarEmbedModel"
11
- },
12
- "hidden_size": 2048,
13
- "is_mask_instruction": true,
14
- "latent_attention_config": {
15
- "cross_dim_head": 2048,
16
- "hidden_dim": 2048,
17
- "latent_dim": 2048,
18
- "model_type": "latent_attention"
19
- },
20
- "mask_type": "b",
21
- "model_type": "gigarembed",
22
- "padding_side": "right",
23
- "text_config": {
24
- "_attn_implementation_autoset": false,
25
  "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
 
26
  "activation_checkpoint_layers_num": null,
27
- "add_cross_attention": false,
28
  "architectures": [
29
- "LlamaForCausalLM"
30
  ],
31
- "attention_bias": false,
32
- "attention_dropout": 0.0,
33
- "attention_hidden_size": null,
34
- "attention_type": "LlamaPackedAttention",
35
- "bad_words_ids": null,
36
- "begin_suppress_tokens": null,
37
- "bos_token_id": 1,
38
- "chunk_size_feed_forward": 0,
39
- "cross_attention_hidden_size": null,
40
- "decoder_start_token_id": null,
41
- "deterministic_attention": false,
42
- "diversity_penalty": 0.0,
43
- "do_sample": false,
44
- "early_stopping": false,
45
- "encoder_no_repeat_ngram_size": 0,
46
- "eos_token_id": 2,
47
- "exponential_decay_length_penalty": null,
48
- "finetuning_task": null,
49
- "forced_bos_token_id": null,
50
- "forced_eos_token_id": null,
51
- "freeze_non_embed": false,
52
- "fused_mlp": true,
53
- "fused_mlp_checkpoint_lvl": 3,
54
- "head_dim": 128,
55
- "hidden_act": "silu",
56
- "hidden_size": 2048,
57
- "id2label": {
58
- "0": "LABEL_0",
59
- "1": "LABEL_1"
60
  },
61
- "init_device": "meta",
62
- "initializer_range": 0.02,
63
- "intermediate_size": 11008,
64
- "is_decoder": false,
65
- "is_encoder_decoder": false,
66
- "label2id": {
67
- "LABEL_0": 0,
68
- "LABEL_1": 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  },
70
- "length_penalty": 1.0,
71
- "loss_inplace_backward": true,
72
- "max_length": 20,
73
- "max_position_embeddings": 4096,
74
- "max_window_layers": 36,
75
- "min_length": 0,
76
- "mlp_bias": false,
77
- "model_type": "llama",
78
- "no_repeat_ngram_size": 0,
79
- "num_attention_heads": 16,
80
- "num_beam_groups": 1,
81
- "num_beams": 1,
82
- "num_hidden_layers": 36,
83
- "num_key_value_heads": 2,
84
- "num_return_sequences": 1,
85
- "output_attentions": false,
86
- "output_hidden_states": false,
87
- "output_scores": false,
88
- "pad_token_id": 2,
89
- "prefix": null,
90
- "pretraining_tp": 1,
91
- "problem_type": null,
92
- "pruned_heads": {},
93
- "remove_invalid_values": false,
94
- "repetition_penalty": 1.0,
95
- "return_dict": true,
96
- "return_dict_in_generate": false,
97
- "rms_norm_eps": 1e-06,
98
- "rope_scaling": null,
99
- "rope_theta": 1300,
100
- "sep_token_id": null,
101
- "sliding_window": null,
102
- "sp_split_type": "equal",
103
- "suppress_tokens": null,
104
- "task_specific_params": null,
105
- "temperature": 1.0,
106
- "tf_legacy_loss": false,
107
- "tie_encoder_decoder": false,
108
- "tie_word_embeddings": false,
109
- "tokenizer_class": null,
110
- "top_k": 50,
111
- "top_p": 1.0,
112
- "torch_dtype": "float32",
113
- "torchscript": false,
114
- "tp_group": null,
115
- "tp_size": 1,
116
- "typical_p": 1.0,
117
- "unk_token_id": 0,
118
- "use_bfloat16": false,
119
- "use_cache": true,
120
- "use_mrope": false,
121
- "use_sliding_window": false,
122
- "varlen_input": false,
123
- "vocab_size": 128256
124
- },
125
- "torch_dtype": "float32",
126
- "transformers_version": "4.46.3"
127
  }
 
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
3
+ "_non_freeze_layers_idxs": null,
4
  "activation_checkpoint_layers_num": null,
5
+ "apply_torch_compile_to_projections": true,
6
  "architectures": [
7
+ "GigarEmbedModel"
8
  ],
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_gigarembed.GigarEmbedConfig",
11
+ "AutoModel": "modeling_gigarembed.GigarEmbedModel"
12
+ },
13
+ "latent_attention_config": {
14
+ "model_type": "latent_attention",
15
+ "num_latents_value": 512,
16
+ "num_cross_heads": 8,
17
+ "cross_dim_head": 2048,
18
+ "hidden_dim": 2048,
19
+ "latent_dim": 2048,
20
+ "mult": 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  },
22
+ "hidden_size": 2048,
23
+ "text_config": {
24
+ "_name_or_path": "ai-sage/Giga-Embeddings-instruct",
25
+ "apply_qk_norm": true,
26
+ "attention_bias": false,
27
+ "attention_dropout": 0.0,
28
+ "attention_hidden_size": null,
29
+ "attention_type": "LlamaLatentAttention",
30
+ "bos_token_id": 1,
31
+ "delete_logits": true,
32
+ "deterministic_attention": false,
33
+ "enable_async_tp": false,
34
+ "eos_token_id": 2,
35
+ "freeze_non_embed": false,
36
+ "fused_mlp": true,
37
+ "fused_mlp_checkpoint_lvl": 3,
38
+ "head_dim": 64,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 2048,
41
+ "ignore_index": -100,
42
+ "init_device": "meta",
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 11008,
45
+ "kv_lora_rank": 1024,
46
+ "lora_alpha": null,
47
+ "lora_r": null,
48
+ "loss_inplace_backward": false,
49
+ "max_position_embeddings": 4096,
50
+ "max_window_layers": 36,
51
+ "mla_config": {
52
+ "kv_lora_rank": 1024,
53
+ "q_lora_rank": 0,
54
+ "qk_nope_head_dim": 64,
55
+ "qk_rope_head_dim": 64,
56
+ "v_head_dim": 128
57
+ },
58
+ "mlp_bias": false,
59
+ "model_type": "gigar",
60
+ "mtp_loss_weight": 0.1,
61
+ "mtp_predictor_num": 1,
62
+ "norm_type": "LlamaRMSNorm",
63
+ "num_attention_heads": 16,
64
+ "num_hidden_layers": 36,
65
+ "num_key_value_heads": 16,
66
+ "pad_token_id": 2,
67
+ "parallel_embedding_type": "EmbeddingParallelEmbedding",
68
+ "pretraining_tp": 1,
69
+ "q_lora_rank": 0,
70
+ "qk_nope_head_dim": 64,
71
+ "qk_rope_head_dim": 64,
72
+ "rms_norm_eps": 1e-06,
73
+ "rope_scaling": null,
74
+ "rope_theta": 100000.0,
75
+ "skip_init_tp_modules": true,
76
+ "sliding_window": null,
77
+ "sp_split_type": "equal",
78
+ "tie_word_embeddings": false,
79
+ "tp_group": null,
80
+ "tp_size": 1,
81
+ "unk_token_id": 0,
82
+ "use_cache": false,
83
+ "use_cache_force": false,
84
+ "use_custom_rotary_kernel": false,
85
+ "use_liger": false,
86
+ "use_mrope": false,
87
+ "use_mtp": true,
88
+ "use_sliding_window": false,
89
+ "v_head_dim": 128,
90
+ "varlen_input": true,
91
+ "vocab_size": 128256,
92
+ "z_loss_eps": 5e-05
93
  },
94
+ "torch_dtype": "bfloat16",
95
+ "transformers_version": "4.48.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  }
configuration_gigarembed.py CHANGED
@@ -1,12 +1,238 @@
 
 
1
  from typing import Literal
2
  from transformers import AutoConfig
3
- from transformers.configuration_utils import PretrainedConfig
4
  from transformers.models.auto import CONFIG_MAPPING
5
- from transformers.models.llama import LlamaConfig
 
6
 
7
  GIGAREMBED_TYPE = "gigarembed"
8
  LATENT_ATTENTION_TYPE = "latent_attention"
9
- BIDIR_LLAMA_TYPE = "bidir_llama"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class GigarEmbedConfig(PretrainedConfig):
12
  model_type = "gigarembed"
@@ -28,14 +254,11 @@ class GigarEmbedConfig(PretrainedConfig):
28
  latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
29
  )
30
  latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
31
- elif latent_attention_config is None:
32
- latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
33
 
34
  self.latent_attention_config = latent_attention_config
35
 
36
  if isinstance(text_config, dict):
37
- text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
38
- text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
39
  elif text_config is None:
40
  text_config = None
41
 
@@ -47,8 +270,6 @@ class GigarEmbedConfig(PretrainedConfig):
47
  self.mask_type = mask_type
48
  if "hidden_size" in kwargs:
49
  self.hidden_size = kwargs["hidden_size"]
50
- else:
51
- self.hidden_size = 2560
52
 
53
  super().__init__(**kwargs)
54
 
@@ -60,30 +281,26 @@ class LatentAttentionConfig(PretrainedConfig):
60
 
61
  def __init__(
62
  self,
63
- num_latents_value: int=512,
64
- num_cross_heads: int=8,
65
- output_normalize: bool=True,
66
- hidden_dim: int=2560,
67
- latent_dim: int=2560,
68
- cross_dim_head: int=2560,
69
  **kwargs,
70
  ):
71
  self.num_latents_value = num_latents_value
72
  self.num_cross_heads = num_cross_heads
73
- self.output_normalize = output_normalize
74
  self.hidden_dim = hidden_dim
75
  self.latent_dim = latent_dim
76
  self.cross_dim_head = cross_dim_head
77
- self._attn_implementation = "eager"
 
 
78
 
79
- class BidirectionalLlamaConfig(LlamaConfig):
80
- model_type = BIDIR_LLAMA_TYPE
81
- keys_to_ignore_at_inference = ["past_key_values"]
82
 
83
  AutoConfig.register(GIGAREMBED_TYPE, GigarEmbedConfig)
84
  AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
85
- AutoConfig.register(BIDIR_LLAMA_TYPE, BidirectionalLlamaConfig)
86
 
87
  GigarEmbedConfig.register_for_auto_class()
88
  LatentAttentionConfig.register_for_auto_class()
89
- BidirectionalLlamaConfig.register_for_auto_class()
 
1
+ import warnings
2
+
3
  from typing import Literal
4
  from transformers import AutoConfig
 
5
  from transformers.models.auto import CONFIG_MAPPING
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_rope_utils import rope_config_validation
8
 
9
  GIGAREMBED_TYPE = "gigarembed"
10
  LATENT_ATTENTION_TYPE = "latent_attention"
11
+
12
+
13
+ class GigarConfig(PretrainedConfig):
14
+ r"""
15
+ This is the configuration class to store the configuration of a [`GigarModel`]. It is used to instantiate an Gigar
16
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
17
+ defaults will yield a similar configuration to that of the Gigar-7B.
18
+
19
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
20
+ documentation from [`PretrainedConfig`] for more information.
21
+
22
+
23
+ Args:
24
+ vocab_size (`int`, *optional*, defaults to 32000):
25
+ Vocabulary size of the Gigar model. Defines the number of different tokens that can be represented by the
26
+ `inputs_ids` passed when calling [`GigarModel`]
27
+ hidden_size (`int`, *optional*, defaults to 4096):
28
+ Dimension of the hidden representations.
29
+ intermediate_size (`int`, *optional*, defaults to 11008):
30
+ Dimension of the MLP representations.
31
+ num_hidden_layers (`int`, *optional*, defaults to 32):
32
+ Number of hidden layers in the Transformer decoder.
33
+ num_attention_heads (`int`, *optional*, defaults to 32):
34
+ Number of attention heads for each attention layer in the Transformer decoder.
35
+ num_key_value_heads (`int`, *optional*):
36
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
37
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
38
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
39
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
40
+ by meanpooling all the original heads within that group. For more details checkout [this
41
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
42
+ `num_attention_heads`.
43
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
44
+ The non-linear activation function (function or string) in the decoder.
45
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
46
+ The maximum sequence length that this model might ever be used with. Gigar 1 supports up to 2048 tokens,
47
+ Gigar 2 up to 4096, CodeLlama up to 16384.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ pad_token_id (`int`, *optional*):
56
+ Padding token id.
57
+ bos_token_id (`int`, *optional*, defaults to 1):
58
+ Beginning of stream token id.
59
+ eos_token_id (`int`, *optional*, defaults to 2):
60
+ End of stream token id.
61
+ pretraining_tp (`int`, *optional*, defaults to 1):
62
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
63
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
64
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
65
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
66
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
67
+ Whether to tie weight embeddings
68
+ rope_theta (`float`, *optional*, defaults to 10000.0):
69
+ The base period of the RoPE embeddings.
70
+ rope_scaling (`Dict`, *optional*):
71
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
72
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73
+ accordingly.
74
+ Expected contents:
75
+ `rope_type` (`str`):
76
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
77
+ 'gigar3'], with 'default' being the original RoPE implementation.
78
+ `factor` (`float`, *optional*):
79
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
80
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
81
+ original maximum pre-trained length.
82
+ `original_max_position_embeddings` (`int`, *optional*):
83
+ Used with 'dynamic', 'longrope' and 'gigar3'. The original max position embeddings used during
84
+ pretraining.
85
+ `attention_factor` (`float`, *optional*):
86
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
87
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
88
+ `factor` field to infer the suggested value.
89
+ `beta_fast` (`float`, *optional*):
90
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
91
+ ramp function. If unspecified, it defaults to 32.
92
+ `beta_slow` (`float`, *optional*):
93
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
94
+ ramp function. If unspecified, it defaults to 1.
95
+ `short_factor` (`List[float]`, *optional*):
96
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
97
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
98
+ size divided by the number of attention heads divided by 2
99
+ `long_factor` (`List[float]`, *optional*):
100
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
101
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
102
+ size divided by the number of attention heads divided by 2
103
+ `low_freq_factor` (`float`, *optional*):
104
+ Only used with 'gigar3'. Scaling factor applied to low frequency components of the RoPE
105
+ `high_freq_factor` (`float`, *optional*):
106
+ Only used with 'gigar3'. Scaling factor applied to high frequency components of the RoPE
107
+ attention_bias (`bool`, *optional*, defaults to `False`):
108
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
109
+ attention_dropout (`float`, *optional*, defaults to 0.0):
110
+ The dropout ratio for the attention probabilities.
111
+ mlp_bias (`bool`, *optional*, defaults to `False`):
112
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
113
+ head_dim (`int`, *optional*):
114
+ The attention head dimension. If None, it will default to hidden_size // num_attention_heads
115
+
116
+ ```python
117
+ >>> from transformers import GigarModel, GigarConfig
118
+
119
+ >>> # Initializing a Gigar gigar-7b style configuration
120
+ >>> configuration = GigarConfig()
121
+
122
+ >>> # Initializing a model from the gigar-7b style configuration
123
+ >>> model = GigarModel(configuration)
124
+
125
+ >>> # Accessing the model configuration
126
+ >>> configuration = model.config
127
+ ```"""
128
+
129
+ model_type = "gigar"
130
+ keys_to_ignore_at_inference = ["past_key_values"]
131
+ # Default tensor parallel plan for base model `GigarModel`
132
+ base_model_tp_plan = {
133
+ "layers.*.self_attn.q_proj": "colwise",
134
+ "layers.*.self_attn.k_proj": "colwise",
135
+ "layers.*.self_attn.v_proj": "colwise",
136
+ "layers.*.self_attn.o_proj": "rowwise",
137
+ "layers.*.mlp.gate_proj": "colwise",
138
+ "layers.*.mlp.up_proj": "colwise",
139
+ "layers.*.mlp.down_proj": "rowwise",
140
+ }
141
+
142
+ def __init__(
143
+ self,
144
+ vocab_size=32000,
145
+ hidden_size=4096,
146
+ intermediate_size=11008,
147
+ num_hidden_layers=32,
148
+ num_attention_heads=32,
149
+ num_key_value_heads=None,
150
+ hidden_act="silu",
151
+ max_position_embeddings=2048,
152
+ initializer_range=0.02,
153
+ rms_norm_eps=1e-6,
154
+ use_cache=True,
155
+ pad_token_id=None,
156
+ bos_token_id=1,
157
+ eos_token_id=2,
158
+ pretraining_tp=1,
159
+ tie_word_embeddings=False,
160
+ rope_theta=10000.0,
161
+ rope_scaling=None,
162
+ attention_bias=False,
163
+ attention_dropout=0.0,
164
+ mlp_bias=False,
165
+ head_dim=None,
166
+ apply_qk_norm=False,
167
+ mla_config=None,
168
+ **kwargs,
169
+ ):
170
+ super().__init__(
171
+ pad_token_id=pad_token_id,
172
+ bos_token_id=bos_token_id,
173
+ eos_token_id=eos_token_id,
174
+ tie_word_embeddings=tie_word_embeddings,
175
+ **kwargs,
176
+ )
177
+
178
+ self.vocab_size = vocab_size
179
+ self.max_position_embeddings = max_position_embeddings
180
+ self.hidden_size = hidden_size
181
+ self.intermediate_size = intermediate_size
182
+ self.num_hidden_layers = num_hidden_layers
183
+ self.num_attention_heads = num_attention_heads
184
+
185
+ # for backward compatibility
186
+ if num_key_value_heads is None:
187
+ num_key_value_heads = num_attention_heads
188
+
189
+ self.num_key_value_heads = num_key_value_heads
190
+ self.hidden_act = hidden_act
191
+ self.initializer_range = initializer_range
192
+ self.rms_norm_eps = rms_norm_eps
193
+ self.pretraining_tp = pretraining_tp
194
+ self.use_cache = use_cache
195
+ self.rope_theta = rope_theta
196
+ self.rope_scaling = rope_scaling
197
+ self.attention_bias = attention_bias
198
+ self.attention_dropout = attention_dropout
199
+ self.mlp_bias = mlp_bias
200
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
201
+ # Validate the correctness of rotary position embeddings parameters
202
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
203
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
204
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
205
+ rope_config_validation(self)
206
+
207
+ self.apply_qk_norm = apply_qk_norm
208
+ self.mla_config = mla_config
209
+
210
+ self._validate_mla_config()
211
+
212
+ def _validate_mla_config(self):
213
+ if self.mla_config is None:
214
+ warnings.warn("MLA config is None!")
215
+ return
216
+
217
+ EXPECTED_KEYS = [
218
+ "qk_nope_head_dim",
219
+ "qk_rope_head_dim",
220
+ "v_head_dim",
221
+ "kv_lora_rank",
222
+ "q_lora_rank",
223
+ ]
224
+ if not all((key in self.mla_config for key in EXPECTED_KEYS)):
225
+ raise ValueError(
226
+ f"MLA config is expected to have the following keys {EXPECTED_KEYS} but got {self.mla_config.keys()}."
227
+ )
228
+
229
+ if self.mla_config["qk_nope_head_dim"] + self.mla_config["qk_rope_head_dim"] != self.mla_config["v_head_dim"]:
230
+ err_msg = (
231
+ f"QK and V head dims do not match! Got {self.mla_config['qk_nope_head_dim']} + {self.mla_config['qk_rope_head_dim']} "
232
+ f"= {self.mla_config['qk_rope_head_dim'] + self.mla_config['qk_nope_head_dim']} and {self.mla_config['v_head_dim']}."
233
+ )
234
+ raise ValueError(err_msg)
235
+
236
 
237
  class GigarEmbedConfig(PretrainedConfig):
238
  model_type = "gigarembed"
 
254
  latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
255
  )
256
  latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
 
 
257
 
258
  self.latent_attention_config = latent_attention_config
259
 
260
  if isinstance(text_config, dict):
261
+ text_config = GigarConfig(**text_config)
 
262
  elif text_config is None:
263
  text_config = None
264
 
 
270
  self.mask_type = mask_type
271
  if "hidden_size" in kwargs:
272
  self.hidden_size = kwargs["hidden_size"]
 
 
273
 
274
  super().__init__(**kwargs)
275
 
 
281
 
282
  def __init__(
283
  self,
284
+ num_latents_value: int,
285
+ num_cross_heads: int,
286
+ hidden_dim: int,
287
+ latent_dim: int,
288
+ cross_dim_head: int,
289
+ mult: int,
290
  **kwargs,
291
  ):
292
  self.num_latents_value = num_latents_value
293
  self.num_cross_heads = num_cross_heads
 
294
  self.hidden_dim = hidden_dim
295
  self.latent_dim = latent_dim
296
  self.cross_dim_head = cross_dim_head
297
+ self.mult = mult
298
+
299
+ super().__init__(**kwargs)
300
 
 
 
 
301
 
302
  AutoConfig.register(GIGAREMBED_TYPE, GigarEmbedConfig)
303
  AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
 
304
 
305
  GigarEmbedConfig.register_for_auto_class()
306
  LatentAttentionConfig.register_for_auto_class()
 
model-00001-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f32eedfe2127f8e9507427af1d796c6547df4c8e5795e4ea8b3a22a96e782292
3
- size 4930720644
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f58bbc242e82e65d3564fdf09a378b5fbf545e15bc8291d7a623a8873222f6f
3
+ size 4947550528
model-00002-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:edc8c0c52613a2712e8c65b3d8b4249b6e99622c695ee5aec698ca37a5a556d3
3
- size 4932780264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe27f921338efe4e1b41a27f6a1c87ca35d6ae933ebd0458ba09c329b63c9ac9
3
+ size 4912107128
model-00003-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8cb8d3f0fb5c526162adc18efcc9e1c13d07088602775e742037a5e53d1531b9
3
- size 3045246736
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90d75a17597fe2351022a4f8638c2509e624b25b1fc08a3d5c5204ec02d7a14b
3
+ size 3938900016
model.safetensors.index.json CHANGED
@@ -1,351 +1,520 @@
1
  {
2
  "metadata": {
3
- "total_size": 12908707844
4
  },
5
  "weight_map": {
6
- "latent_attention_model.cross_attend_blocks.0.fn.to_kv.weight": "model-00001-of-00003.safetensors",
7
- "latent_attention_model.cross_attend_blocks.0.fn.to_out.weight": "model-00001-of-00003.safetensors",
8
- "latent_attention_model.cross_attend_blocks.0.fn.to_q.weight": "model-00001-of-00003.safetensors",
9
- "latent_attention_model.cross_attend_blocks.0.norm.bias": "model-00001-of-00003.safetensors",
10
- "latent_attention_model.cross_attend_blocks.0.norm.weight": "model-00001-of-00003.safetensors",
11
- "latent_attention_model.cross_attend_blocks.0.norm_context.bias": "model-00001-of-00003.safetensors",
12
- "latent_attention_model.cross_attend_blocks.0.norm_context.weight": "model-00001-of-00003.safetensors",
13
- "latent_attention_model.cross_attend_blocks.1.fn.net.0.bias": "model-00001-of-00003.safetensors",
14
- "latent_attention_model.cross_attend_blocks.1.fn.net.0.weight": "model-00001-of-00003.safetensors",
15
- "latent_attention_model.cross_attend_blocks.1.fn.net.2.bias": "model-00001-of-00003.safetensors",
16
- "latent_attention_model.cross_attend_blocks.1.fn.net.2.weight": "model-00001-of-00003.safetensors",
17
- "latent_attention_model.cross_attend_blocks.1.norm.bias": "model-00001-of-00003.safetensors",
18
- "latent_attention_model.cross_attend_blocks.1.norm.weight": "model-00001-of-00003.safetensors",
19
  "latent_attention_model.latents": "model-00001-of-00003.safetensors",
20
- "latent_attention_model.w_lexical.bias": "model-00001-of-00003.safetensors",
21
- "latent_attention_model.w_lexical.weight": "model-00001-of-00003.safetensors",
22
- "latent_attention_model.w_multi_vector.bias": "model-00001-of-00003.safetensors",
23
- "latent_attention_model.w_multi_vector.weight": "model-00001-of-00003.safetensors",
24
  "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
25
  "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
26
  "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
27
  "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
28
  "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
29
  "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
30
- "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
31
  "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
32
  "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
33
- "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
34
  "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
35
  "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
36
  "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
37
  "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
38
  "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
39
- "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
40
  "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
41
  "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
42
- "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
43
  "model.layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
44
  "model.layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
45
  "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
46
  "model.layers.10.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
47
  "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
48
- "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
49
- "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
50
- "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
51
- "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
 
 
52
  "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
53
  "model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
54
  "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
55
  "model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
56
  "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
57
- "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
58
  "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
59
  "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
60
- "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
61
  "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
62
  "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
63
  "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
64
  "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
65
  "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
66
- "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
67
  "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
68
  "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
69
- "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
70
  "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
71
  "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
72
  "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
73
  "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
74
  "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
75
- "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
76
  "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
77
  "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
78
- "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
79
  "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
80
  "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
81
  "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
82
  "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
83
  "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
84
- "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
85
  "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
86
  "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
87
- "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
88
  "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
89
  "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
90
  "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
91
  "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
92
  "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
93
- "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
94
  "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
95
  "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
96
- "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
97
  "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
98
  "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
99
  "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
100
  "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
101
  "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
102
- "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
103
  "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
104
  "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
105
- "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
106
  "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
107
  "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
108
  "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
109
  "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
110
  "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
111
- "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
112
  "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
113
  "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
114
- "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
115
  "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
116
  "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
117
  "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
118
  "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
119
  "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
120
- "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
121
  "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
122
  "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
123
- "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
124
  "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
125
  "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
126
  "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
127
  "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
128
  "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
129
- "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
130
  "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
131
  "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
132
- "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
133
  "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
134
  "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
135
  "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
136
  "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
137
  "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
138
- "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
139
  "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
140
  "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
141
- "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
142
  "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
143
  "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
144
  "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
145
  "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
146
  "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
147
- "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
148
  "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
149
  "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
150
- "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
151
  "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
152
  "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
153
  "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
154
  "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
155
  "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
156
- "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
157
  "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
158
  "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
159
- "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
160
  "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
161
  "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
162
  "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
163
  "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
164
  "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
165
- "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
166
  "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
167
  "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
168
- "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
169
  "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
170
  "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
171
  "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
172
  "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
173
  "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
174
- "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
175
  "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
176
  "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
177
- "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
178
- "model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
179
- "model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
180
- "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
181
- "model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
182
- "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
183
- "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
 
 
184
  "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
185
  "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
186
- "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
187
- "model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
188
- "model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
189
- "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
190
- "model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
191
- "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
192
- "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
193
- "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
194
- "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
195
- "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
 
 
 
 
 
196
  "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
197
  "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
198
  "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
199
  "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
200
  "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
201
- "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
202
- "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
203
- "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
204
- "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
 
 
 
 
 
205
  "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
206
  "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
207
  "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
208
  "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
209
  "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
210
- "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
211
  "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
212
  "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
213
- "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
214
  "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
215
  "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
216
  "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
217
  "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
218
  "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
219
- "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
220
  "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
221
  "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
222
- "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
223
  "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
224
  "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
225
  "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
226
  "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
227
  "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
228
- "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
229
  "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
230
  "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
231
- "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
232
  "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
233
  "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
234
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
235
  "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
236
  "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
237
- "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
238
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
239
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
240
- "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
241
  "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
242
  "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
243
  "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
244
  "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
245
  "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
246
- "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
247
  "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
248
  "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
249
- "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
250
  "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
251
  "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
252
  "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
253
  "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
254
  "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
255
- "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
256
  "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
257
  "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
258
- "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
259
  "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
260
  "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
261
  "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
262
  "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
263
  "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
264
- "model.layers.32.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
265
  "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
266
  "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
267
- "model.layers.32.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
268
  "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
269
  "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
270
  "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
271
  "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
272
  "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
273
- "model.layers.33.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
274
  "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
275
  "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
276
- "model.layers.33.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
277
  "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
278
  "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
279
  "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
280
  "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
281
  "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
282
- "model.layers.34.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
283
  "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
284
  "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
285
- "model.layers.34.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
286
  "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
287
  "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
288
  "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
289
  "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
290
  "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
291
- "model.layers.35.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
 
 
292
  "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
293
  "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
294
- "model.layers.35.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
 
 
 
295
  "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
296
  "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
297
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
298
  "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
299
  "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
300
- "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
301
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
302
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
303
- "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
304
  "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
305
  "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
306
  "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
307
  "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
308
  "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
309
- "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
310
  "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
311
  "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
312
- "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
313
  "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
314
  "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
315
  "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
316
  "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
317
  "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
318
- "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
319
  "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
320
  "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
321
- "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
322
  "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
323
  "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
324
  "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
325
  "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
326
  "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
327
- "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
328
  "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
329
  "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
330
- "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
331
  "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
332
  "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
333
  "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
334
  "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
335
  "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
336
- "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
337
  "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
338
  "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
339
- "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
340
- "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
341
- "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
342
  "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
343
- "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
344
- "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
345
- "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
 
 
346
  "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
347
  "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
348
- "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
349
  "model.norm.weight": "model-00003-of-00003.safetensors"
350
  }
351
  }
 
1
  {
2
  "metadata": {
3
+ "total_size": 13798498304
4
  },
5
  "weight_map": {
6
+ "latent_attention_model.cross_attend_blocks.0.to_kv.weight": "model-00001-of-00003.safetensors",
7
+ "latent_attention_model.cross_attend_blocks.0.to_out.weight": "model-00001-of-00003.safetensors",
8
+ "latent_attention_model.cross_attend_blocks.0.to_q.weight": "model-00001-of-00003.safetensors",
9
+ "latent_attention_model.cross_attend_blocks.1.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "latent_attention_model.cross_attend_blocks.1.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "latent_attention_model.cross_attend_blocks.1.up_proj.weight": "model-00001-of-00003.safetensors",
 
 
 
 
 
 
 
12
  "latent_attention_model.latents": "model-00001-of-00003.safetensors",
 
 
 
 
13
  "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
14
  "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
15
  "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
16
  "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
17
  "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
18
  "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.0.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.0.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.0.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
22
  "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
23
  "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.0.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
25
+ "model.layers.0.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
26
+ "model.layers.0.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.0.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
28
  "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
29
  "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
30
  "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
31
  "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
32
  "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
33
+ "model.layers.1.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
34
+ "model.layers.1.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
35
+ "model.layers.1.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
36
  "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
37
  "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
38
+ "model.layers.1.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
39
+ "model.layers.1.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
40
+ "model.layers.1.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
41
+ "model.layers.1.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
42
  "model.layers.10.input_layernorm.weight": "model-00002-of-00003.safetensors",
43
  "model.layers.10.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
44
  "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
45
  "model.layers.10.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
46
  "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
47
+ "model.layers.10.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
48
+ "model.layers.10.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
49
+ "model.layers.10.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
50
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
51
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
52
+ "model.layers.10.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
53
+ "model.layers.10.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
54
+ "model.layers.10.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
55
+ "model.layers.10.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
56
  "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
57
  "model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
58
  "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
59
  "model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
60
  "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
61
+ "model.layers.11.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
62
+ "model.layers.11.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
63
+ "model.layers.11.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
64
  "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
65
  "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
66
+ "model.layers.11.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.11.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
68
+ "model.layers.11.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
69
+ "model.layers.11.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
70
  "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
71
  "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
72
  "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
73
  "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
74
  "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.12.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.12.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.12.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
78
  "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
79
  "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.12.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.12.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.12.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.12.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
84
  "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
85
  "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
86
  "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
87
  "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
88
  "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.13.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.13.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.13.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
92
  "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
93
  "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.13.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.13.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
96
+ "model.layers.13.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.13.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
98
  "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
99
  "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
100
  "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
101
  "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
102
  "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.14.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.14.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.14.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
106
  "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
107
  "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
108
+ "model.layers.14.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.14.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
110
+ "model.layers.14.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.14.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
112
  "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
113
  "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
114
  "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
115
  "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
116
  "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
117
+ "model.layers.15.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
118
+ "model.layers.15.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
119
+ "model.layers.15.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
120
  "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
121
  "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
122
+ "model.layers.15.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
123
+ "model.layers.15.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
124
+ "model.layers.15.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
125
+ "model.layers.15.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
126
  "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
127
  "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
128
  "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
129
  "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
130
  "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.16.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.16.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.16.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
134
  "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
135
  "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.16.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.16.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.16.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.16.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
140
  "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
141
  "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
142
  "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
143
  "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
144
  "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
145
+ "model.layers.17.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
146
+ "model.layers.17.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
147
+ "model.layers.17.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
148
  "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
149
  "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
150
+ "model.layers.17.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
151
+ "model.layers.17.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
152
+ "model.layers.17.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
153
+ "model.layers.17.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
154
  "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
155
  "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
156
  "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
157
  "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
158
  "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.18.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
160
+ "model.layers.18.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
161
+ "model.layers.18.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
162
  "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
163
  "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
164
+ "model.layers.18.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
165
+ "model.layers.18.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
166
+ "model.layers.18.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
167
+ "model.layers.18.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
168
  "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
169
  "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
170
  "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
171
  "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
172
  "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
173
+ "model.layers.19.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
174
+ "model.layers.19.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
175
+ "model.layers.19.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
176
  "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
177
  "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
178
+ "model.layers.19.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
179
+ "model.layers.19.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.19.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.19.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
182
  "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
183
  "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
184
  "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
185
  "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
186
  "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
187
+ "model.layers.2.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
188
+ "model.layers.2.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
189
+ "model.layers.2.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
190
  "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
191
  "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
192
+ "model.layers.2.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
193
+ "model.layers.2.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
194
+ "model.layers.2.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
195
+ "model.layers.2.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
196
  "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
197
  "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
198
  "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
199
  "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
200
  "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
201
+ "model.layers.20.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
202
+ "model.layers.20.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
203
+ "model.layers.20.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
204
  "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
205
  "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
206
+ "model.layers.20.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
207
+ "model.layers.20.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
208
+ "model.layers.20.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
209
+ "model.layers.20.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
210
  "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
211
  "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
212
  "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
213
  "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
214
  "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
215
+ "model.layers.21.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
216
+ "model.layers.21.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
217
+ "model.layers.21.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
218
  "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
219
  "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
220
+ "model.layers.21.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
221
+ "model.layers.21.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
222
+ "model.layers.21.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
223
+ "model.layers.21.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
224
  "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
225
  "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
226
  "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
227
  "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
228
  "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
229
+ "model.layers.22.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
230
+ "model.layers.22.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
231
+ "model.layers.22.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
232
  "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
233
  "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.22.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.22.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.22.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.22.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
238
  "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
239
  "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
240
  "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
241
  "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
242
  "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
243
+ "model.layers.23.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
244
+ "model.layers.23.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
245
+ "model.layers.23.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
246
  "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
247
  "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
248
+ "model.layers.23.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
249
+ "model.layers.23.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
250
+ "model.layers.23.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
251
+ "model.layers.23.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
252
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
253
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
254
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
255
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
256
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
257
+ "model.layers.24.self_attn.dkv_proj.weight": "model-00002-of-00003.safetensors",
258
+ "model.layers.24.self_attn.kr_proj.weight": "model-00002-of-00003.safetensors",
259
+ "model.layers.24.self_attn.kv_norm.weight": "model-00002-of-00003.safetensors",
260
  "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
261
  "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
262
+ "model.layers.24.self_attn.qk_k_norm.weight": "model-00002-of-00003.safetensors",
263
+ "model.layers.24.self_attn.qk_q_norm.weight": "model-00002-of-00003.safetensors",
264
+ "model.layers.24.self_attn.uk_proj.weight": "model-00002-of-00003.safetensors",
265
+ "model.layers.24.self_attn.uv_proj.weight": "model-00002-of-00003.safetensors",
266
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
267
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
268
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
269
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
270
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
271
+ "model.layers.25.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
272
+ "model.layers.25.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
273
+ "model.layers.25.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
274
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
275
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
276
+ "model.layers.25.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
277
+ "model.layers.25.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
278
+ "model.layers.25.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
279
+ "model.layers.25.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
280
  "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
281
  "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
282
  "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
283
  "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
284
  "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
285
+ "model.layers.26.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
286
+ "model.layers.26.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
287
+ "model.layers.26.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
288
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
289
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
290
+ "model.layers.26.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
291
+ "model.layers.26.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
292
+ "model.layers.26.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
293
+ "model.layers.26.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
294
  "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
295
  "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
296
  "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
297
  "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
298
  "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
299
+ "model.layers.27.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
300
+ "model.layers.27.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
301
+ "model.layers.27.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
302
  "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
303
  "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
304
+ "model.layers.27.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
305
+ "model.layers.27.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
306
+ "model.layers.27.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
307
+ "model.layers.27.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
308
  "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
309
  "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
310
  "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
311
  "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
312
  "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
313
+ "model.layers.28.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
314
+ "model.layers.28.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
315
+ "model.layers.28.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
316
  "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
317
  "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
318
+ "model.layers.28.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
319
+ "model.layers.28.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
320
+ "model.layers.28.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
321
+ "model.layers.28.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
322
  "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
323
  "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
324
  "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
325
  "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
326
  "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
327
+ "model.layers.29.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
328
+ "model.layers.29.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
329
+ "model.layers.29.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
330
  "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
331
  "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
332
+ "model.layers.29.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
333
+ "model.layers.29.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
334
+ "model.layers.29.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
335
+ "model.layers.29.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
336
  "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
337
  "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
338
  "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
339
  "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
340
  "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
341
+ "model.layers.3.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
342
+ "model.layers.3.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
343
+ "model.layers.3.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
344
  "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
345
  "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
346
+ "model.layers.3.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
347
+ "model.layers.3.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
348
+ "model.layers.3.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
349
+ "model.layers.3.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
350
  "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
351
  "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
352
  "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
353
  "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
354
  "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
355
+ "model.layers.30.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
356
+ "model.layers.30.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
357
+ "model.layers.30.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
358
  "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
359
  "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
360
+ "model.layers.30.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
361
+ "model.layers.30.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
362
+ "model.layers.30.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
363
+ "model.layers.30.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
364
  "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
365
  "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
366
  "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
367
  "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
368
  "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
369
+ "model.layers.31.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
370
+ "model.layers.31.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
371
+ "model.layers.31.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
372
  "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
373
  "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
374
+ "model.layers.31.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
375
+ "model.layers.31.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
376
+ "model.layers.31.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
377
+ "model.layers.31.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
378
  "model.layers.32.input_layernorm.weight": "model-00003-of-00003.safetensors",
379
  "model.layers.32.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
380
  "model.layers.32.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
381
  "model.layers.32.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
382
  "model.layers.32.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
383
+ "model.layers.32.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
384
+ "model.layers.32.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
385
+ "model.layers.32.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
386
  "model.layers.32.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
387
  "model.layers.32.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
388
+ "model.layers.32.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
389
+ "model.layers.32.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
390
+ "model.layers.32.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
391
+ "model.layers.32.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
392
  "model.layers.33.input_layernorm.weight": "model-00003-of-00003.safetensors",
393
  "model.layers.33.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
394
  "model.layers.33.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
395
  "model.layers.33.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
396
  "model.layers.33.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.33.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
398
+ "model.layers.33.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
399
+ "model.layers.33.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
400
  "model.layers.33.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
401
  "model.layers.33.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
402
+ "model.layers.33.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
403
+ "model.layers.33.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
404
+ "model.layers.33.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
405
+ "model.layers.33.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
406
  "model.layers.34.input_layernorm.weight": "model-00003-of-00003.safetensors",
407
  "model.layers.34.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
408
  "model.layers.34.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
409
  "model.layers.34.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
410
  "model.layers.34.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
411
+ "model.layers.34.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
412
+ "model.layers.34.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
413
+ "model.layers.34.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
414
  "model.layers.34.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
415
  "model.layers.34.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
416
+ "model.layers.34.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
417
+ "model.layers.34.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
418
+ "model.layers.34.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
419
+ "model.layers.34.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
420
  "model.layers.35.input_layernorm.weight": "model-00003-of-00003.safetensors",
421
  "model.layers.35.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
422
  "model.layers.35.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
423
  "model.layers.35.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
424
  "model.layers.35.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
425
+ "model.layers.35.self_attn.dkv_proj.weight": "model-00003-of-00003.safetensors",
426
+ "model.layers.35.self_attn.kr_proj.weight": "model-00003-of-00003.safetensors",
427
+ "model.layers.35.self_attn.kv_norm.weight": "model-00003-of-00003.safetensors",
428
  "model.layers.35.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
429
  "model.layers.35.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
430
+ "model.layers.35.self_attn.qk_k_norm.weight": "model-00003-of-00003.safetensors",
431
+ "model.layers.35.self_attn.qk_q_norm.weight": "model-00003-of-00003.safetensors",
432
+ "model.layers.35.self_attn.uk_proj.weight": "model-00003-of-00003.safetensors",
433
+ "model.layers.35.self_attn.uv_proj.weight": "model-00003-of-00003.safetensors",
434
  "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
435
  "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
436
  "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
437
  "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
438
  "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
439
+ "model.layers.4.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
440
+ "model.layers.4.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
441
+ "model.layers.4.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
442
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
443
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
444
+ "model.layers.4.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
445
+ "model.layers.4.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
446
+ "model.layers.4.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
447
+ "model.layers.4.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
448
  "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
449
  "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
450
  "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
451
  "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
452
  "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
453
+ "model.layers.5.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
454
+ "model.layers.5.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
455
+ "model.layers.5.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
456
  "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
457
  "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
458
+ "model.layers.5.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
459
+ "model.layers.5.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
460
+ "model.layers.5.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
461
+ "model.layers.5.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
462
  "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
463
  "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
464
  "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
465
  "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
466
  "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
467
+ "model.layers.6.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
468
+ "model.layers.6.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
469
+ "model.layers.6.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
470
  "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
471
  "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
472
+ "model.layers.6.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
473
+ "model.layers.6.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
474
+ "model.layers.6.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
475
+ "model.layers.6.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
476
  "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
477
  "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
478
  "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
479
  "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
480
  "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
481
+ "model.layers.7.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
482
+ "model.layers.7.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
483
+ "model.layers.7.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
484
  "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
485
  "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
486
+ "model.layers.7.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
487
+ "model.layers.7.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
488
+ "model.layers.7.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
489
+ "model.layers.7.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
490
  "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
491
  "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
492
  "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
493
  "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
494
  "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
495
+ "model.layers.8.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
496
+ "model.layers.8.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
497
+ "model.layers.8.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
498
  "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
499
  "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
500
+ "model.layers.8.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
501
+ "model.layers.8.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
502
+ "model.layers.8.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
503
+ "model.layers.8.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
504
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00003.safetensors",
505
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
506
  "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
507
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
508
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
509
+ "model.layers.9.self_attn.dkv_proj.weight": "model-00001-of-00003.safetensors",
510
+ "model.layers.9.self_attn.kr_proj.weight": "model-00001-of-00003.safetensors",
511
+ "model.layers.9.self_attn.kv_norm.weight": "model-00001-of-00003.safetensors",
512
  "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
513
  "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
514
+ "model.layers.9.self_attn.qk_k_norm.weight": "model-00001-of-00003.safetensors",
515
+ "model.layers.9.self_attn.qk_q_norm.weight": "model-00001-of-00003.safetensors",
516
+ "model.layers.9.self_attn.uk_proj.weight": "model-00001-of-00003.safetensors",
517
+ "model.layers.9.self_attn.uv_proj.weight": "model-00001-of-00003.safetensors",
518
  "model.norm.weight": "model-00003-of-00003.safetensors"
519
  }
520
  }
modeling_gigarembed.py CHANGED
@@ -1,139 +1,658 @@
1
- from typing import List, Union, Dict, Mapping, Optional, Tuple, TypedDict
 
 
 
2
  import torch
3
- import os
4
- import json
5
  import numpy as np
6
  import torch.nn.functional as F
7
 
8
- from functools import partial
9
- from contextlib import nullcontext
10
- from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
 
 
 
 
 
 
 
 
 
 
11
  from transformers.modeling_utils import PreTrainedModel
12
- from transformers.models.auto import AutoTokenizer
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.modeling_outputs import BaseModelOutputWithPast
15
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
16
- from transformers import LlamaModel, LlamaConfig
17
- from transformers.cache_utils import Cache, DynamicCache
18
- from transformers.utils import (
19
- add_start_docstrings_to_model_forward,
20
- logging,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
- from einops import rearrange, repeat
23
- from tqdm.auto import tqdm
24
- from datasets import Dataset
25
- from torch.utils.data import DataLoader
26
- from .configuration_gigarembed import GigarEmbedConfig, LatentAttentionConfig, BidirectionalLlamaConfig
 
 
 
 
 
 
 
27
 
28
- logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
29
 
30
- class GigarEmbedFeatures(TypedDict):
31
- input_dict: torch.Tensor
32
- attention_mask: torch.Tensor
33
- pool_mask: torch.Tensor
34
 
35
- class BidirectionalLlamaModel(LlamaModel):
36
- config_class = BidirectionalLlamaConfig
37
-
38
- def __init__(self, config: LlamaConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  super().__init__(config)
40
- for layer in self.layers:
41
- layer.self_attn.is_causal = False
42
- self._attn_implementation = "eager"
43
 
44
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def forward(
46
  self,
47
  input_ids: torch.LongTensor = None,
48
  attention_mask: Optional[torch.Tensor] = None,
49
  position_ids: Optional[torch.LongTensor] = None,
50
- past_key_values: Optional[List[torch.FloatTensor]] = None,
51
  inputs_embeds: Optional[torch.FloatTensor] = None,
52
  use_cache: Optional[bool] = None,
53
  output_attentions: Optional[bool] = None,
54
  output_hidden_states: Optional[bool] = None,
55
  return_dict: Optional[bool] = None,
 
 
56
  ) -> Union[Tuple, BaseModelOutputWithPast]:
57
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
58
  output_hidden_states = (
59
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
60
  )
61
  use_cache = use_cache if use_cache is not None else self.config.use_cache
62
-
63
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
 
65
- # retrieve input_ids and inputs_embeds
66
- if input_ids is not None and inputs_embeds is not None:
67
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
68
- elif input_ids is not None:
69
- batch_size, seq_length = input_ids.shape
70
- elif inputs_embeds is not None:
71
- batch_size, seq_length, _ = inputs_embeds.shape
72
- else:
73
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
74
-
75
- if self.gradient_checkpointing and self.training:
76
- if use_cache:
77
- logger.warning_once(
78
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
79
- )
80
- use_cache = False
81
-
82
- past_key_values_length = 0
83
 
84
- if use_cache:
85
- use_legacy_cache = not isinstance(past_key_values, Cache)
86
- if use_legacy_cache:
87
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
88
- past_key_values_length = past_key_values.get_usable_length(seq_length)
89
-
90
- if position_ids is None:
91
- device = input_ids.device if input_ids is not None else inputs_embeds.device
92
- position_ids = torch.arange(
93
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
94
  )
95
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
96
- else:
97
- position_ids = position_ids.view(-1, seq_length).long()
98
 
99
  if inputs_embeds is None:
100
  inputs_embeds = self.embed_tokens(input_ids)
101
 
102
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
103
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
104
- if is_padding_right:
105
- raise ValueError(
106
- "You are attempting to perform batched generation with padding_side='right'"
107
- " this may lead to unexpected behaviour for Flash Attention version of Llama. Make sure to "
108
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
109
- )
110
 
111
- if self._attn_implementation == "flash_attention_2":
112
- # 2d mask is passed through the layers
113
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
114
- elif self._attn_implementation == "sdpa" and not output_attentions:
115
- # output_attentions=True can not be supported when using SDPA, and we fall back on
116
- # the manual implementation that requires a 4D causal mask in all cases.
117
- attention_mask = _prepare_4d_attention_mask_for_sdpa(
118
- attention_mask, inputs_embeds.dtype
119
- )
120
- else:
121
- # 4d mask is passed through the layers
122
- attention_mask = _prepare_4d_attention_mask(
123
- attention_mask, inputs_embeds.dtype,
124
  )
125
 
126
- hidden_states = inputs_embeds
 
127
 
 
 
 
 
128
  # create position embeddings to be shared across the decoder layers
129
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
130
 
131
  # decoder layers
132
  all_hidden_states = () if output_hidden_states else None
133
  all_self_attns = () if output_attentions else None
134
- next_decoder_cache = None
135
 
136
- for decoder_layer in self.layers:
137
  if output_hidden_states:
138
  all_hidden_states += (hidden_states,)
139
 
@@ -141,29 +660,29 @@ class BidirectionalLlamaModel(LlamaModel):
141
  layer_outputs = self._gradient_checkpointing_func(
142
  decoder_layer.__call__,
143
  hidden_states,
144
- attention_mask,
145
  position_ids,
146
  past_key_values,
147
  output_attentions,
148
  use_cache,
149
- position_embeddings=position_embeddings
 
150
  )
151
  else:
152
  layer_outputs = decoder_layer(
153
  hidden_states,
154
- attention_mask=attention_mask,
155
  position_ids=position_ids,
156
  past_key_value=past_key_values,
157
  output_attentions=output_attentions,
158
  use_cache=use_cache,
159
- position_embeddings=position_embeddings
 
 
160
  )
161
 
162
  hidden_states = layer_outputs[0]
163
 
164
- if use_cache:
165
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
166
-
167
  if output_attentions:
168
  all_self_attns += (layer_outputs[1],)
169
 
@@ -173,277 +692,463 @@ class BidirectionalLlamaModel(LlamaModel):
173
  if output_hidden_states:
174
  all_hidden_states += (hidden_states,)
175
 
176
- next_cache = None
177
- if use_cache:
178
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
179
-
180
- if not return_dict:
181
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
182
- return BaseModelOutputWithPast(
183
  last_hidden_state=hidden_states,
184
- past_key_values=next_cache,
185
  hidden_states=all_hidden_states,
186
  attentions=all_self_attns,
187
  )
 
188
 
189
- def _move_to_device(maybe_tensor, device: torch.device):
190
- if torch.is_tensor(maybe_tensor):
191
- return maybe_tensor.to(device, non_blocking=device.type == "cuda")
192
- elif isinstance(maybe_tensor, dict):
193
- return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()}
194
- elif isinstance(maybe_tensor, list):
195
- return [_move_to_device(x, device) for x in maybe_tensor]
196
- elif isinstance(maybe_tensor, tuple):
197
- return tuple([_move_to_device(x, device) for x in maybe_tensor])
198
- elif isinstance(maybe_tensor, Mapping):
199
- return type(maybe_tensor)({k: _move_to_device(v, device) for k, v in maybe_tensor.items()})
200
- else:
201
- return maybe_tensor
202
-
203
- def move_to_device(sample, device: torch.device):
204
- if device.type == "cpu":
205
- return sample
206
-
207
- if len(sample) == 0:
208
- return {}
209
- return _move_to_device(sample, device)
210
-
211
-
212
- def input_transform_func(
213
- tokenizer: PreTrainedTokenizerFast,
214
- examples: Dict[str, List],
215
- max_length: int,
216
- instruction: str,
217
- ) -> BatchEncoding:
218
- examples['input_texts'] = [instruction + input_example for input_example in examples['input_texts']]
219
- batch_dict = tokenizer(
220
- examples['input_texts'],
221
- max_length=max_length,
222
- padding=True,
223
- return_token_type_ids=False,
224
- return_tensors="pt",
225
- truncation=True)
226
- return batch_dict
227
-
228
-
229
- class PreNorm(torch.nn.Module):
230
- def __init__(self, dim, fn, context_dim = None):
231
- super().__init__()
232
- self.fn = fn
233
- self.norm = torch.nn.LayerNorm(dim)
234
- self.norm_context = torch.nn.LayerNorm(context_dim) if exists(context_dim) else None
235
-
236
- def forward(self, x, **kwargs):
237
- x = self.norm(x)
238
- if exists(self.norm_context):
239
- context = kwargs['context']
240
- normed_context = self.norm_context(context)
241
- kwargs.update(context = normed_context)
242
- return self.fn(x, **kwargs)
243
-
244
- class GEGLU(torch.nn.Module):
245
- def forward(self, x):
246
- x, gates = x.chunk(2, dim = -1)
247
- return x * F.gelu(gates)
248
 
249
- class FeedForward(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def __init__(self, dim, mult = 4):
251
  super().__init__()
252
- self.net = torch.nn.Sequential(
253
- torch.nn.Linear(dim, 2 * dim * mult),
254
- GEGLU(),
255
- torch.nn.Linear(dim * mult, dim)
256
- )
 
257
 
258
  def forward(self, x):
259
- return self.net(x)
260
 
261
- def exists(val):
262
- return val is not None
263
 
264
- def default(val, d):
265
- return val if exists(val) else d
 
 
 
266
 
 
 
267
 
268
- class Attention(torch.nn.Module):
269
- def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
270
- super().__init__()
271
- inner_dim = dim_head * heads
272
- context_dim = default(context_dim, query_dim)
273
- self.scale = dim_head ** -0.5
274
- self.heads = heads
275
 
276
- self.to_q = torch.nn.Linear(query_dim, inner_dim, bias = False)
277
- self.to_kv = torch.nn.Linear(context_dim, inner_dim * 2, bias = False)
278
- self.to_out = torch.nn.Linear(inner_dim, query_dim, bias = False)
279
 
280
- def forward(self, x, context = None, mask = None):
281
- h = self.heads
282
- q = self.to_q(x)
283
- context = default(context, x)
284
- k, v = self.to_kv(context).chunk(2, dim = -1)
285
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
286
 
287
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
288
- out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
289
- out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
290
- return self.to_out(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  class LatentAttentionModel(PreTrainedModel):
294
  config_class = LatentAttentionConfig
295
 
296
- def __init__(self, config: LatentAttentionConfig):
297
- super().__init__(config)
298
- ## cross-attention block
299
- num_latents, latent_dim, cross_heads, cross_dim_head = config.num_latents_value, config.latent_dim, config.num_cross_heads, config.cross_dim_head
300
- dim = config.hidden_dim
301
- # init latent_attention and latents
302
- self.cross_attend_blocks = torch.nn.ModuleList([
303
- PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head),
304
- context_dim = dim),
305
- PreNorm(latent_dim, FeedForward(latent_dim)),
 
 
 
 
 
 
 
 
 
306
  ])
307
 
308
- self.w_lexical = torch.nn.Linear(latent_dim, 1)
309
- self.w_multi_vector = torch.nn.Linear(latent_dim, latent_dim)
310
-
311
- # self.output_normalize = config.output_normalize
312
- self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
313
- self._attn_implementation = "eager"
314
-
315
- def forward(self, hiddens, attention_mask: torch.Tensor=None):
316
- # cross-attention block
317
- cross_attn, cross_ff = self.cross_attend_blocks
318
- b, *_, device = *hiddens.shape, hiddens.device
319
- x = repeat(self.latents, 'n d -> b n d', b = b)
320
- output = cross_attn(hiddens, context=x, mask=attention_mask) + hiddens
321
- output = cross_ff(output) + output
322
- if attention_mask != None:
323
- s = torch.sum(output * attention_mask.unsqueeze(-1), dim=1)
324
- d = attention_mask.sum(dim=1, keepdim=True)
325
- output = s / d
326
- output = F.normalize(output, p=2, dim=-1)
327
- return output
328
 
329
  class GigarEmbedModel(PreTrainedModel):
330
  config_class = GigarEmbedConfig
331
- _no_split_modules = ["LlamaDecoderLayer", "LatentAttentionModel"]
 
332
 
333
- def __init__(self, config: GigarEmbedConfig):
334
- super().__init__(config)
335
- self.latent_attention_model = AutoModel.from_config(config.latent_attention_config)
336
- self.model = AutoModel.from_config(
337
- config.text_config,
338
- ) if config.text_config is not None else None
339
- self.tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path) if config.text_config is not None else None
340
- self.padding_side = config.padding_side
341
- self.is_mask_instruction = config.is_mask_instruction
342
- self.add_eos = config.add_eos
343
- self.mask_type = config.mask_type
344
- if config.add_pad_token and self.tokenizer is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  self.add_pad_token()
346
 
347
  def add_pad_token(self):
348
  self.tokenizer.pad_token_id = 0
349
  self.tokenizer.padding_side = self.padding_side
350
-
351
- def prepare_kwargs_from_batch(self, batch_dict: dict, instruction_lens: int, device: torch.device):
352
- batch_dict = move_to_device(batch_dict, device)
353
- attention_mask = batch_dict['attention_mask'].clone() if 'attention_mask' in batch_dict else None
354
- if (attention_mask is not None and
355
- self.padding_side == "right" and
356
- self.is_mask_instruction == True and
357
- instruction_lens > 0):
358
- # Mask out the instruction tokens for mean-pooling
359
- attention_mask[:, :instruction_lens] = 0
360
- features: GigarEmbedFeatures = {
361
- 'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
362
- 'attention_mask': batch_dict['attention_mask'],
363
- 'pool_mask': attention_mask,
364
- }
365
- return features
366
 
367
- @torch.no_grad()
368
- def _do_encode(self,
369
- prompts: List[str],
370
- batch_size: int=1,
371
- instruction: str="",
372
- max_length: int=4096,
373
- num_workers: int=32,
374
- **kwargs
375
- ) -> Union[np.ndarray, torch.FloatTensor]:
376
- dataset: Dataset = Dataset.from_dict({'input_texts': prompts})
377
- dataset.set_transform(partial(input_transform_func,
378
- self.tokenizer,
379
- max_length=max_length,
380
- instruction=instruction))
381
-
382
- data_collator = DataCollatorWithPadding(self.tokenizer)
383
- data_loader = DataLoader(
384
- dataset,
385
- batch_size=batch_size,
386
- shuffle=False,
387
- drop_last=False,
388
- num_workers=num_workers,
389
- collate_fn=data_collator,
390
- pin_memory=True)
391
-
392
- if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
393
- instruction_lens = len(self.tokenizer.tokenize(instruction))
394
- else:
395
- instruction_lens = 0
396
-
397
- encoded_embeds = []
398
- device = next(self.model.parameters()).device
399
- for batch_dict in tqdm(data_loader, desc='encoding', mininterval=10):
400
- features = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
401
- embeds=self(**features)["sentence_embeddings"].squeeze(1)
402
- encoded_embeds.append(embeds)
403
- encoded_embeds = torch.cat(encoded_embeds, axis=0)
404
- if "return_numpy" in kwargs and kwargs.get("return_numpy"):
405
- encoded_embeds = encoded_embeds.cpu().detach().numpy()
406
- return encoded_embeds
407
-
408
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None,
409
- return_dict: bool=True, **kwargs):
410
  kwargs.pop('token_type_ids', None)
411
 
412
  with torch.autocast('cuda', dtype=torch.bfloat16):
413
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
414
 
415
- if pool_mask is None: pool_mask = attention_mask.clone()
416
 
417
- embeds = self.latent_attention_model(outputs.last_hidden_state, pool_mask)
 
418
 
419
- if not return_dict:
420
- return (embeds,)
421
- return {"sentence_embeddings": embeds}
422
-
423
-
424
- @torch.no_grad()
425
- def encode(self, prompts: List[str], instruction: str="", max_length: int=4096, **kwargs):
426
- if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
427
- instruction_lens = len(self.tokenizer.tokenize(instruction))
428
- else:
429
- instruction_lens = 0
430
-
431
- device = next(self.model.parameters()).device
432
- batch_dict = input_transform_func(self.tokenizer,
433
- {"input_texts": [prompt for prompt in prompts]},
434
- max_length=max_length,
435
- instruction=instruction)
436
 
437
- features: GigarEmbedFeatures = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
438
- return self(**features)["sentence_embeddings"].squeeze(1)
 
 
439
 
440
 
441
  ## AutoModel Register
 
442
  AutoModel.register(GigarEmbedConfig, GigarEmbedModel)
443
  AutoModel.register(LatentAttentionConfig, LatentAttentionModel)
444
- AutoModel.register(BidirectionalLlamaConfig, BidirectionalLlamaModel)
445
 
446
  ## Register for auto class
 
447
  GigarEmbedModel.register_for_auto_class("AutoModel")
448
  LatentAttentionModel.register_for_auto_class("AutoModel")
449
- BidirectionalLlamaModel.register_for_auto_class("AutoModel")
 
1
+ import copy
2
+ import logging
3
+ from typing import Callable, List, Optional, Tuple, Union, Mapping
4
+
5
  import torch
6
+ import torch.nn as nn
 
7
  import numpy as np
8
  import torch.nn.functional as F
9
 
10
+ from einops import rearrange, repeat
11
+ from transformers import AutoModel, AutoTokenizer
12
+
13
+ from transformers.cache_utils import Cache
14
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.cache_utils import DynamicCache, StaticCache
18
+ from transformers.generation import GenerationMixin
19
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
20
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
21
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
22
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.processing_utils import Unpack
25
+ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
26
+
27
+ from .configuration_gigarembed import GigarConfig, GigarEmbedConfig, LatentAttentionConfig
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+ _CONFIG_FOR_DOC = "GigarEmbedConfig"
32
+
33
+
34
+ class GigarMLP(nn.Module):
35
+ def __init__(self, config):
36
+ super().__init__()
37
+ self.config = config
38
+ self.hidden_size = config.hidden_size
39
+ self.intermediate_size = config.intermediate_size
40
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
41
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
42
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
43
+ self.act_fn = ACT2FN[config.hidden_act]
44
+
45
+ def forward(self, x):
46
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
47
+ return down_proj
48
+
49
+
50
+ class GigarRMSNorm(nn.Module):
51
+ def __init__(self, hidden_size, eps=1e-6):
52
+ """
53
+ GigarRMSNorm is equivalent to T5LayerNorm
54
+ """
55
+ super().__init__()
56
+ self.weight = nn.Parameter(torch.ones(hidden_size))
57
+ self.variance_epsilon = eps
58
+
59
+ def forward(self, hidden_states):
60
+ input_dtype = hidden_states.dtype
61
+ hidden_states = hidden_states.to(torch.float32)
62
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
63
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
64
+ return self.weight * hidden_states.to(input_dtype)
65
+
66
+ def extra_repr(self):
67
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
68
+
69
+
70
+ def rotate_half(x):
71
+ """Rotates half the hidden dims of the input."""
72
+ x1 = x[..., : x.shape[-1] // 2]
73
+ x2 = x[..., x.shape[-1] // 2 :]
74
+ return torch.cat((-x2, x1), dim=-1)
75
+
76
+
77
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
78
+ """Applies Rotary Position Embedding to the query and key tensors.
79
+
80
+ Args:
81
+ q (`torch.Tensor`): The query tensor.
82
+ k (`torch.Tensor`): The key tensor.
83
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
84
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
85
+ position_ids (`torch.Tensor`, *optional*):
86
+ Deprecated and unused.
87
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
88
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
89
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
90
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
91
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
92
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
93
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
94
+ Returns:
95
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
96
+ """
97
+ cos = cos.unsqueeze(unsqueeze_dim)
98
+ sin = sin.unsqueeze(unsqueeze_dim)
99
+ q_embed = (q * cos) + (rotate_half(q) * sin)
100
+ k_embed = (k * cos) + (rotate_half(k) * sin)
101
+ return q_embed, k_embed
102
+
103
+
104
+ from flash_attn.layers.rotary import ApplyRotaryEmb
105
+
106
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
107
+ q = q.transpose(1, 2).to(cos.dtype) # make it [batch_size, seqlen, nheads, headdim]
108
+ k = k.transpose(1, 2).to(cos.dtype) # make it [batch_size, seqlen, nheads, headdim]
109
+ rotary_dim = cos.shape[-1]
110
+ cos = cos.squeeze(1).squeeze(0)[...,: rotary_dim // 2] # [seq_len, dim // 2]
111
+ sin = sin.squeeze(1).squeeze(0)[...,: rotary_dim // 2] # [seq_len, dim // 2]
112
+ return ApplyRotaryEmb.apply(q, cos, sin).transpose(1, 2), ApplyRotaryEmb.apply(k, cos, sin).transpose(1, 2)
113
+
114
+
115
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
116
+ """
117
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
118
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
119
+ """
120
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
121
+ if n_rep == 1:
122
+ return hidden_states
123
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
124
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
125
+
126
+
127
+ def eager_attention_forward(
128
+ module: nn.Module,
129
+ query: torch.Tensor,
130
+ key: torch.Tensor,
131
+ value: torch.Tensor,
132
+ attention_mask: Optional[torch.Tensor],
133
+ scaling: float,
134
+ dropout: float = 0.0,
135
+ **kwargs,
136
+ ):
137
+ key_states = repeat_kv(key, module.num_key_value_groups)
138
+ value_states = repeat_kv(value, module.num_key_value_groups)
139
+
140
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
141
+ if attention_mask is not None:
142
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
143
+ attn_weights = attn_weights + causal_mask
144
+
145
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
146
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
147
+ attn_output = torch.matmul(attn_weights, value_states)
148
+ attn_output = attn_output.transpose(1, 2).contiguous()
149
+
150
+ return attn_output, attn_weights
151
+
152
+
153
+ class GigarLatentAttention(nn.Module):
154
+ """
155
+ Multi-headed Latent Attention (MLA)
156
+
157
+ Check out the original paper: https://arxiv.org/pdf/2405.04434,
158
+ and the reference implementation: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
159
+ """
160
+
161
+ def __init__(self, config: GigarConfig, layer_idx: Optional[int] = None):
162
+ super().__init__()
163
+ self.config = config
164
+ self.hidden_size = config.hidden_size
165
+ self.num_heads = config.num_attention_heads
166
+ self.layer_idx = layer_idx
167
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
168
+
169
+ assert config.num_attention_heads == config.num_key_value_heads, (
170
+ "GQA for MLA is not supported (does it even make sense?)"
171
+ )
172
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
173
+
174
+ self.max_position_embeddings = config.max_position_embeddings
175
+ self.rope_theta = config.rope_theta
176
+ self.apply_qk_norm = config.apply_qk_norm
177
+ self.attention_dropout = config.attention_dropout
178
+
179
+ assert config.mla_config is not None
180
+ self.qk_nope_head_dim = config.mla_config["qk_nope_head_dim"]
181
+ self.qk_rope_head_dim = config.mla_config["qk_rope_head_dim"]
182
+ self.v_head_dim = config.mla_config["v_head_dim"] # V has no rope part
183
+ self.kv_lora_rank = config.mla_config["kv_lora_rank"]
184
+ self.q_lora_rank = config.mla_config["q_lora_rank"]
185
+
186
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
187
+
188
+ self.scaling = self.qk_head_dim**-0.5
189
+
190
+ if self.q_lora_rank == 0:
191
+ self.q_proj = nn.Linear(
192
+ self.hidden_size,
193
+ self.num_heads * self.qk_head_dim,
194
+ bias=config.attention_bias,
195
+ )
196
+ else:
197
+ self.dq_proj = nn.Linear(
198
+ self.hidden_size,
199
+ self.q_lora_rank,
200
+ bias=config.attention_bias,
201
+ )
202
+ self.q_norm = GigarRMSNorm(self.q_lora_rank)
203
+ self.uq_proj = nn.Linear(
204
+ self.q_lora_rank,
205
+ self.num_heads * self.qk_head_dim,
206
+ bias=config.attention_bias,
207
+ )
208
+
209
+ self.kv_norm = GigarRMSNorm(self.kv_lora_rank)
210
+ self.dkv_proj = nn.Linear(
211
+ self.hidden_size,
212
+ self.kv_lora_rank,
213
+ bias=config.attention_bias,
214
+ )
215
+ self.uk_proj = nn.Linear(
216
+ config.kv_lora_rank,
217
+ self.num_heads * self.qk_nope_head_dim,
218
+ bias=config.attention_bias,
219
+ )
220
+ self.uv_proj = nn.Linear(
221
+ config.kv_lora_rank,
222
+ self.num_heads * self.v_head_dim,
223
+ bias=config.attention_bias,
224
+ )
225
+ self.kr_proj = nn.Linear(
226
+ self.hidden_size,
227
+ self.num_heads * self.qk_rope_head_dim,
228
+ bias=config.attention_bias,
229
+ )
230
+
231
+ self.o_proj = nn.Linear(
232
+ self.num_heads * self.v_head_dim,
233
+ self.hidden_size,
234
+ bias=config.attention_bias,
235
+ )
236
+
237
+ if self.apply_qk_norm:
238
+ self.qk_q_norm = nn.LayerNorm(self.num_heads * self.qk_head_dim, bias=False)
239
+ self.qk_k_norm = nn.LayerNorm(self.num_heads * self.qk_head_dim, bias=False)
240
+
241
+ config_for_rope = copy.copy(self.config)
242
+ config_for_rope.head_dim = self.config.qk_rope_head_dim
243
+ # self.rotary_emb = GigarRotaryEmbedding(config_for_rope)
244
+
245
+ # self._compute_qkv = torch.compile(self._compute_qkv)
246
+ self.is_causal = False
247
+
248
+ def _compute_qkv(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ ):
252
+ """Compute query, key, and value tensors from hidden states."""
253
+ bsz, seq_len, _ = hidden_states.size()
254
+
255
+ if self.q_lora_rank == 0:
256
+ query = self.q_proj(hidden_states)
257
+ else:
258
+ query = self.uq_proj(self.q_norm(self.dq_proj(hidden_states)))
259
+
260
+ latent = self.dkv_proj(hidden_states)
261
+ latent = self.kv_norm(latent)
262
+ k_rope = self.kr_proj(hidden_states)
263
+
264
+ k_nope = self.uk_proj(latent)
265
+ value = self.uv_proj(latent)
266
+
267
+ if self.apply_qk_norm:
268
+ query = self.qk_q_norm(query).to(query.dtype)
269
+ key = self.qk_k_norm(torch.cat([k_nope, k_rope], dim=-1)).to(k_nope.dtype)
270
+ k_nope, k_rope = torch.split(key, [k_nope.shape[-1], k_rope.shape[-1]], dim=-1)
271
+
272
+ # Reshape tensors
273
+ query = query.view(bsz, seq_len, self.num_heads, self.qk_head_dim).transpose(1, 2)
274
+ k_nope = k_nope.view(bsz, seq_len, self.num_heads, self.qk_nope_head_dim).transpose(1, 2)
275
+ k_rope = k_rope.view(bsz, seq_len, self.num_heads, self.qk_rope_head_dim).transpose(1, 2)
276
+ value = value.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)
277
+
278
+ q_nope, q_rope = torch.split(query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
279
+
280
+ return q_nope, q_rope, k_nope, k_rope, value
281
+
282
+ def forward(
283
+ self,
284
+ hidden_states: torch.Tensor,
285
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
286
+ attention_mask: Optional[torch.Tensor],
287
+ past_key_value: Optional[Cache] = None,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ **kwargs,
290
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
291
+ """
292
+ hidden_states: [bsz, seq_len, hidden_size]
293
+ attention_mask: [bsz, seq_len]
294
+ """
295
+ batch_size, seq_len, _ = hidden_states.size()
296
+
297
+ q_nope, q_rope, k_nope, k_rope, value_states = self._compute_qkv(hidden_states)
298
+
299
+ # cos, sin = self.rotary_emb(q_rope, seq_len=seq_len)
300
+ cos, sin = position_embeddings
301
+ q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos, sin)
302
+ query_states = torch.cat([q_nope, q_rope], dim=-1)
303
+ key_states = torch.cat([k_nope, k_rope], dim=-1)
304
+
305
+ if past_key_value is not None:
306
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
307
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
308
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
309
+
310
+ attention_interface: Callable = eager_attention_forward
311
+ if self.config._attn_implementation != "eager":
312
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
313
+
314
+ attn_output, attn_weights = attention_interface(
315
+ self,
316
+ query_states,
317
+ key_states,
318
+ value_states,
319
+ attention_mask,
320
+ dropout=0.0 if not self.training else self.attention_dropout,
321
+ scaling=self.scaling,
322
+ **kwargs,
323
+ )
324
+
325
+ attn_output = attn_output.reshape(batch_size, seq_len, -1).contiguous()
326
+ attn_output = self.o_proj(attn_output)
327
+
328
+ return attn_output, attn_weights
329
+
330
+
331
+ class GigarDecoderLayer(nn.Module):
332
+ def __init__(self, config: GigarConfig, layer_idx: Optional[int] = None):
333
+ super().__init__()
334
+ self.hidden_size = config.hidden_size
335
+
336
+ self.self_attn = GigarLatentAttention(config, layer_idx)
337
+ self.mlp = GigarMLP(config)
338
+ self.input_layernorm = GigarRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
339
+ self.post_attention_layernorm = GigarRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
340
+
341
+ def forward(
342
+ self,
343
+ hidden_states: torch.Tensor,
344
+ attention_mask: Optional[torch.Tensor] = None,
345
+ position_ids: Optional[torch.LongTensor] = None,
346
+ past_key_value: Optional[Cache] = None,
347
+ output_attentions: Optional[bool] = False,
348
+ use_cache: Optional[bool] = False,
349
+ cache_position: Optional[torch.LongTensor] = None,
350
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
351
+ **kwargs: Unpack[FlashAttentionKwargs],
352
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
353
+ residual = hidden_states
354
+
355
+ hidden_states = self.input_layernorm(hidden_states)
356
+
357
+ # Self Attention
358
+ hidden_states, self_attn_weights = self.self_attn(
359
+ hidden_states=hidden_states,
360
+ attention_mask=attention_mask,
361
+ position_ids=position_ids,
362
+ past_key_value=past_key_value,
363
+ output_attentions=output_attentions,
364
+ use_cache=use_cache,
365
+ cache_position=cache_position,
366
+ position_embeddings=position_embeddings,
367
+ **kwargs,
368
+ )
369
+ hidden_states = residual + hidden_states
370
+
371
+ # Fully Connected
372
+ residual = hidden_states
373
+ hidden_states = self.post_attention_layernorm(hidden_states)
374
+ hidden_states = self.mlp(hidden_states)
375
+ hidden_states = residual + hidden_states
376
+
377
+ outputs = (hidden_states,)
378
+ if output_attentions:
379
+ outputs += (self_attn_weights,)
380
+
381
+ return outputs
382
+
383
+
384
+ class GigarRotaryEmbedding(nn.Module):
385
+ def __init__(self, config: GigarConfig, device=None):
386
+ super().__init__()
387
+ # BC: "rope_type" was originally "type"
388
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
389
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
390
+ else:
391
+ self.rope_type = "default"
392
+ self.max_seq_len_cached = config.max_position_embeddings
393
+ self.original_max_seq_len = config.max_position_embeddings
394
+
395
+ self.config = config
396
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
397
+
398
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
399
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
400
+ self.original_inv_freq = self.inv_freq
401
+
402
+ def _dynamic_frequency_update(self, position_ids, device):
403
+ """
404
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
405
+ 1 - growing beyond the cached sequence length (allow scaling)
406
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
407
+ """
408
+ seq_len = torch.max(position_ids) + 1
409
+ if seq_len > self.max_seq_len_cached: # growth
410
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
411
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
412
+ self.max_seq_len_cached = seq_len
413
+
414
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
415
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
416
+ self.max_seq_len_cached = self.original_max_seq_len
417
+
418
+ @torch.no_grad()
419
+ def forward(self, x, position_ids):
420
+ if "dynamic" in self.rope_type:
421
+ self._dynamic_frequency_update(position_ids, device=x.device)
422
+
423
+ # Core RoPE block
424
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
425
+ position_ids_expanded = position_ids[:, None, :].float()
426
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
427
+ device_type = x.device.type
428
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
429
+ with torch.autocast(device_type=device_type, enabled=False):
430
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
431
+ emb = torch.cat((freqs, freqs), dim=-1)
432
+ cos = emb.cos()
433
+ sin = emb.sin()
434
+
435
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
436
+ cos = cos * self.attention_scaling
437
+ sin = sin * self.attention_scaling
438
+
439
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
440
+
441
+
442
+ GIGAR_START_DOCSTRING = r"""
443
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
444
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
445
+ etc.)
446
+
447
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
448
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
449
+ and behavior.
450
+
451
+ Parameters:
452
+ config ([`GigarConfig`]):
453
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
454
+ load the weights associated with the model, only the configuration. Check out the
455
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
456
+ """
457
+
458
+
459
+ @add_start_docstrings(
460
+ "The bare Gigar Model outputting raw hidden-states without any specific head on top.",
461
+ GIGAR_START_DOCSTRING,
462
  )
463
+ class GigarPreTrainedModel(PreTrainedModel):
464
+ config_class = GigarConfig
465
+ base_model_prefix = "model"
466
+ supports_gradient_checkpointing = True
467
+ _no_split_modules = ["GigarDecoderLayer"]
468
+ _skip_keys_device_placement = ["past_key_values"]
469
+ _supports_flash_attn_2 = True
470
+ _supports_sdpa = True
471
+ _supports_flex_attn = True
472
+ _supports_cache_class = True
473
+ _supports_quantized_cache = True
474
+ _supports_static_cache = True
475
 
476
+ def _init_weights(self, module):
477
+ std = self.config.initializer_range
478
+ if isinstance(module, nn.Linear):
479
+ module.weight.data.normal_(mean=0.0, std=std)
480
+ if module.bias is not None:
481
+ module.bias.data.zero_()
482
+ elif isinstance(module, nn.Embedding):
483
+ module.weight.data.normal_(mean=0.0, std=std)
484
+ if module.padding_idx is not None:
485
+ module.weight.data[module.padding_idx].zero_()
486
 
 
 
 
 
487
 
488
+ GIGAR_INPUTS_DOCSTRING = r"""
489
+ Args:
490
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
491
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
492
+ it.
493
+
494
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
495
+ [`PreTrainedTokenizer.__call__`] for details.
496
+
497
+ [What are input IDs?](../glossary#input-ids)
498
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
499
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
500
+
501
+ - 1 for tokens that are **not masked**,
502
+ - 0 for tokens that are **masked**.
503
+
504
+ [What are attention masks?](../glossary#attention-mask)
505
+
506
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
507
+ [`PreTrainedTokenizer.__call__`] for details.
508
+
509
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
510
+ `past_key_values`).
511
+
512
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
513
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
514
+ information on the default strategy.
515
+
516
+ - 1 indicates the head is **not masked**,
517
+ - 0 indicates the head is **masked**.
518
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
519
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
520
+ config.n_positions - 1]`.
521
+
522
+ [What are position IDs?](../glossary#position-ids)
523
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
524
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
525
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
526
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
527
+
528
+ Two formats are allowed:
529
+ - a [`~cache_utils.Cache`] instance, see our
530
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
531
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
532
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
533
+ cache format.
534
+
535
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
536
+ legacy cache format will be returned.
537
+
538
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
539
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
540
+ of shape `(batch_size, sequence_length)`.
541
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
+ model's internal embedding lookup matrix.
545
+ use_cache (`bool`, *optional*):
546
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
547
+ `past_key_values`).
548
+ output_attentions (`bool`, *optional*):
549
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
550
+ tensors for more detail.
551
+ output_hidden_states (`bool`, *optional*):
552
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
553
+ more detail.
554
+ return_dict (`bool`, *optional*):
555
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
556
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
557
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
558
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
559
+ the complete sequence length.
560
+ """
561
+ def gradient_checkpointing_enable(self, *args, **kwargs):
562
+ self.model.gradient_checkpointing_enable(*args, **kwargs)
563
+
564
+ @add_start_docstrings(
565
+ "The bare Gigar Model outputting raw hidden-states without any specific head on top.",
566
+ GIGAR_START_DOCSTRING,
567
+ )
568
+ class GigarModel(GigarPreTrainedModel):
569
+ """
570
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GigarDecoderLayer`]
571
+
572
+ Args:
573
+ config: GigarConfig
574
+ """
575
+
576
+ def __init__(self, config: GigarConfig):
577
  super().__init__(config)
578
+ self.padding_idx = config.pad_token_id
579
+ self.vocab_size = config.vocab_size
 
580
 
581
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
582
+ self.layers = nn.ModuleList(
583
+ [GigarDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
584
+ )
585
+ self.norm = GigarRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
586
+ self.rotary_emb = GigarRotaryEmbedding(config=config)
587
+ self.gradient_checkpointing = False
588
+
589
+ # Initialize weights and apply final processing
590
+ self.post_init()
591
+
592
+ def get_input_embeddings(self):
593
+ return self.embed_tokens
594
+
595
+ def set_input_embeddings(self, value):
596
+ self.embed_tokens = value
597
+
598
+ @add_start_docstrings_to_model_forward(GIGAR_INPUTS_DOCSTRING)
599
  def forward(
600
  self,
601
  input_ids: torch.LongTensor = None,
602
  attention_mask: Optional[torch.Tensor] = None,
603
  position_ids: Optional[torch.LongTensor] = None,
604
+ past_key_values: Optional[Cache] = None,
605
  inputs_embeds: Optional[torch.FloatTensor] = None,
606
  use_cache: Optional[bool] = None,
607
  output_attentions: Optional[bool] = None,
608
  output_hidden_states: Optional[bool] = None,
609
  return_dict: Optional[bool] = None,
610
+ cache_position: Optional[torch.LongTensor] = None,
611
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
612
  ) -> Union[Tuple, BaseModelOutputWithPast]:
613
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
614
  output_hidden_states = (
615
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
616
  )
617
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
618
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
619
 
620
+ if (input_ids is None) ^ (inputs_embeds is not None):
621
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
+ if self.gradient_checkpointing and self.training and use_cache:
624
+ logger.warning_once(
625
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
 
 
 
 
 
 
 
626
  )
627
+ use_cache = False
 
 
628
 
629
  if inputs_embeds is None:
630
  inputs_embeds = self.embed_tokens(input_ids)
631
 
632
+ if use_cache and past_key_values is None:
633
+ past_key_values = DynamicCache()
 
 
 
 
 
 
634
 
635
+ if cache_position is None:
636
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
637
+ cache_position = torch.arange(
638
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
 
 
 
 
 
 
 
 
639
  )
640
 
641
+ if position_ids is None:
642
+ position_ids = cache_position.unsqueeze(0)
643
 
644
+ attention_mask = self._update_encoder_mask(attention_mask, inputs_embeds)
645
+
646
+ hidden_states = inputs_embeds
647
+
648
  # create position embeddings to be shared across the decoder layers
649
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
650
 
651
  # decoder layers
652
  all_hidden_states = () if output_hidden_states else None
653
  all_self_attns = () if output_attentions else None
 
654
 
655
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
656
  if output_hidden_states:
657
  all_hidden_states += (hidden_states,)
658
 
 
660
  layer_outputs = self._gradient_checkpointing_func(
661
  decoder_layer.__call__,
662
  hidden_states,
663
+ attention_mask, # causal_mask
664
  position_ids,
665
  past_key_values,
666
  output_attentions,
667
  use_cache,
668
+ cache_position,
669
+ position_embeddings,
670
  )
671
  else:
672
  layer_outputs = decoder_layer(
673
  hidden_states,
674
+ attention_mask=attention_mask, # causal_mask
675
  position_ids=position_ids,
676
  past_key_value=past_key_values,
677
  output_attentions=output_attentions,
678
  use_cache=use_cache,
679
+ cache_position=cache_position,
680
+ position_embeddings=position_embeddings,
681
+ **flash_attn_kwargs,
682
  )
683
 
684
  hidden_states = layer_outputs[0]
685
 
 
 
 
686
  if output_attentions:
687
  all_self_attns += (layer_outputs[1],)
688
 
 
692
  if output_hidden_states:
693
  all_hidden_states += (hidden_states,)
694
 
695
+ output = BaseModelOutputWithPast(
 
 
 
 
 
 
696
  last_hidden_state=hidden_states,
697
+ past_key_values=past_key_values if use_cache else None,
698
  hidden_states=all_hidden_states,
699
  attentions=all_self_attns,
700
  )
701
+ return output if return_dict else output.to_tuple()
702
 
703
+ def _update_encoder_mask(
704
+ self,
705
+ attention_mask: torch.Tensor,
706
+ input_tensor: torch.Tensor,
707
+ ):
708
+ # Для flash_attention_2 возвращаем исходную маску
709
+ if self.config._attn_implementation == "flash_attention_2":
710
+ if attention_mask is not None and (attention_mask == 0).any():
711
+ return attention_mask
712
+ return None
713
+
714
+ dtype, device = input_tensor.dtype, input_tensor.device
715
+ batch_size, sequence_length = input_tensor.shape[:2]
716
+
717
+ # 1. Создаём базовую маску без ограничений (все токены видят друг друга)
718
+ encoder_mask = torch.full(
719
+ (batch_size, 1, sequence_length, sequence_length),
720
+ fill_value=1.0,
721
+ dtype=dtype,
722
+ device=device
723
+ )
724
+
725
+ # 2. Применяем padding-маску если есть
726
+ if attention_mask is not None:
727
+ # Создаём 4D padding-маску [batch, 1, 1, seq_len]
728
+ padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
729
+
730
+ # Комбинируем: обнуляем позиции где padding_mask == 0
731
+ encoder_mask = encoder_mask * padding_mask
732
+
733
+ # Конвертируем в формат для softmax (0 = -inf)
734
+ min_dtype = torch.finfo(dtype).min
735
+ encoder_mask = encoder_mask.masked_fill(encoder_mask == 0.0, min_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
 
737
+ return encoder_mask
738
+
739
+ def _update_causal_mask(
740
+ self,
741
+ attention_mask: torch.Tensor,
742
+ input_tensor: torch.Tensor,
743
+ cache_position: torch.Tensor,
744
+ past_key_values: Cache,
745
+ output_attentions: bool,
746
+ ):
747
+ if self.config._attn_implementation == "flash_attention_2":
748
+ if attention_mask is not None and (attention_mask == 0.0).any():
749
+ return attention_mask
750
+ return None
751
+
752
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
753
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
754
+ # to infer the attention mask.
755
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
756
+ using_static_cache = isinstance(past_key_values, StaticCache)
757
+
758
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
759
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
760
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
761
+ attention_mask,
762
+ inputs_embeds=input_tensor,
763
+ past_key_values_length=past_seen_tokens,
764
+ is_training=self.training,
765
+ ):
766
+ return None
767
+
768
+ dtype, device = input_tensor.dtype, input_tensor.device
769
+ sequence_length = input_tensor.shape[1]
770
+ if using_static_cache:
771
+ target_length = past_key_values.get_max_cache_shape()
772
+ else:
773
+ target_length = (
774
+ attention_mask.shape[-1]
775
+ if isinstance(attention_mask, torch.Tensor)
776
+ else past_seen_tokens + sequence_length + 1
777
+ )
778
+
779
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
780
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
781
+ attention_mask,
782
+ sequence_length=sequence_length,
783
+ target_length=target_length,
784
+ dtype=dtype,
785
+ device=device,
786
+ cache_position=cache_position,
787
+ batch_size=input_tensor.shape[0],
788
+ )
789
+
790
+ if (
791
+ self.config._attn_implementation == "sdpa"
792
+ and attention_mask is not None
793
+ and attention_mask.device.type == "cuda"
794
+ and not output_attentions
795
+ ):
796
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
797
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
798
+ # Details: https://github.com/pytorch/pytorch/issues/110213
799
+ min_dtype = torch.finfo(dtype).min
800
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
801
+
802
+ return causal_mask
803
+
804
+ @staticmethod
805
+ def _prepare_4d_causal_attention_mask_with_cache_position(
806
+ attention_mask: torch.Tensor,
807
+ sequence_length: int,
808
+ target_length: int,
809
+ dtype: torch.dtype,
810
+ device: torch.device,
811
+ cache_position: torch.Tensor,
812
+ batch_size: int,
813
+ **kwargs,
814
+ ):
815
+ """
816
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
817
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
818
+
819
+ Args:
820
+ attention_mask (`torch.Tensor`):
821
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
822
+ `(batch_size, 1, query_length, key_value_length)`.
823
+ sequence_length (`int`):
824
+ The sequence length being processed.
825
+ target_length (`int`):
826
+ The target length: when generating with static cache, the mask should be as long as the static cache,
827
+ to account for the 0 padding, the part of the cache that is not filled yet.
828
+ dtype (`torch.dtype`):
829
+ The dtype to use for the 4D attention mask.
830
+ device (`torch.device`):
831
+ The device to plcae the 4D attention mask on.
832
+ cache_position (`torch.Tensor`):
833
+ Indices depicting the position of the input sequence tokens in the sequence.
834
+ batch_size (`torch.Tensor`):
835
+ Batch size.
836
+ """
837
+ if attention_mask is not None and attention_mask.dim() == 4:
838
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
839
+ causal_mask = attention_mask
840
+ else:
841
+ min_dtype = torch.finfo(dtype).min
842
+ causal_mask = torch.full(
843
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
844
+ )
845
+ if sequence_length != 1:
846
+ causal_mask = torch.triu(causal_mask, diagonal=1)
847
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
848
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
849
+ if attention_mask is not None:
850
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
851
+ mask_length = attention_mask.shape[-1]
852
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
853
+ padding_mask = padding_mask == 0
854
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
855
+ padding_mask, min_dtype
856
+ )
857
+
858
+ return causal_mask
859
+
860
+
861
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
862
+
863
+
864
+ class GigarForCausalLM(GigarPreTrainedModel, GenerationMixin):
865
+ _tied_weights_keys = ["lm_head.weight"]
866
+ _tp_plan = {"lm_head": "colwise_rep"}
867
+
868
+ def __init__(self, config):
869
+ super().__init__(config)
870
+ self.model = GigarModel(config)
871
+ self.vocab_size = config.vocab_size
872
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
873
+
874
+ # Initialize weights and apply final processing
875
+ self.post_init()
876
+
877
+ def get_input_embeddings(self):
878
+ return self.model.embed_tokens
879
+
880
+ def set_input_embeddings(self, value):
881
+ self.model.embed_tokens = value
882
+
883
+ def get_output_embeddings(self):
884
+ return self.lm_head
885
+
886
+ def set_output_embeddings(self, new_embeddings):
887
+ self.lm_head = new_embeddings
888
+
889
+ def set_decoder(self, decoder):
890
+ self.model = decoder
891
+
892
+ def get_decoder(self):
893
+ return self.model
894
+
895
+ @add_start_docstrings_to_model_forward(GIGAR_INPUTS_DOCSTRING)
896
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
897
+ def forward(
898
+ self,
899
+ input_ids: torch.LongTensor = None,
900
+ attention_mask: Optional[torch.Tensor] = None,
901
+ position_ids: Optional[torch.LongTensor] = None,
902
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
903
+ inputs_embeds: Optional[torch.FloatTensor] = None,
904
+ labels: Optional[torch.LongTensor] = None,
905
+ use_cache: Optional[bool] = None,
906
+ output_attentions: Optional[bool] = None,
907
+ output_hidden_states: Optional[bool] = None,
908
+ return_dict: Optional[bool] = None,
909
+ cache_position: Optional[torch.LongTensor] = None,
910
+ num_logits_to_keep: int = 0,
911
+ **kwargs: Unpack[KwargsForCausalLM],
912
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
913
+ r"""
914
+ Args:
915
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
916
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
917
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
918
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
919
+
920
+ num_logits_to_keep (`int`, *optional*):
921
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
922
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
923
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
924
+
925
+ Returns:
926
+
927
+ Example:
928
+
929
+ ```python
930
+ >>> from transformers import AutoTokenizer, GigarForCausalLM
931
+
932
+ >>> model = GigarForCausalLM.from_pretrained("meta-gigar/Gigar-2-7b-hf")
933
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-gigar/Gigar-2-7b-hf")
934
+
935
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
936
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
937
+
938
+ >>> # Generate
939
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
940
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
941
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
942
+ ```"""
943
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
944
+ output_hidden_states = (
945
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
946
+ )
947
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
+
949
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
950
+ outputs = self.model(
951
+ input_ids=input_ids,
952
+ attention_mask=attention_mask,
953
+ position_ids=position_ids,
954
+ past_key_values=past_key_values,
955
+ inputs_embeds=inputs_embeds,
956
+ use_cache=use_cache,
957
+ output_attentions=output_attentions,
958
+ output_hidden_states=output_hidden_states,
959
+ return_dict=return_dict,
960
+ cache_position=cache_position,
961
+ **kwargs,
962
+ )
963
+
964
+ hidden_states = outputs[0]
965
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
966
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
967
+
968
+ loss = None
969
+ if labels is not None:
970
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
971
+
972
+ if not return_dict:
973
+ output = (logits,) + outputs[1:]
974
+ return (loss,) + output if loss is not None else output
975
+
976
+ return CausalLMOutputWithPast(
977
+ loss=loss,
978
+ logits=logits,
979
+ past_key_values=outputs.past_key_values,
980
+ hidden_states=outputs.hidden_states,
981
+ attentions=outputs.attentions,
982
+ )
983
+
984
+
985
+ class FeedForward(nn.Module):
986
  def __init__(self, dim, mult = 4):
987
  super().__init__()
988
+ self.hidden_size = dim
989
+ self.intermediate_size = dim * mult
990
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
991
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
992
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
993
+ self.act_fn = nn.SiLU()
994
 
995
  def forward(self, x):
996
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
997
 
 
 
998
 
999
+ class Attention(nn.Module):
1000
+ def __init__(self, query_dimension, context_dimension=None, num_heads=8, head_dim=64):
1001
+ super().__init__()
1002
+ inner_dimension = head_dim * num_heads
1003
+ context_dimension = context_dimension if context_dimension is not None else query_dimension
1004
 
1005
+ self.scaling_factor = head_dim ** -0.5
1006
+ self.num_heads = num_heads
1007
 
1008
+ self.to_q = nn.Linear(query_dimension, inner_dimension, bias=False)
1009
+ self.to_kv = nn.Linear(context_dimension, inner_dimension * 2, bias=False)
1010
+ self.to_out = nn.Linear(inner_dimension, query_dimension, bias=False)
 
 
 
 
1011
 
1012
+ def forward(self, input_tensor, context=None, attention_mask=None):
1013
+ batch_size, seq_len, _ = input_tensor.shape
1014
+ num_heads = self.num_heads
1015
 
1016
+ # Project input to query
1017
+ query = self.to_q(input_tensor)
 
 
 
 
1018
 
1019
+ # Use input as context if not provided
1020
+ context = input_tensor if context is None else context
1021
+ key, value = self.to_kv(context).chunk(2, dim=-1)
1022
+
1023
+ # Rearrange for multi-head attention
1024
+ query = rearrange(query, 'b n (h d) -> (b h) n d', h=num_heads)
1025
+ key = rearrange(key, 'b n (h d) -> (b h) n d', h=num_heads)
1026
+ value = rearrange(value, 'b n (h d) -> (b h) n d', h=num_heads)
1027
+
1028
+ # Compute scaled dot-product attention
1029
+ with torch.backends.cuda.sdp_kernel(
1030
+ enable_flash=True,
1031
+ enable_math=True,
1032
+ enable_mem_efficient=True
1033
+ ):
1034
+ attention_output = F.scaled_dot_product_attention(query, key, value)
1035
+
1036
+ # Rearrange back to original shape
1037
+ attention_output = rearrange(attention_output, '(b h) n d -> b n (h d)', h=num_heads)
1038
+
1039
+ return self.to_out(attention_output)
1040
 
1041
 
1042
  class LatentAttentionModel(PreTrainedModel):
1043
  config_class = LatentAttentionConfig
1044
 
1045
+ def __init__(self, configuration: LatentAttentionConfig):
1046
+ super().__init__(configuration)
1047
+
1048
+ # Extract configuration parameters
1049
+ num_latents = configuration.num_latents_value
1050
+ latent_dimension = configuration.latent_dim
1051
+ cross_attention_heads = configuration.num_cross_heads
1052
+ cross_head_dimension = configuration.cross_dim_head
1053
+ hidden_dimension = configuration.hidden_dim
1054
+
1055
+ # Initialize cross-attention components
1056
+ self.cross_attend_blocks = nn.ModuleList([
1057
+ Attention(
1058
+ query_dimension=latent_dimension,
1059
+ context_dimension=hidden_dimension,
1060
+ num_heads=cross_attention_heads,
1061
+ head_dim=cross_head_dimension
1062
+ ),
1063
+ FeedForward(latent_dimension)
1064
  ])
1065
 
1066
+ # Register learnable latents as model parameter
1067
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dimension))
1068
+
1069
+ def forward(self, hidden_states, attention_mask: Optional[torch.Tensor] = None):
1070
+ cross_attention, feed_forward = self.cross_attend_blocks
1071
+
1072
+ batch_size, device = hidden_states.size(0), hidden_states.device
1073
+
1074
+ # Expand latents to match batch size
1075
+ expanded_latents = self.latents.repeat(batch_size, 1, 1)
1076
+
1077
+ # Apply cross-attention with residual connection
1078
+ attended_output = cross_attention(
1079
+ hidden_states, context=expanded_latents, attention_mask=attention_mask) + hidden_states
1080
+
1081
+ # Apply feed-forward with residual connection
1082
+ processed_output = feed_forward(attended_output) + attended_output
1083
+
1084
+ return processed_output
1085
+
1086
 
1087
  class GigarEmbedModel(PreTrainedModel):
1088
  config_class = GigarEmbedConfig
1089
+ _supports_flash_attn_2 = True
1090
+ _no_split_modules = ["GigarDecoderLayer", "LatentAttentionModel"]
1091
 
1092
+ def __init__(self, configuration: GigarEmbedConfig):
1093
+ super().__init__(configuration)
1094
+
1095
+ # Initialize latent attention model
1096
+ self.latent_attention_model = AutoModel.from_config(
1097
+ configuration.latent_attention_config
1098
+ )
1099
+
1100
+ self.tokenizer, self.text_encoder = None, None
1101
+ if configuration.text_config is not None:
1102
+ # Initialize text model if provided in config
1103
+ self.model = AutoModel.from_config(configuration.text_config)
1104
+
1105
+ # Initialize tokenizer if text config is available
1106
+ self.tokenizer = AutoTokenizer.from_pretrained(
1107
+ configuration.text_config.name_or_path
1108
+ )
1109
+
1110
+ # Set configuration parameters
1111
+ self.padding_side = configuration.padding_side
1112
+ self.add_eos = configuration.add_eos
1113
+ self.mask_type = configuration.mask_type
1114
+
1115
+ # Add padding token if configured
1116
+ if configuration.add_pad_token and self.tokenizer is not None:
1117
  self.add_pad_token()
1118
 
1119
  def add_pad_token(self):
1120
  self.tokenizer.pad_token_id = 0
1121
  self.tokenizer.padding_side = self.padding_side
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
 
1123
+ def gradient_checkpointing_enable(self, *args, **kwargs):
1124
+ self.model.gradient_checkpointing_enable(*args, **kwargs)
1125
+
1126
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
1127
+ return_embeddings: bool = False, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
  kwargs.pop('token_type_ids', None)
1129
 
1130
  with torch.autocast('cuda', dtype=torch.bfloat16):
1131
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
1132
 
1133
+ last_hidden = self.latent_attention_model(outputs.last_hidden_state, attention_mask)
1134
 
1135
+ if return_embeddings:
1136
+ return self.mean_pool(last_hidden, attention_mask)
1137
 
1138
+ return last_hidden
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
 
1140
+ def mean_pool(self, last_hidden: torch.Tensor, attention_mask: torch.Tensor):
1141
+ last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
1142
+ embeddings = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
1143
+ return F.normalize(embeddings, p=2, dim=-1)
1144
 
1145
 
1146
  ## AutoModel Register
1147
+ AutoModel.register(GigarConfig, GigarModel)
1148
  AutoModel.register(GigarEmbedConfig, GigarEmbedModel)
1149
  AutoModel.register(LatentAttentionConfig, LatentAttentionModel)
 
1150
 
1151
  ## Register for auto class
1152
+ GigarModel.register_for_auto_class("AutoModel")
1153
  GigarEmbedModel.register_for_auto_class("AutoModel")
1154
  LatentAttentionModel.register_for_auto_class("AutoModel")
 
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bcb50618d6a02d4562ada12978a8aa9e0b6e31260f71acce28586072a9005d4a
3
- size 10728437
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ec0a1cffcc9192f5ee3d7b273673a062918055238bda3d23cfb6d2512e947ff
3
+ size 10728325
tokenizer_config.json CHANGED
@@ -2074,8 +2074,10 @@
2074
  }
2075
  },
2076
  "bos_token": "<s>",
 
2077
  "clean_up_tokenization_spaces": true,
2078
  "eos_token": "</s>",
 
2079
  "max_length": 512,
2080
  "model_max_length": 1000000000000000019884624838656,
2081
  "pad_to_multiple_of": null,
 
2074
  }
2075
  },
2076
  "bos_token": "<s>",
2077
+ "chat_template": "{%- for message in messages -%}{{ message['content'] }}{%- if not loop.last -%}{{ ' ' }}{%- endif -%}{%- endfor -%}",
2078
  "clean_up_tokenization_spaces": true,
2079
  "eos_token": "</s>",
2080
+ "extra_special_tokens": {},
2081
  "max_length": 512,
2082
  "model_max_length": 1000000000000000019884624838656,
2083
  "pad_to_multiple_of": null,