TempoPFN / src /models /gated_deltaproduct /modeling_gated_deltaproduct.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from fla.layers.attn import Attention
from fla.models.utils import Cache
from fla.modules import GatedMLP as GatedDeltaProductMLP
from fla.modules import RMSNorm
from src.models.gated_deltaproduct.configuration_gated_deltaproduct import (
GatedDeltaProductConfig,
)
from src.models.gated_deltaproduct.gated_deltaproduct import GatedDeltaProduct
if TYPE_CHECKING:
from transformers.processing_utils import Unpack
class GatedDeltaProductBlock(nn.Module):
def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
if config.attn is not None and layer_idx in config.attn["layers"]:
self.attn = Attention(
hidden_size=config.hidden_size,
num_heads=config.attn["num_heads"],
num_kv_heads=config.attn["num_kv_heads"],
qkv_bias=config.attn["qkv_bias"],
window_size=config.attn["window_size"],
rope_theta=config.attn["rope_theta"],
max_position_embeddings=config.max_position_embeddings,
layer_idx=layer_idx,
)
else:
self.attn = GatedDeltaProduct(
mode=config.attn_mode,
hidden_size=config.hidden_size,
expand_v=config.expand_v,
head_dim=config.head_dim,
num_heads=config.num_heads,
use_gate=config.use_gate,
use_forget_gate=config.use_forget_gate,
use_short_conv=config.use_short_conv,
conv_size=config.conv_size,
norm_eps=config.norm_eps,
allow_neg_eigval=config.allow_neg_eigval,
num_householder=config.num_householder,
layer_idx=layer_idx,
)
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
self.mlp = GatedDeltaProductMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
fuse_swiglu=config.fuse_swiglu,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values: Cache | list[torch.FloatTensor] | None = None,
use_cache: bool | None = False,
output_attentions: bool | None = False,
initial_state: torch.FloatTensor | None = None,
**kwargs: Unpack[dict],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
residual = hidden_states
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
initial_state=initial_state,
**kwargs,
)
if self.config.fuse_norm:
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
else:
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.mlp_norm(hidden_states)
hidden_states = self.mlp(hidden_states, **kwargs)
hidden_states = residual + hidden_states
outputs = (hidden_states, attentions, past_key_values)
return outputs