Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # app.py β Chat inference for AGILLM2 (HF Spaces friendly) | |
| # - Auto-detects non-interactive env (Spaces) and launches Gradio UI | |
| # - Loads final.pt from repo OpenTransformer/AGILLM2-fast-training | |
| # - Qwen tokenizer + chat template | |
| # - Optional local CLI REPL when run in a terminal | |
| # - Adds a "Raw transcript" tab with "User:" / "Assistant:" lines | |
| from __future__ import annotations | |
| import os, sys, time, math, argparse | |
| from typing import Optional, Tuple, List, Dict, Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Silence the PyTorch TF32 deprecation nag in Spaces logs (optional) | |
| import warnings | |
| warnings.filterwarnings("ignore", message="Please use the new API settings to control TF32 behavior") | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, logging as hf_log | |
| hf_log.set_verbosity_error() | |
| # βββββββββββββββββββββββββ Config βββββββββββββββββββββββββ | |
| MODEL_REPO = os.getenv("MODEL_REPO", "OpenTransformer/AGILLM2-fast-training") | |
| CKPT_NAME = os.getenv("CKPT_NAME", "final.pt") # e.g., step04121612.pt | |
| TOKENIZER_ID = os.getenv("TOKENIZER_ID", "Qwen/Qwen3-235B-A22B-Thinking-2507") | |
| DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| try: | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| # βββββββββββββββββββββββββ Tokenizer βββββββββββββββββββββββββ | |
| tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True) | |
| if tok.pad_token is None: | |
| tok.add_special_tokens({"pad_token": "[PAD]"}) | |
| VOCAB = max(tok.get_vocab().values()) + 1 | |
| BLANK = tok.pad_token_id | |
| EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id | |
| # βββββββββββββββββββββββββ AMP helper βββββββββββββββββββββββββ | |
| try: | |
| from torch.amp import autocast as _ac, GradScaler # noqa | |
| except Exception: | |
| from torch.cuda.amp import autocast as _ac, GradScaler # noqa | |
| def _supports_fp8() -> bool: | |
| return hasattr(torch, "float8_e4m3fn") | |
| def _auto_amp_dtype(prefer_fp8: bool = False): | |
| if DEV.type != "cuda": | |
| return torch.float32 | |
| if prefer_fp8 and _supports_fp8(): | |
| return torch.float8_e4m3fn | |
| try: | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| return torch.float16 | |
| except Exception: | |
| return torch.float16 | |
| def amp(enabled: bool, prefer_fp8: bool = False): | |
| if not (enabled and DEV.type == "cuda"): | |
| from contextlib import nullcontext | |
| return nullcontext() | |
| return _ac(device_type="cuda", dtype=_auto_amp_dtype(prefer_fp8=prefer_fp8)) | |
| # βββββββββββββββββββββββββ ALiBi helpers βββββββββββββββββββββββββ | |
| def _alibi_slopes(n_heads: int): | |
| import math as _m | |
| def pow2slopes(n): | |
| start = 2 ** (-2 ** -(_m.log2(n) - 3)) | |
| ratio = start | |
| return [start * (ratio ** i) for i in range(n)] | |
| if _m.log2(n_heads).is_integer(): | |
| vals = pow2slopes(n_heads) | |
| else: | |
| closest = 2 ** _m.floor(_m.log2(n_heads)) | |
| vals = pow2slopes(closest); extra = pow2slopes(2 * closest) | |
| vals += extra[0::2][: n_heads - closest] | |
| return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) | |
| def alibi_bias(n_heads: int, n_tokens: int): | |
| i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) | |
| j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) | |
| dist = (j - i).clamp_min(0) | |
| slopes = _alibi_slopes(n_heads) | |
| return -slopes * dist | |
| # βββββββββββββββββββββββββ Model (Encoder + AR head) βββββββββββββββββββββββββ | |
| class LowRankMHA(nn.Module): | |
| def __init__(self, d: int, h: int, r: int, use_relpos: bool = True): | |
| super().__init__() | |
| assert d % h == 0, "d must be divisible by number of heads" | |
| self.h, self.dk = h, d // h | |
| self.use_relpos = use_relpos | |
| self.q = nn.Linear(d, d, bias=False) | |
| self.k = nn.Linear(d, d, bias=False) | |
| self.v = nn.Linear(d, d, bias=False) | |
| self.U = nn.Parameter(torch.randn(self.dk, r)) | |
| nn.init.orthogonal_(self.U) | |
| self.proj = nn.Linear(h * r, d, bias=False) | |
| self.drop = nn.Dropout(0.1) | |
| def _proj(self, x): | |
| B, N, _ = x.shape | |
| return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, | |
| rel_bias_tokens: Optional[int] = None, | |
| kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False): | |
| q = self._proj(self.q(x)) | |
| k_new = self._proj(self.k(x)) | |
| v_new = self._proj(self.v(x)) | |
| if kv_cache is None: | |
| k, v = k_new, v_new | |
| else: | |
| k, v = kv_cache | |
| if use_cache: | |
| k = torch.cat([k, k_new], dim=2) | |
| v = torch.cat([v, v_new], dim=2) | |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) | |
| if q.size(2) == k.size(2): | |
| if self.use_relpos and rel_bias_tokens is not None: | |
| att = att + alibi_bias(self.h, rel_bias_tokens) | |
| if mask is not None: | |
| att = att + mask | |
| z = (att.softmax(-1) @ v).transpose(1, 2) | |
| z = z.reshape(x.size(0), x.size(1), -1) | |
| out = self.drop(self.proj(z)) | |
| return (out, (k, v)) if use_cache else out | |
| class Block(nn.Module): | |
| def __init__(self, d: int, h: int, r: int): | |
| super().__init__() | |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) | |
| self.mha = LowRankMHA(d, h, r, use_relpos=True) | |
| self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) | |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor], | |
| kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False): | |
| n = x.size(1) | |
| if use_cache: | |
| y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=n if mask is not None else None, kv_cache=kv, use_cache=True) | |
| x = x + y | |
| x = x + self.ff(self.ln2(x)) | |
| return x, new_kv | |
| else: | |
| x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n) | |
| return x + self.ff(self.ln2(x)) | |
| class Encoder(nn.Module): | |
| def __init__(self, cfg: Dict[str, int]): | |
| super().__init__() | |
| d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] | |
| self.emb = nn.Embedding(VOCAB, d) | |
| self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(l)]) | |
| self.ln = nn.LayerNorm(d) | |
| def forward(self, ids: torch.Tensor, mask: Optional[torch.Tensor], | |
| kv_caches: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, | |
| use_cache: bool = False): | |
| x = self.emb(ids) | |
| if not use_cache: | |
| for blk in self.blocks: | |
| x = blk(x, mask) | |
| return self.ln(x) | |
| new_kvs: List[Tuple[torch.Tensor, torch.Tensor]] = [] | |
| for i, blk in enumerate(self.blocks): | |
| kv = kv_caches[i] if (kv_caches is not None) else None | |
| x, kv_out = blk(x, mask, kv, use_cache=True) | |
| new_kvs.append(kv_out) | |
| return self.ln(x), new_kvs | |
| class ARHead(nn.Module): | |
| def __init__(self, d): | |
| super().__init__() | |
| self.proj = nn.Linear(d, VOCAB) | |
| def forward(self, h): return self.proj(h) | |
| # βββββββββββββββββββββββββ Misc βββββββββββββββββββββββββ | |
| def causal_mask(n: int): | |
| m = torch.full((1, 1, n, n), float("-inf"), device=DEV) | |
| return torch.triu(m, 1) | |
| def _resolve_cfg_from_ckpt(sd: dict) -> Dict[str, int]: | |
| if isinstance(sd, dict) and "cfg" in sd and isinstance(sd["cfg"], dict): | |
| return dict(sd["cfg"]) | |
| core = sd.get("core", {}) | |
| emb_w = core.get("emb.weight") | |
| if emb_w is None: | |
| raise RuntimeError("Checkpoint missing core.emb.weight; cannot infer d/l/h/r.") | |
| d = emb_w.shape[1] | |
| layer_ids = [] | |
| for k in core.keys(): | |
| if k.startswith("blocks."): | |
| parts = k.split(".") | |
| if len(parts) > 2 and parts[1].isdigit(): | |
| layer_ids.append(int(parts[1])) | |
| layers = (max(layer_ids) + 1) if layer_ids else 0 | |
| U = core.get("blocks.0.mha.U") | |
| if U is None: | |
| raise RuntimeError("Checkpoint missing blocks.0.mha.U; cannot infer rank/heads.") | |
| dk, r = U.shape | |
| h = d // dk | |
| return {"d": d, "layers": layers, "heads": h, "rank": r} | |
| def load_joint_from_hub(repo_id: str, filename: str): | |
| ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename, resume_download=True) | |
| sd = torch.load(ckpt_path, map_location="cpu") | |
| cfg = _resolve_cfg_from_ckpt(sd) | |
| core = Encoder(cfg).to(DEV) | |
| ar_h = ARHead(cfg["d"]).to(DEV) | |
| core.load_state_dict(sd["core"]) | |
| if "ar" in sd: ar_h.load_state_dict(sd["ar"]) | |
| return core, ar_h, cfg | |
| # βββββββββββββββββββββββββ Chat helpers βββββββββββββββββββββββββ | |
| def render_chat(messages: List[Dict[str, str]], add_generation_prompt: bool = True) -> str: | |
| try: | |
| return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) | |
| except Exception: | |
| # Fallback plain format | |
| out = [] | |
| for m in messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "") | |
| out.append(f"{role.capitalize()}: {content}") | |
| if add_generation_prompt: | |
| out.append("Assistant:") | |
| return "\n".join(out) | |
| def render_raw(history: List[Tuple[str, str]] | None, sys_prompt: str) -> str: | |
| lines = [] | |
| if sys_prompt: | |
| lines.append(f"System: {sys_prompt}") | |
| for u, a in (history or []): | |
| lines.append(f"User: {u}") | |
| lines.append(f"Assistant: {a}") | |
| return "\n".join(lines) | |
| def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int): | |
| if n <= 0 or ids.size(1) < n - 1: return logits | |
| prefix = ids[0, -(n - 1):].tolist() | |
| banned, tokens = [], ids[0].tolist() | |
| for i in range(len(tokens) - n + 1): | |
| if tokens[i:i + n - 1] == prefix: | |
| banned.append(tokens[i + n - 1]) | |
| if banned: | |
| banned_idx = torch.tensor(banned, device=logits.device, dtype=torch.long) | |
| logits[..., banned_idx] = float("-inf") | |
| return logits | |
| def _apply_rep_presence_frequency(logits, ids, last_n, repetition_penalty, presence_penalty, frequency_penalty): | |
| if ids.numel() == 0: return logits | |
| hist = ids[0, -last_n:].to(torch.long) if last_n > 0 else ids[0].to(torch.long) | |
| if hist.numel() == 0: return logits | |
| uniq, counts = torch.unique(hist, return_counts=True) | |
| if presence_penalty != 0.0 or frequency_penalty != 0.0: | |
| adjust = presence_penalty + frequency_penalty * counts.to(logits.dtype) | |
| logits[..., uniq] = logits[..., uniq] - adjust | |
| if repetition_penalty and abs(repetition_penalty - 1.0) > 1e-6: | |
| sel = logits[..., uniq] | |
| sel = torch.where(sel > 0, sel / repetition_penalty, sel * repetition_penalty) | |
| logits[..., uniq] = sel | |
| return logits | |
| def _filter_top_k_top_p_min_p(logits: torch.Tensor, top_k: int, top_p: float, min_p: float, temperature: float): | |
| logits = logits / max(temperature, 1e-8) | |
| if logits.dim() == 1: logits = logits.unsqueeze(0) | |
| probs = logits.softmax(-1) | |
| V = probs.size(-1) | |
| if top_k and top_k < V: | |
| vals, idx = torch.topk(probs, top_k, dim=-1) | |
| mask = torch.full_like(probs, 0.0); mask.scatter_(1, idx, 1.0); probs = probs * mask | |
| if top_p < 1.0: | |
| sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1) | |
| cumsum = torch.cumsum(sorted_probs, dim=-1) | |
| keep = cumsum <= top_p; keep[..., 0] = True | |
| mask = torch.zeros_like(probs); mask.scatter_(1, sorted_idx, keep.to(mask.dtype)) | |
| probs = probs * mask | |
| if min_p > 0.0: | |
| probs = torch.where(probs >= min_p, probs, torch.zeros_like(probs)) | |
| sums = probs.sum(-1, keepdim=True); empty = (sums == 0) | |
| if empty.any(): | |
| fallback_idx = logits.argmax(-1, keepdim=True) | |
| probs = torch.where(empty, torch.zeros_like(probs), probs) | |
| probs.scatter_(-1, fallback_idx, torch.where(empty, torch.ones_like(sums), torch.zeros_like(sums))) | |
| probs = probs / probs.sum(-1, keepdim=True) | |
| return probs | |
| def chat_decode(core, ar_h, messages: List[Dict[str, str]], max_new: int = 200, T: float = 0.9, | |
| greedy: bool = False, top_k: int = 50, top_p: float = 0.9, min_p: float = 0.0, | |
| repetition_penalty: float = 1.1, presence_penalty: float = 0.3, frequency_penalty: float = 0.2, | |
| penalty_last_n: int = 128, no_repeat_ngram_size: int = 3, | |
| use_fp8: bool = False, fp8_fallback: bool = True) -> str: | |
| prompt = render_chat(messages, add_generation_prompt=True) | |
| ids = torch.tensor([tok.encode(prompt)], device=DEV) | |
| prompt_len = ids.size(1) | |
| with amp(use_fp8 or False, prefer_fp8=(use_fp8 and (_supports_fp8() or fp8_fallback))): | |
| h_full, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True) | |
| for _ in range(max_new): | |
| logits = ar_h(h_full)[:, -1] | |
| logits = _apply_no_repeat_ngram(logits, ids, no_repeat_ngram_size) | |
| logits = _apply_rep_presence_frequency(logits, ids, penalty_last_n, | |
| repetition_penalty, presence_penalty, frequency_penalty) | |
| if greedy: | |
| nxt = logits.argmax(-1, keepdim=True) | |
| else: | |
| probs = _filter_top_k_top_p_min_p(logits.squeeze(0), top_k, top_p, min_p, T) | |
| nxt = probs.multinomial(1) | |
| ids = torch.cat([ids, nxt.unsqueeze(0) if nxt.dim()==1 else nxt], 1) | |
| x = ids[:, -1:] | |
| h_full, kvs = core(x, None, kv_caches=kvs, use_cache=True) | |
| full_ids = ids[0].tolist() | |
| return tok.decode(full_ids[prompt_len:], skip_special_tokens=True).strip() | |
| # βββββββββββββββββββββββββ UI / CLI βββββββββββββββββββββββββ | |
| def launch_gradio(core, ar_h): | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| gr.Markdown("### OpenTransformer / AGILLM2 β Chat") | |
| with gr.Row(): | |
| temp = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature") | |
| topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| topk = gr.Slider(0, 200, value=50, step=1, label="Top-k") | |
| mxnt = gr.Slider(16, 1024, value=200, step=8, label="Max new tokens") | |
| sys_prompt = gr.Textbox(value="You are a helpful, concise assistant.", label="System prompt") | |
| with gr.Tabs(): | |
| with gr.TabItem("Chat"): | |
| chatbot = gr.Chatbot(height=520, label="Conversation") | |
| msg = gr.Textbox(placeholder="Say something usefulβ¦", label="Message") | |
| submit = gr.Button("Send", variant="primary") | |
| with gr.TabItem("Raw transcript"): | |
| raw = gr.Textbox(lines=24, label="Raw transcript (User:/Assistant:)", interactive=False) | |
| clear = gr.Button("Clear", variant="secondary") | |
| def _chat(history, user_msg, t, p, k, mxt, sys_p): | |
| # Build messages from history + new user message | |
| messages = [{"role":"system","content":sys_p}] | |
| for u,a in history or []: | |
| messages.append({"role":"user","content":u}) | |
| messages.append({"role":"assistant","content":a}) | |
| messages.append({"role":"user","content":user_msg}) | |
| reply = chat_decode( | |
| core, ar_h, messages, | |
| max_new=int(mxt), T=float(t), | |
| greedy=False, top_k=int(k), top_p=float(p), | |
| use_fp8=False, fp8_fallback=True | |
| ) | |
| history = (history or []) + [(user_msg, reply)] | |
| transcript = render_raw(history, sys_p) | |
| return history, "", transcript | |
| # Wire up events: submit via button or enter | |
| msg.submit(_chat, [chatbot, msg, temp, topp, topk, mxnt, sys_prompt], [chatbot, msg, raw], queue=False) | |
| submit.click(_chat, [chatbot, msg, temp, topp, topk, mxnt, sys_prompt], [chatbot, msg, raw], queue=False) | |
| def _clear(): | |
| return [], "", "" | |
| clear.click(_clear, inputs=None, outputs=[chatbot, msg, raw], queue=False) | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860"))) | |
| def run_cli(core, ar_h): | |
| print("Type to chat. Ctrl+C to exit.") | |
| history: List[Tuple[str,str]] = [] | |
| while True: | |
| try: | |
| user = input("\nYou: ").strip() | |
| if not user: continue | |
| messages = [{"role":"system","content":"You are a helpful, concise assistant."}] | |
| for u,a in history: | |
| messages.append({"role":"user","content":u}) | |
| messages.append({"role":"assistant","content":a}) | |
| messages.append({"role":"user","content":user}) | |
| t0 = time.time() | |
| reply = chat_decode(core, ar_h, messages, max_new=200, T=0.9, top_k=50, top_p=0.9) | |
| dt = time.time()-t0 | |
| print(f"Bot: {reply}\n[{len(tok.encode(reply))} tok in {dt:.2f}s]") | |
| history.append((user, reply)) | |
| # Also show raw transcript line by line in CLI | |
| print("\n--- RAW ---") | |
| print(render_raw(history, "You are a helpful, concise assistant.")) | |
| except KeyboardInterrupt: | |
| print("\nbye."); break | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--cli", action="store_true", help="Force CLI REPL even if not a TTY") | |
| ap.add_argument("--gradio", action="store_true", help="Force Gradio UI") | |
| args = ap.parse_args() | |
| print(f"[init] downloading checkpoint {CKPT_NAME} from {MODEL_REPO} β¦", flush=True) | |
| core, ar_h, cfg = load_joint_from_hub(MODEL_REPO, CKPT_NAME) | |
| core.eval(); ar_h.eval() | |
| print(f"[ready] cfg={cfg} device={DEV.type} vocab={VOCAB}", flush=True) | |
| # Spaces have no interactive stdin. Auto-launch Gradio if not a TTY. | |
| in_tty = sys.stdin.isatty() | |
| if args.gradio or (not args.cli and not in_tty): | |
| launch_gradio(core, ar_h) | |
| else: | |
| run_cli(core, ar_h) | |
| if __name__ == "__main__": | |
| main() | |