File size: 1,301 Bytes
e035647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch import Tensor, nn
from torch.nn import functional as F
from transformers.models.llama import LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel

from .configuration_hyperllama import HyperLlamaConfig


class ScaledLinear(nn.Linear):
    def __init__(
        self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None, scaling_factor: int = 1
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.scale_factor = scaling_factor

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(
            input,
            self.weight / self.scale_factor,
            self.bias / self.scale_factor if self.bias is not None else None,
        )


# TODO: Add docstring
class HyperLlamaForCausalLM(LlamaForCausalLM):
    config_class = HyperLlamaConfig

    def __init__(self, config):
        # Skip initializing LlamaForCausalLM
        super(LlamaPreTrainedModel, self).__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = ScaledLinear(
            config.hidden_size, config.vocab_size, bias=False, scaling_factor=config.lm_head_normalization_factor
        )

        # Initialize weights and apply final processing
        self.post_init()