Instructions to use Motif-Technologies/activation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/activation with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/activation") - Notebooks
- Google Colab
- Kaggle
| """Numerical parity test: activation MLA RoPE kernels vs PyTorch reference. | |
| The activation package exposes two Triton kernels for Motif3 MLA attention: | |
| * fused_q_rope_inplace — in-place RoPE on q's rope section | |
| * fused_kv_split_rope_cat — split kv_latent + register-broadcast k_pe to H heads + cat | |
| This test runs both the fused path and a pure-PyTorch reference over identical | |
| inputs (forward + backward) and compares all outputs and input gradients. | |
| Self-contained: the reference RoPE implementation lives in this file (no | |
| upstream model code dependency). | |
| """ | |
| import pytest | |
| import torch | |
| import activation | |
| from .utils import assert_close | |
| # Realistic motif3_seq per-GPU shapes (B=local_batch_size, H_q/H_kv per MLA spec). | |
| SHAPES = [ | |
| # (B, S, H_q, H_kv, D_nope, D_rope, D_v) | |
| (8, 4096, 80, 16, 128, 64, 128), | |
| ] | |
| DTYPES = [torch.bfloat16] | |
| SEEDS = [0] | |
| # ------------------------------------------------------------------ reference | |
| def _precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| t = torch.arange(end, dtype=torch.float32) | |
| freqs = torch.outer(t, freqs) | |
| return torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| def _apply_rotary_emb_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| """[B, S, H, D] interleaved → rotated, in interleaved layout.""" | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| freqs_cis = freqs_cis[: x_.shape[1]].view(1, x_.shape[1], 1, x_.shape[3]) | |
| out = torch.view_as_real(x_ * freqs_cis).flatten(3) | |
| return out.type_as(x) | |
| def _reorder_headdim_elements_rope(qk: torch.Tensor, B: int, S: int, rope_dim: int) -> torch.Tensor: | |
| """Interleaved [r0,i0,r1,i1,...] → contiguous [r0,r1,...,i0,i1,...].""" | |
| qk = qk.view(B, S, -1, rope_dim // 2, 2) | |
| qk = qk.transpose(3, 4) | |
| return qk.reshape(B, S, -1, rope_dim) | |
| def vanilla_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v): | |
| # Q | |
| q_nope, q_pe = torch.split(q, [D_nope, D_rope], dim=-1) | |
| q_pe = _apply_rotary_emb_single(q_pe, freqs_cis) | |
| q_pe = _reorder_headdim_elements_rope(q_pe, B, S, D_rope) | |
| q_total = torch.cat([q_nope, q_pe], dim=-1) | |
| # k_pe (head-shared, H=1) | |
| k_pe_4d = k_pe.unsqueeze(2) | |
| k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis) | |
| k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope) | |
| # KV split + expand + cat | |
| k_nope, v = torch.split(kv_latent, [D_nope, D_v], dim=-1) | |
| k_full = torch.cat([k_nope, k_pe_roped.expand(-1, -1, H_kv, -1)], dim=-1) | |
| return q_total, k_full, v | |
| def fused_path(q, kv_latent, k_pe, freqs_cis, B, S, H_kv, D_nope, D_rope, D_v): | |
| q_total = activation.fused_q_rope_inplace(q, freqs_cis, D_nope, D_rope) | |
| # k_pe RoPE stays PyTorch native (head-shared; standalone Triton kernel was | |
| # launch-bound on B200, no measurable win — see PR #22). | |
| k_pe_4d = k_pe.unsqueeze(2) | |
| k_pe_roped = _apply_rotary_emb_single(k_pe_4d, freqs_cis) | |
| k_pe_roped = _reorder_headdim_elements_rope(k_pe_roped, B, S, D_rope) | |
| k_full, v = activation.fused_kv_split_rope_cat( | |
| kv_latent, k_pe_roped, D_nope, D_v, D_rope | |
| ) | |
| return q_total, k_full, v | |
| # ------------------------------------------------------------------ harness | |
| def _run_with_grad(path_fn, q, kv_latent, k_pe, freqs_cis, **shape_kwargs): | |
| # Inputs come in as leaves; thread through a no-op so the in-place fused_q | |
| # kernel sees a non-leaf (mirrors the real model where q is a Linear output). | |
| q_leaf, kv_leaf, kpe_leaf = ( | |
| q.clone().detach().requires_grad_(True), | |
| kv_latent.clone().detach().requires_grad_(True), | |
| k_pe.clone().detach().requires_grad_(True), | |
| ) | |
| q_in, kv_in, kpe_in = q_leaf * 1.0, kv_leaf * 1.0, kpe_leaf * 1.0 | |
| q_total, k_full, v = path_fn(q_in, kv_in, kpe_in, freqs_cis, **shape_kwargs) | |
| loss = (q_total.float() ** 2).sum() + (k_full.float() ** 2).sum() + (v.float() ** 2).sum() | |
| loss.backward() | |
| return ( | |
| q_total.detach(), k_full.detach(), v.detach(), | |
| q_leaf.grad.detach(), kv_leaf.grad.detach(), kpe_leaf.grad.detach(), | |
| ) | |
| # ------------------------------------------------------------------ test | |
| def test_mla_rope_fused_vs_reference(shape, dtype, seed): | |
| B, S, H_q, H_kv, D_nope, D_rope, D_v = shape | |
| D_qk = D_nope + D_rope | |
| device = "cuda" | |
| torch.manual_seed(seed) | |
| freqs_cis = _precompute_freqs_cis(D_rope, S).to(device) | |
| q = (torch.randn(B, S, H_q, D_qk, device=device, dtype=dtype) * 0.5) | |
| kv_latent = (torch.randn(B, S, H_kv, D_nope + D_v, device=device, dtype=dtype) * 0.5) | |
| k_pe = (torch.randn(B, S, D_rope, device=device, dtype=dtype) * 0.5) | |
| kw = dict(B=B, S=S, H_kv=H_kv, D_nope=D_nope, D_rope=D_rope, D_v=D_v) | |
| van_q, van_k, van_v, van_gq, van_gkv, van_gkpe = _run_with_grad( | |
| vanilla_path, q, kv_latent, k_pe, freqs_cis, **kw | |
| ) | |
| our_q, our_k, our_v, our_gq, our_gkv, our_gkpe = _run_with_grad( | |
| fused_path, q, kv_latent, k_pe, freqs_cis, **kw | |
| ) | |
| # Forward outputs: small bf16 jitter expected on the q rope rotation | |
| # (Triton fp32 accum vs inductor fp32 complex_mul order). | |
| assert_close(our_q.float(), van_q.float(), atol=1e-2, rtol=1e-2) | |
| # KV path is bit-exact (just slice + register broadcast + store). | |
| assert_close(our_k.float(), van_k.float(), atol=0.0, rtol=0.0) | |
| assert_close(our_v.float(), van_v.float(), atol=0.0, rtol=0.0) | |
| # Input grads. | |
| assert_close(our_gq.float(), van_gq.float(), atol=1e-2, rtol=1e-2) | |
| assert_close(our_gkv.float(), van_gkv.float(), atol=0.0, rtol=0.0) | |
| assert_close(our_gkpe.float(), van_gkpe.float(), atol=0.0, rtol=0.0) | |