|
|
""" |
|
|
Full definition of a GPT Language Model, all of it in this single file. |
|
|
References: |
|
|
1) the official GPT-2 TensorFlow implementation released by OpenAI: |
|
|
https://github.com/openai/gpt-2/blob/master/src/model.py |
|
|
2) huggingface/transformers PyTorch implementation: |
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py |
|
|
""" |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Literal |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from .hook_utils import ( |
|
|
hook_namespace, |
|
|
hook_save, |
|
|
torch_recompute_preserving_hook_context, |
|
|
) |
|
|
|
|
|
|
|
|
def sample_top_k(*, n: int, k: int, shape: tuple[int, ...]): |
|
|
"""Fallback sampler used only when sparse kernels are enabled.""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
return torch.randn(shape, device=device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
class AbsTopK(nn.Module): |
|
|
def __init__(self, k): |
|
|
super().__init__() |
|
|
self.k = k |
|
|
|
|
|
def forward(self, x): |
|
|
vals, inds = torch.topk(x.abs(), self.k, dim=-1, sorted=False) |
|
|
ret = torch.zeros_like(x) |
|
|
ret.scatter_(-1, inds, x.gather(-1, inds)) |
|
|
return ret |
|
|
|
|
|
|
|
|
def barrier(): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
|
|
|
def __init__(self, ndim, bias): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
|
|
def forward(self, input): |
|
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
assert config.d_model % config.n_head == 0 |
|
|
|
|
|
self.c_attn = config.Linear( |
|
|
config.d_model, 3 * config.d_head * config.n_head, bias=config.bias |
|
|
) |
|
|
|
|
|
self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
self.n_head = config.n_head |
|
|
self.d_head = config.d_head |
|
|
self.d_model = config.d_model |
|
|
self.dropout = config.dropout |
|
|
|
|
|
self.config = config |
|
|
|
|
|
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash |
|
|
|
|
|
if self.flash: |
|
|
self.attn_imp = ( |
|
|
SDPAWithSink(config.n_head) if config.sink else F.scaled_dot_product_attention |
|
|
) |
|
|
|
|
|
if not self.flash: |
|
|
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") |
|
|
|
|
|
self.register_buffer( |
|
|
"bias", |
|
|
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
|
|
1, 1, config.block_size, config.block_size |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.size() |
|
|
|
|
|
x = self.config.maybe_activation_sparsity(x, "attn_in") |
|
|
x = hook_save("act_in", x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in input" |
|
|
|
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_head * self.d_head, dim=2) |
|
|
|
|
|
k = self.config.maybe_activation_sparsity(k, "attn_k") |
|
|
q = self.config.maybe_activation_sparsity(q, "attn_q") |
|
|
v = self.config.maybe_activation_sparsity(v, "attn_v") |
|
|
|
|
|
k = hook_save("k", k) |
|
|
q = hook_save("q", q) |
|
|
v = hook_save("v", v) |
|
|
|
|
|
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2) |
|
|
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert q.isfinite().all(), "nan in query" |
|
|
assert k.isfinite().all(), "nan in key" |
|
|
assert v.isfinite().all(), "nan in value" |
|
|
|
|
|
attention_scale = 1.0 / math.sqrt(k.size(-1)) |
|
|
|
|
|
|
|
|
if self.flash: |
|
|
|
|
|
y = self.attn_imp( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
dropout_p=self.dropout if self.training else 0, |
|
|
is_causal=True, |
|
|
scale=attention_scale, |
|
|
) |
|
|
else: |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) * attention_scale |
|
|
att = att.masked_fill( |
|
|
self.bias[:, :, :T, :T] == 0, torch.finfo(att.dtype).min |
|
|
) |
|
|
|
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_dropout(att) |
|
|
y = att @ v |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert y.isfinite().all(), "nan in attention output" |
|
|
|
|
|
y = ( |
|
|
y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.d_head) |
|
|
) |
|
|
|
|
|
|
|
|
y = hook_save("y", y) |
|
|
|
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert y.isfinite().all(), "nan in attention output 2" |
|
|
|
|
|
y = self.config.maybe_activation_sparsity(y, "attn_out") |
|
|
return y |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.c_fc = config.Linear(config.d_model, config.d_mlp, bias=config.bias) |
|
|
self.act_fn = { |
|
|
"gelu": nn.GELU(), |
|
|
"relu": nn.ReLU(), |
|
|
}[config.activation_type] |
|
|
self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.config.maybe_activation_sparsity(x, "mlp_in") |
|
|
x = hook_save("act_in", x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in mlp input" |
|
|
|
|
|
x = self.c_fc(x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in mlp after c_fc" |
|
|
|
|
|
x = self.act_fn(x) |
|
|
x = self.config.maybe_activation_sparsity(x, "mlp_neuron") |
|
|
x = hook_save("post_act", x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in mlp after act" |
|
|
|
|
|
x = self.c_proj(x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in mlp after c_proj" |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = self.config.maybe_activation_sparsity(x, "mlp_out") |
|
|
return x |
|
|
|
|
|
|
|
|
class SDPAWithSink(nn.Module): |
|
|
""" |
|
|
Adds a learnable denominator-only term ("attention sink") to SDPA by |
|
|
concatenating a dummy KV slot whose logit is b and whose V is zero. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_heads: int, init_logit: float = 0.0): |
|
|
super().__init__() |
|
|
shape = (num_heads,) |
|
|
self.sink_logit = nn.Parameter(torch.full(shape, init_logit)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
dropout_p: float = 0.0, |
|
|
is_causal: bool = False, |
|
|
scale: float | None = None, |
|
|
) -> torch.Tensor: |
|
|
B, H, Lq, D = q.shape |
|
|
_, _, Lk, _ = k.shape |
|
|
Dv = v.size(-1) |
|
|
|
|
|
|
|
|
k_sink = torch.zeros((B, H, 1, D), dtype=q.dtype, device=q.device) |
|
|
v_sink = torch.zeros((B, H, 1, Dv), dtype=v.dtype, device=v.device) |
|
|
k_aug = torch.cat([k_sink, k], dim=2) |
|
|
v_aug = torch.cat([v_sink, v], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allow = torch.zeros((Lq, Lk + 1), dtype=torch.bool, device=q.device) |
|
|
allow[:, 0] = True |
|
|
|
|
|
real = torch.ones((Lq, Lk), dtype=torch.bool, device=q.device).tril() |
|
|
allow[:, 1:] = real |
|
|
|
|
|
|
|
|
allow = allow.view(1, 1, Lq, Lk + 1).expand(B, H, Lq, Lk + 1) |
|
|
|
|
|
|
|
|
neg_inf = torch.finfo(q.dtype).min |
|
|
base_mask = torch.where( |
|
|
allow, |
|
|
torch.zeros((), dtype=q.dtype, device=q.device), |
|
|
torch.full((), neg_inf, dtype=q.dtype, device=q.device), |
|
|
) |
|
|
|
|
|
|
|
|
if self.sink_logit.numel() == H: |
|
|
b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, H, 1, 1) |
|
|
else: |
|
|
b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, 1, 1, 1) |
|
|
|
|
|
sink_bias_mask = torch.zeros((1, 1, 1, Lk + 1), dtype=q.dtype, device=q.device) |
|
|
sink_bias_mask[..., 0] = 1.0 |
|
|
attn_mask = base_mask + sink_bias_mask * b |
|
|
|
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
q, |
|
|
k_aug, |
|
|
v_aug, |
|
|
attn_mask=attn_mask, |
|
|
dropout_p=dropout_p, |
|
|
is_causal=False, |
|
|
scale=scale, |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.ln_1 = ( |
|
|
nn.RMSNorm(config.d_model) |
|
|
if config.rms_norm |
|
|
else LayerNorm(config.d_model, bias=config.ln_bias) |
|
|
) |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.ln_2 = ( |
|
|
nn.RMSNorm(config.d_model) |
|
|
if config.rms_norm |
|
|
else LayerNorm(config.d_model, bias=config.ln_bias) |
|
|
) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward_attn_block(self, x): |
|
|
x = hook_save("resid_in", x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in blk input" |
|
|
|
|
|
with hook_namespace("attn"): |
|
|
if self.config.grad_checkpointing: |
|
|
x = x + hook_save( |
|
|
"resid_delta", |
|
|
torch_recompute_preserving_hook_context( |
|
|
lambda x: self.attn(self.ln_1(x)), x, use_reentrant=False |
|
|
), |
|
|
) |
|
|
else: |
|
|
x = x + hook_save("resid_delta", self.attn(self.ln_1(x))) |
|
|
|
|
|
if self.config.residual_activation_type == "relu": |
|
|
x = torch.relu(x) |
|
|
x = self.config.maybe_activation_sparsity(x, "resid_post_attn") |
|
|
|
|
|
return x |
|
|
|
|
|
def forward_mlp_block(self, x): |
|
|
x = hook_save("resid_mid", x) |
|
|
with hook_namespace("mlp"): |
|
|
if self.config.grad_checkpointing: |
|
|
x = x + hook_save( |
|
|
"resid_delta", |
|
|
torch_recompute_preserving_hook_context( |
|
|
lambda x: self.mlp(self.ln_2(x)), x, use_reentrant=False |
|
|
), |
|
|
) |
|
|
else: |
|
|
x = x + hook_save("resid_delta", self.mlp(self.ln_2(x))) |
|
|
|
|
|
if self.config.residual_activation_type == "relu": |
|
|
x = torch.relu(x) |
|
|
x = self.config.maybe_activation_sparsity(x, "resid_post_mlp") |
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.forward_attn_block(x) |
|
|
x = self.forward_mlp_block(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class CausalSelfAttentionCatPosEmb(CausalSelfAttention): |
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
assert config.d_model % config.n_head == 0 |
|
|
|
|
|
self.c_attn = config.Linear( |
|
|
config.d_model_in, 3 * config.d_head * config.n_head, bias=config.bias |
|
|
) |
|
|
|
|
|
self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
self.n_head = config.n_head |
|
|
self.d_head = config.d_head |
|
|
self.d_model_in = config.d_model_in |
|
|
self.d_model = config.d_model |
|
|
self.dropout = config.dropout |
|
|
self.config = config |
|
|
|
|
|
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash |
|
|
|
|
|
if not self.flash: |
|
|
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") |
|
|
|
|
|
self.register_buffer( |
|
|
"bias", |
|
|
torch.tril(torch.ones(config.block_size, config.block_size)).view( |
|
|
1, 1, config.block_size, config.block_size |
|
|
), |
|
|
) |
|
|
|
|
|
def forward(self, x, pos_emb_to_cat): |
|
|
|
|
|
if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1: |
|
|
pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1) |
|
|
x = torch.cat([x, pos_emb_to_cat], dim=-1) |
|
|
return super().forward(x) |
|
|
|
|
|
|
|
|
class MLPCatPosEmb(MLP): |
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.c_fc = config.Linear(config.d_model_in, config.d_mlp, bias=config.bias) |
|
|
self.act_fn = { |
|
|
"gelu": nn.GELU(), |
|
|
"relu": nn.ReLU(), |
|
|
}[config.activation_type] |
|
|
self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x, pos_emb_to_cat): |
|
|
|
|
|
if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1: |
|
|
pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1) |
|
|
x = torch.cat([x, pos_emb_to_cat], dim=-1) |
|
|
x = super().forward(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class BlockCatPosEmb(Block): |
|
|
|
|
|
def __init__(self, config): |
|
|
|
|
|
super().__init__(config) |
|
|
self.ln_p1 = ( |
|
|
nn.RMSNorm(config.d_pos_emb) |
|
|
if config.rms_norm |
|
|
else LayerNorm(config.d_pos_emb, bias=config.ln_bias) |
|
|
) |
|
|
self.ln_p2 = ( |
|
|
nn.RMSNorm(config.d_pos_emb) |
|
|
if config.rms_norm |
|
|
else LayerNorm(config.d_pos_emb, bias=config.ln_bias) |
|
|
) |
|
|
self.attn = CausalSelfAttentionCatPosEmb(config) |
|
|
self.mlp = MLPCatPosEmb(config) |
|
|
|
|
|
def forward_attn_block(self, x, p): |
|
|
x = hook_save("resid_in", x) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in blk input" |
|
|
|
|
|
with hook_namespace("attn"): |
|
|
if self.config.grad_checkpointing: |
|
|
x = x + hook_save( |
|
|
"resid_delta", |
|
|
torch_recompute_preserving_hook_context( |
|
|
lambda x, p: self.attn(self.ln_1(x), self.ln_p1(p)), |
|
|
x, |
|
|
p, |
|
|
use_reentrant=False, |
|
|
), |
|
|
) |
|
|
else: |
|
|
x = x + hook_save("resid_delta", self.attn(self.ln_1(x), self.ln_p1(p))) |
|
|
|
|
|
if self.config.residual_activation_type == "relu": |
|
|
x = torch.relu(x) |
|
|
x = self.config.maybe_activation_sparsity(x, "resid_post_attn") |
|
|
|
|
|
return x |
|
|
|
|
|
def forward_mlp_block(self, x, p): |
|
|
x = hook_save("resid_mid", x) |
|
|
with hook_namespace("mlp"): |
|
|
if self.config.grad_checkpointing: |
|
|
x = x + hook_save( |
|
|
"resid_delta", |
|
|
torch_recompute_preserving_hook_context( |
|
|
lambda x, p: self.mlp(self.ln_2(x), self.ln_p2(p)), |
|
|
x, |
|
|
p, |
|
|
use_reentrant=False, |
|
|
), |
|
|
) |
|
|
else: |
|
|
x = x + hook_save("resid_delta", self.mlp(self.ln_2(x), self.ln_p2(p))) |
|
|
|
|
|
if self.config.residual_activation_type == "relu": |
|
|
x = torch.relu(x) |
|
|
x = self.config.maybe_activation_sparsity(x, "resid_post_mlp") |
|
|
return x |
|
|
|
|
|
def forward(self, x, pos_emb_to_cat): |
|
|
x = self.forward_attn_block(x, pos_emb_to_cat) |
|
|
x = self.forward_mlp_block(x, pos_emb_to_cat) |
|
|
return x |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GPTConfig: |
|
|
block_size: int = 1024 |
|
|
vocab_size: int = 50304 |
|
|
n_layer: int = 12 |
|
|
n_head: int = 12 |
|
|
d_head: int | None = None |
|
|
d_model: int = 768 |
|
|
dropout: float = 0.0 |
|
|
bias: bool = ( |
|
|
True |
|
|
) |
|
|
ln_bias: bool = ( |
|
|
True |
|
|
) |
|
|
rms_norm: bool = False |
|
|
residual_activation_type: Literal["identity", "relu"] = "identity" |
|
|
activation_type: Literal["gelu", "relu"] = "gelu" |
|
|
afrac: float | None = None |
|
|
afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out" |
|
|
debug_nans: bool = False |
|
|
tied_unembed: bool = True |
|
|
|
|
|
tokenizer_name: str = "tinypython_2k" |
|
|
|
|
|
grad_checkpointing: bool = True |
|
|
d_mlp: int | None = None |
|
|
|
|
|
enable_bigram_table: bool = False |
|
|
learnable_bigram_table: bool = False |
|
|
d_pos_emb: int | None = None |
|
|
dropout_cat_pos_emb: bool = False |
|
|
sinusoidal_cat_pos_emb: bool = False |
|
|
enable_sparse_kernels: bool = False |
|
|
|
|
|
flash: bool = True |
|
|
sink: bool = False |
|
|
|
|
|
@property |
|
|
def cat_pos_emb(self): |
|
|
return self.d_pos_emb is not None |
|
|
|
|
|
@property |
|
|
def d_model_in(self): |
|
|
return self.d_model + self.d_pos_emb if self.cat_pos_emb else self.d_model |
|
|
|
|
|
def __post_init__(self): |
|
|
assert self.d_model % self.n_head == 0 |
|
|
assert self.residual_activation_type in ["identity", "relu"] |
|
|
assert self.activation_type in ["gelu", "relu"] |
|
|
|
|
|
if self.d_mlp is None: |
|
|
self.d_mlp = 4 * self.d_model |
|
|
if self.d_head is None: |
|
|
self.d_head = self.d_model // self.n_head |
|
|
|
|
|
@property |
|
|
def Linear(self): |
|
|
return nn.Linear |
|
|
|
|
|
def maybe_activation_sparsity(self, x, loctype): |
|
|
if self.afrac is not None and loctype in self.afrac_loctypes.split(","): |
|
|
|
|
|
def keep_abstopk(x, k): |
|
|
ret = torch.zeros_like(x) |
|
|
_, topk_inds = torch.topk(x.abs(), k, dim=-1, sorted=False) |
|
|
ret.scatter_(-1, topk_inds, x.gather(-1, topk_inds)) |
|
|
return ret |
|
|
|
|
|
x = keep_abstopk( |
|
|
x, |
|
|
k=int(self.afrac * x.shape[-1]), |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class GPT(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
assert config.vocab_size is not None |
|
|
assert config.block_size is not None |
|
|
self.config = config |
|
|
|
|
|
if config.cat_pos_emb: |
|
|
block_cls = BlockCatPosEmb |
|
|
else: |
|
|
block_cls = Block |
|
|
|
|
|
self.transformer = nn.ModuleDict( |
|
|
dict( |
|
|
wte=nn.Embedding(config.vocab_size, config.d_model), |
|
|
wpe=nn.Embedding(config.block_size, config.d_pos_emb or config.d_model), |
|
|
drop=nn.Dropout(config.dropout), |
|
|
h=nn.ModuleList([(block_cls(config)) for _ in range(config.n_layer)]), |
|
|
ln_f=nn.RMSNorm(config.d_model) |
|
|
if config.rms_norm |
|
|
else LayerNorm(config.d_model, bias=config.ln_bias), |
|
|
) |
|
|
) |
|
|
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
self.register_buffer( |
|
|
"final_logits_bias", torch.zeros(config.vocab_size, dtype=torch.float32) |
|
|
) |
|
|
|
|
|
if self.config.enable_bigram_table: |
|
|
if self.config.learnable_bigram_table: |
|
|
|
|
|
self.bigram_table = nn.Parameter( |
|
|
torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32) |
|
|
) |
|
|
else: |
|
|
self.register_buffer( |
|
|
"bigram_table", |
|
|
torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32), |
|
|
) |
|
|
else: |
|
|
self.bigram_table = None |
|
|
|
|
|
|
|
|
config.tied_unembed = False |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
for pn, p in self.named_parameters(): |
|
|
if pn.endswith("c_proj.weight"): |
|
|
if p.is_sparse: |
|
|
num_nonzero = p._nnz() |
|
|
p._values().data = ( |
|
|
sample_top_k(n=p.numel(), k=num_nonzero, shape=(num_nonzero,)) |
|
|
* 0.02 |
|
|
/ math.sqrt(2 * config.n_layer) |
|
|
) |
|
|
else: |
|
|
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) |
|
|
|
|
|
|
|
|
if config.cat_pos_emb and config.sinusoidal_cat_pos_emb: |
|
|
assert config.d_pos_emb is not None, ( |
|
|
"sinusoidal_cat_pos_emb requires cat_pos_emb (d_pos_emb must be set)" |
|
|
) |
|
|
with torch.no_grad(): |
|
|
T = config.block_size |
|
|
D = config.d_pos_emb |
|
|
device = self.transformer.wpe.weight.device |
|
|
dtype = self.transformer.wpe.weight.dtype |
|
|
positions = torch.arange(T, device=device, dtype=dtype).unsqueeze(1) |
|
|
d_half = max(1, D // 2) |
|
|
|
|
|
T_float = float(T) |
|
|
p_min = 4.0 |
|
|
p_max = max(p_min, T_float) |
|
|
periods = torch.logspace( |
|
|
math.log10(p_min), math.log10(p_max), steps=d_half, device=device, dtype=dtype |
|
|
) |
|
|
freqs = 2 * math.pi / periods |
|
|
angles = positions * freqs |
|
|
sinv = torch.sin(angles) |
|
|
cosv = torch.cos(angles) |
|
|
enc = torch.cat([sinv, cosv], dim=1) |
|
|
if enc.shape[1] < D: |
|
|
pad = torch.zeros(T, D - enc.shape[1], device=device, dtype=dtype) |
|
|
enc = torch.cat([enc, pad], dim=1) |
|
|
elif enc.shape[1] > D: |
|
|
enc = enc[:, :D] |
|
|
self.transformer.wpe.weight.copy_(enc) |
|
|
self.transformer.wpe.weight.requires_grad_(False) |
|
|
|
|
|
|
|
|
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _initialize_weights(self, module: nn.Module) -> None: |
|
|
""" |
|
|
Compatibility shim for newer `transformers` versions. |
|
|
|
|
|
`transformers.PreTrainedModel.initialize_weights()` will treat any submodule that |
|
|
defines `_init_weights` as a nested "sub-model" and will recursively call that |
|
|
submodule's `_initialize_weights`. Our core `GPT` module historically only |
|
|
implemented `_init_weights`, so we provide this wrapper to match HF's contract. |
|
|
""" |
|
|
if getattr(module, "_is_hf_initialized", False): |
|
|
return |
|
|
self._init_weights(module) |
|
|
module._is_hf_initialized = True |
|
|
|
|
|
|
|
|
def get_num_params(self, non_embedding=True): |
|
|
""" |
|
|
Return the number of parameters in the model. |
|
|
For non-embedding count (default), the position embeddings get subtracted. |
|
|
The token embeddings would too, except due to the parameter sharing these |
|
|
params are actually used as weights in the final layer, so we include them. |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
if non_embedding: |
|
|
n_params -= self.transformer.wpe.weight.numel() |
|
|
return n_params |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, idx, targets=None, include_resid_mid=False): |
|
|
device = idx.device |
|
|
b, t = idx.size() |
|
|
|
|
|
assert t <= self.config.block_size, ( |
|
|
f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
tok_emb = self.transformer.wte(idx) |
|
|
|
|
|
pos_emb = self.transformer.wpe.weight[:t].unsqueeze(0) |
|
|
if self.config.cat_pos_emb: |
|
|
x = self.transformer.drop(tok_emb) |
|
|
else: |
|
|
x = self.transformer.drop(tok_emb + pos_emb) |
|
|
|
|
|
if self.config.debug_nans: |
|
|
assert x.isfinite().all(), "nan in initial post-embedding" |
|
|
|
|
|
if self.config.enable_bigram_table: |
|
|
|
|
|
additional_logits_bias = F.embedding(idx, self.bigram_table, padding_idx=-1) |
|
|
additional_logits_bias = additional_logits_bias.to(x.dtype) |
|
|
else: |
|
|
additional_logits_bias = None |
|
|
|
|
|
if self.config.cat_pos_emb: |
|
|
pos_emb_to_cat = pos_emb |
|
|
if self.config.dropout_cat_pos_emb: |
|
|
pos_emb_to_cat = self.transformer.drop(pos_emb) |
|
|
else: |
|
|
pos_emb_to_cat = None |
|
|
|
|
|
return self.forward_tail( |
|
|
x, |
|
|
n=0, |
|
|
targets=targets, |
|
|
additional_logits_bias=additional_logits_bias, |
|
|
include_resid_mid=include_resid_mid, |
|
|
pos_emb_to_cat=pos_emb_to_cat, |
|
|
) |
|
|
|
|
|
def forward_tail( |
|
|
self, |
|
|
x, |
|
|
n, |
|
|
targets=None, |
|
|
additional_logits_bias=None, |
|
|
include_resid_mid=False, |
|
|
pos_emb_to_cat=None, |
|
|
): |
|
|
|
|
|
hs = [] |
|
|
blks = list(self.transformer.h) |
|
|
|
|
|
if include_resid_mid: |
|
|
blks = list_join( |
|
|
[ |
|
|
[ |
|
|
blk.forward_attn_block, |
|
|
blk.forward_mlp_block, |
|
|
] |
|
|
for blk in blks |
|
|
] |
|
|
) |
|
|
|
|
|
assert n <= len(blks) |
|
|
|
|
|
for i, block_fn in enumerate(blks[n:]): |
|
|
global curlayer |
|
|
curlayer = i |
|
|
with hook_namespace(f"{i // 2}") if include_resid_mid else hook_namespace(f"{i}"): |
|
|
hs.append(x) |
|
|
if self.config.cat_pos_emb: |
|
|
x = block_fn(x, pos_emb_to_cat) |
|
|
else: |
|
|
x = block_fn(x) |
|
|
|
|
|
x = hook_save("final_resid", x) |
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
logits = ( |
|
|
self.lm_head(x) |
|
|
+ self.final_logits_bias |
|
|
+ (additional_logits_bias if additional_logits_bias is not None else 0) |
|
|
) |
|
|
if targets is not None: |
|
|
loss = F.cross_entropy( |
|
|
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
|
|
) |
|
|
else: |
|
|
loss = torch.zeros(1, device=x.device) |
|
|
|
|
|
return logits, loss, hs |
|
|
|
|
|
def crop_block_size(self, block_size): |
|
|
|
|
|
|
|
|
|
|
|
assert block_size <= self.config.block_size |
|
|
self.config.block_size = block_size |
|
|
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) |
|
|
for block in self.transformer.h: |
|
|
if hasattr(block.attn, "bias"): |
|
|
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
|
""" |
|
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
|
|
""" |
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
idx_cond = ( |
|
|
idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] |
|
|
) |
|
|
|
|
|
logits, _, _ = self(idx_cond) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, -1:]] = -float("Inf") |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
return idx |
|
|
|
|
|
def is_mlp_param(self, p): |
|
|
return id(p) in list_join( |
|
|
[ |
|
|
[ |
|
|
id(self.transformer.h[i].mlp.c_fc.weight), |
|
|
id(self.transformer.h[i].mlp.c_proj.weight), |
|
|
] |
|
|
for i in range(self.config.n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
def is_param_embed(self, p): |
|
|
return p is self.transformer.wte.weight or p is self.transformer.wpe.weight |
|
|
|
|
|
def is_attn_param(self, p): |
|
|
return id(p) in list_join( |
|
|
[ |
|
|
[ |
|
|
id(self.transformer.h[i].attn.c_attn.weight), |
|
|
id(self.transformer.h[i].attn.c_proj.weight), |
|
|
] |
|
|
for i in range(self.config.n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
def is_bias(self, p): |
|
|
return id(p) in list_join( |
|
|
[ |
|
|
[ |
|
|
id(self.transformer.h[i].attn.c_attn.bias), |
|
|
id(self.transformer.h[i].attn.c_proj.bias), |
|
|
id(self.transformer.h[i].mlp.c_fc.bias), |
|
|
id(self.transformer.h[i].mlp.c_proj.bias), |
|
|
] |
|
|
for i in range(self.config.n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
def is_ln_param(self, p): |
|
|
return id(p) in list_join( |
|
|
[ |
|
|
[ |
|
|
id(self.transformer.h[i].ln_1.weight), |
|
|
id(self.transformer.h[i].ln_2.weight), |
|
|
] |
|
|
for i in range(self.config.n_layer) |
|
|
] |
|
|
) + [ |
|
|
id(self.transformer.ln_f.weight), |
|
|
] |
|
|
|
|
|
def is_sparse_param(self, p, dense_embeddings=None, dense_unembed=None, dense_biases=None): |
|
|
|
|
|
|
|
|
if dense_embeddings is None: |
|
|
assert p is not self.transformer.wte.weight and p is not self.transformer.wpe.weight |
|
|
if dense_unembed is None: |
|
|
assert p is not self.lm_head.weight |
|
|
if dense_biases is None: |
|
|
assert not self.is_bias(p) |
|
|
|
|
|
if p is self.transformer.wte.weight or p is self.transformer.wpe.weight: |
|
|
return not dense_embeddings |
|
|
if p is self.lm_head.weight: |
|
|
return not dense_unembed |
|
|
if self.is_bias(p): |
|
|
return not dense_biases |
|
|
|
|
|
return id(p) in list_join( |
|
|
[ |
|
|
[ |
|
|
id(self.transformer.h[i].attn.c_attn.weight), |
|
|
id(self.transformer.h[i].attn.c_proj.weight), |
|
|
id(self.transformer.h[i].mlp.c_fc.weight), |
|
|
id(self.transformer.h[i].mlp.c_proj.weight), |
|
|
] |
|
|
for i in range(self.config.n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def list_join(xss: list[list]) -> list: |
|
|
"""monadic join for lists""" |
|
|
return [x for xs in xss for x in xs] |
|
|
|