Kernels
activation / tests /test_mla_rope_grad.py
3v324v23's picture
test: numerical parity for MLA RoPE fused kernels vs PyTorch reference
0c42208
"""Numerical parity test: activation MLA RoPE kernels vs PyTorch reference.
The activation package exposes two Triton kernels for Motif3 MLA attention:
* fused_q_rope_inplace — in-place RoPE on q's rope section
* fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat
This test runs both the fused path and a pure-PyTorch reference over identical
inputs (forward + backward) and compares all outputs and input gradients.
Self-contained: the reference RoPE implementation lives in this file (no
upstream model code dependency).
"""
import pytest
import torch
import activation
from .utils import assert_close
# Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec).
SHAPES = [
# (B, S, H_q, H_kv, D_nope, D_rope, D_v)
(8, 4096, 80, 16, 128, 64, 128),
]
DTYPES = [torch.bfloat16]
SEEDS = [0]
# ------------------------------------------------------------------ reference
def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs) # complex64
def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""[B, S, H, D] interleaved → rotated, in interleaved layout."""
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3])
out = torch.view_as_real(x_ * freqs_cis).flatten(3)
return out.type_as(x)
def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor:
"""Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...]."""
qk = qk.view(B, S, -1, rope_dim // 2, 2)
qk = qk.transpose(3, 4)
return qk.reshape(B, S, -1, rope_dim)
def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
# Q
q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1)
q_pe = _apply_rotary_emb_single(q_pe, freqs_cis)
q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope)
q_total = torch.cat([q_nope, q_pe], dim=-1)
# k_pe (head-shared, H=1)
k_pe_4d = k_pe.unsqueeze(2)
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
# KV split + expand + cat
k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1)
k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1)
return q_total, k_full, v
def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v):
q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope)
# k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was
# launch-bound on B200, no measurable win — see PR #22).
k_pe_4d = k_pe.unsqueeze(2)
k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis)
k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope)
k_full, v = activation.fused_kv_split_rope_cat(
kv_latent, k_pe_roped, D_nope, D_v, D_rope
)
return q_total, k_full, v
# ------------------------------------------------------------------ harness
def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs):
# Inputs come in as leaves; thread through a no-op so the in-place fused_q
# kernel sees a non-leaf (mirrors the real model where q is a Linear output).
q_leaf, kv_leaf, kpe_leaf = (
q.clone().detach().requires_grad_(True),
kv_latent.clone().detach().requires_grad_(True),
k_pe.clone().detach().requires_grad_(True),
)
q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0
q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs)
loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum()
loss.backward()
return (
q_total.detach(), k_full.detach(), v.detach(),
q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(),
)
# ------------------------------------------------------------------ test
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
def test_mla_rope_fused_vs_reference(shape, dtype, seed):
B, S, H_q, H_kv, D_nope, D_rope, D_v = shape
D_qk = D_nope + D_rope
device = "cuda"
torch.manual_seed(seed)
freqs_cis = _precompute_freqs_cis(D_rope, S).to(device)
q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5)
kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5)
k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5)
kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v)
van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad(
vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw
)
our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad(
fused_path, q, kv_latent, k_pe, freqs_cis, **kw
)
# Forward outputs: small bf16 jitter expected on the q rope rotation
# (Triton fp32 accum vs inductor fp32 complex_mul order).
assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2)
# KV path is bit-exact (just slice + register broadcast + store).
assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0)
assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0)
# Input grads.
assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2)
assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0)
assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0)