dcode / app.py
twarner's picture
Use Gradio Default theme for proper light/dark switching
4624764
"""dcode - Text to Polargraph Gcode via Stable Diffusion"""
import re
import os
import json
import gradio as gr
import torch
import torch.nn as nn
from pathlib import Path
import spaces
# Machine limits
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5}
# Model cache
_model = None
# ============================================================================
# V3 DECODER ARCHITECTURE
# ============================================================================
class GcodeDecoderConfigV3:
"""Config for v3 decoder architecture."""
def __init__(
self,
latent_channels: int = 4,
latent_size: int = 64,
hidden_size: int = 1024,
num_layers: int = 12,
num_heads: int = 16,
vocab_size: int = 8192,
max_seq_len: int = 2048,
dropout: float = 0.1,
ffn_mult: int = 4,
):
self.latent_channels = latent_channels
self.latent_size = latent_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.dropout = dropout
self.ffn_mult = ffn_mult
class CNNLatentProjector(nn.Module):
"""CNN-based latent projector preserving spatial structure."""
def __init__(self, config: GcodeDecoderConfigV3):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(config.latent_channels, 64, 3, stride=2, padding=1),
nn.LayerNorm([64, 32, 32]),
nn.GELU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.LayerNorm([128, 16, 16]),
nn.GELU(),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.LayerNorm([256, 8, 8]),
nn.GELU(),
nn.Conv2d(256, config.hidden_size, 3, stride=2, padding=1),
nn.LayerNorm([config.hidden_size, 4, 4]),
nn.GELU(),
)
self.num_memory_tokens = 16
self.memory_pos = nn.Parameter(torch.randn(1, self.num_memory_tokens, config.hidden_size) * 0.02)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
B = latent.shape[0]
x = self.cnn(latent)
x = x.view(B, x.shape[1], -1).transpose(1, 2)
x = x + self.memory_pos.expand(B, -1, -1)
return x
class GcodeDecoderV3(nn.Module):
"""Large transformer decoder for gcode generation (v3)."""
def __init__(self, config: GcodeDecoderConfigV3):
super().__init__()
self.config = config
self.latent_proj = CNNLatentProjector(config)
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
self.embed_drop = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.hidden_size * config.ffn_mult,
dropout=config.dropout,
activation='gelu',
batch_first=True,
norm_first=True,
)
for _ in range(config.num_layers)
])
self.ln_f = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
B, seq_len = input_ids.shape
device = input_ids.device
dtype = latent.dtype
memory = self.latent_proj(latent)
positions = torch.arange(seq_len, device=device)
x = self.token_embed(input_ids) + self.pos_embed(positions)
x = self.embed_drop(x)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
for layer in self.layers:
x = layer(x, memory, tgt_mask=causal_mask)
x = self.ln_f(x)
return self.lm_head(x)
# ============================================================================
# V2 DECODER ARCHITECTURE (for backwards compatibility)
# ============================================================================
class GcodeDecoderConfigV2:
def __init__(
self,
latent_channels: int = 4,
latent_size: int = 64,
hidden_size: int = 768,
num_layers: int = 6,
num_heads: int = 12,
vocab_size: int = 32128,
max_seq_len: int = 1024,
dropout: float = 0.1,
):
self.latent_channels = latent_channels
self.latent_size = latent_size
self.latent_dim = latent_channels * latent_size * latent_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.dropout = dropout
class GcodeDecoderV2(nn.Module):
def __init__(self, config: GcodeDecoderConfigV2):
super().__init__()
self.config = config
self.latent_proj = nn.Sequential(
nn.Linear(config.latent_dim, config.hidden_size * 4),
nn.GELU(),
nn.Linear(config.hidden_size * 4, config.hidden_size * 16),
nn.LayerNorm(config.hidden_size * 16),
)
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.hidden_size * 4,
dropout=config.dropout,
activation='gelu',
batch_first=True,
norm_first=True,
)
for _ in range(config.num_layers)
])
self.ln_f = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embed.weight
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
device = input_ids.device
dtype = latent.dtype
latent_flat = latent.view(batch_size, -1)
memory = self.latent_proj(latent_flat)
memory = memory.view(batch_size, 16, self.config.hidden_size)
positions = torch.arange(seq_len, device=device)
x = self.token_embed(input_ids) + self.pos_embed(positions)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype)
for layer in self.layers:
x = layer(x, memory, tgt_mask=causal_mask)
x = self.ln_f(x)
return self.lm_head(x)
# ============================================================================
# MODEL LOADING
# ============================================================================
def get_model():
"""Load and cache the SD-Gcode model."""
global _model
if _model is None:
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from huggingface_hub import hf_hub_download
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("Loading SD-Gcode model...")
# Download config and weights from v3 model
config_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "config.json")
weights_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "pytorch_model.bin")
with open(config_path) as f:
config = json.load(f)
# Determine model version
gcode_cfg = config.get("gcode_decoder", {})
is_v3 = gcode_cfg.get("ffn_mult") is not None or gcode_cfg.get("hidden_size", 768) >= 1024
print(f"Model version: {'v3' if is_v3 else 'v2'}")
# Load SD pipeline
sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5")
print(f"Loading SD from {sd_model_id}...")
pipe = StableDiffusionPipeline.from_pretrained(
sd_model_id,
torch_dtype=dtype,
safety_checker=None,
).to(device)
# Build decoder based on version
if is_v3:
decoder_config = GcodeDecoderConfigV3(
latent_channels=gcode_cfg.get("latent_channels", 4),
latent_size=gcode_cfg.get("latent_size", 64),
hidden_size=gcode_cfg.get("hidden_size", 1024),
num_layers=gcode_cfg.get("num_layers", 12),
num_heads=gcode_cfg.get("num_heads", 16),
vocab_size=gcode_cfg.get("vocab_size", 8192),
max_seq_len=gcode_cfg.get("max_seq_len", 2048),
ffn_mult=gcode_cfg.get("ffn_mult", 4),
)
gcode_decoder = GcodeDecoderV3(decoder_config).to(device, dtype)
else:
decoder_config = GcodeDecoderConfigV2(
latent_channels=gcode_cfg.get("latent_channels", 4),
latent_size=gcode_cfg.get("latent_size", 64),
hidden_size=gcode_cfg.get("hidden_size", 768),
num_layers=gcode_cfg.get("num_layers", 6),
num_heads=gcode_cfg.get("num_heads", 12),
vocab_size=gcode_cfg.get("vocab_size", 32128),
max_seq_len=gcode_cfg.get("max_seq_len", 1024),
)
gcode_decoder = GcodeDecoderV2(decoder_config).to(device, dtype)
# Load weights
print("Loading finetuned weights...")
state_dict = torch.load(weights_path, map_location=device, weights_only=False)
# Load SD components if present
text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items()
if k.startswith("text_encoder.")}
if text_encoder_state:
pipe.text_encoder.load_state_dict(text_encoder_state, strict=False)
print(f"Loaded {len(text_encoder_state)} text encoder weights")
unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items()
if k.startswith("unet.")}
if unet_state:
pipe.unet.load_state_dict(unet_state, strict=False)
print(f"Loaded {len(unet_state)} UNet weights")
# Load decoder weights
decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items()
if k.startswith("gcode_decoder.")}
if decoder_state:
try:
gcode_decoder.load_state_dict(decoder_state, strict=True)
print(f"Loaded {len(decoder_state)} decoder weights (strict)")
except Exception as e:
print(f"Strict load failed: {e}")
gcode_decoder.load_state_dict(decoder_state, strict=False)
print(f"Loaded {len(decoder_state)} decoder weights (non-strict)")
gcode_decoder.eval()
# Load gcode tokenizer
try:
# Try loading custom tokenizer from v3 model
tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode-v3", "gcode_tokenizer/tokenizer.json")
gcode_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_path,
pad_token="<pad>",
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
)
# Verify special tokens
print(f"Loaded custom gcode tokenizer (vocab={gcode_tokenizer.vocab_size})")
print(f" BOS='{gcode_tokenizer.bos_token}' (id={gcode_tokenizer.bos_token_id})")
print(f" EOS='{gcode_tokenizer.eos_token}' (id={gcode_tokenizer.eos_token_id})")
print(f" PAD='{gcode_tokenizer.pad_token}' (id={gcode_tokenizer.pad_token_id})")
# Test encode/decode
test = "G0 X100 Y200\nG1 X150 Y250"
enc = gcode_tokenizer.encode(test)
dec = gcode_tokenizer.decode(enc)
print(f" Test encode: {len(enc)} tokens")
print(f" Test decode: '{dec[:50]}...'")
except Exception as e:
print(f"Failed to load custom tokenizer: {e}")
import traceback
traceback.print_exc()
# Fallback to T5 tokenizer
gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
print("Using fallback T5 tokenizer")
_model = {
"pipe": pipe,
"gcode_decoder": gcode_decoder,
"gcode_tokenizer": gcode_tokenizer,
"device": device,
"dtype": dtype,
"num_inference_steps": config.get("num_inference_steps", 20),
"is_v3": is_v3,
}
print("Model loaded!")
return _model
# ============================================================================
# GCODE PROCESSING
# ============================================================================
def is_valid_coord(s: str) -> bool:
"""Check if a string is a valid coordinate number."""
try:
v = float(s)
return -1000 < v < 1000 # Reasonable bounds
except (ValueError, TypeError):
return False
def clean_gcode(gcode: str) -> str:
"""Clean up generated gcode - fix formatting, remove garbage."""
# Replace any remaining <newline> tokens
gcode = gcode.replace("<newline>", "\n")
# If no/few newlines, split on command boundaries
if gcode.count("\n") < 10:
# Split before each gcode command
gcode = re.sub(r'([GM]\d+)', r'\n\1', gcode)
# Add spaces after G0/G1 if missing: G0X -> G0 X
gcode = re.sub(r'(G[01])([XYZ])', r'\1 \2', gcode)
gcode = re.sub(r'(G[01])F', r'\1 F', gcode)
# Clean up each line
cleaned_lines = []
seen_coords = set() # Track to detect stuck coordinates
for line in gcode.split("\n"):
line = line.strip()
if not line:
continue
# Skip garbage/metadata lines
if line.lower() in ["dcode", "gcode", "code", "output"]:
continue
if line.startswith("Source:") or line.startswith(";Generated"):
continue
if line.startswith("Workarea:") or line.startswith("Algorithm:"):
continue
# Skip lines with mixed axis prefixes: Y-X-288 or X-Y-100
if re.search(r'X-Y-|Y-X-|X-X-|Y-Y-', line):
continue
# Fix double negatives: X--411 -> X-411
line = re.sub(r'X--(\d)', r'X-\1', line)
line = re.sub(r'Y--(\d)', r'Y-\1', line)
# Fix missing spaces: G1X -> G1 X
line = re.sub(r'(G[01])X', r'\1 X', line)
line = re.sub(r'(G[01])Y', r'\1 Y', line)
# Validate coordinates - extract and check
x_match = re.search(r'X([-\d.]+)', line)
y_match = re.search(r'Y([-\d.]+)', line)
# If line has X or Y, validate them
if x_match:
if not is_valid_coord(x_match.group(1)):
continue # Skip malformed line
if y_match:
if not is_valid_coord(y_match.group(1)):
continue # Skip malformed line
# Check for stuck coordinates (repeated positions)
if x_match and y_match:
try:
coord = (round(float(x_match.group(1)), 1), round(float(y_match.group(1)), 1))
if coord in seen_coords:
# Skip if we've seen this exact coordinate recently
if len(seen_coords) > 5:
continue
seen_coords.add(coord)
# Keep only last 50 coords
if len(seen_coords) > 50:
seen_coords = set(list(seen_coords)[-50:])
except ValueError:
pass
# Only keep lines starting with valid gcode commands
if line and line[0] in "GMgm;":
cleaned_lines.append(line)
result = "\n".join(cleaned_lines)
print(f"Cleaned gcode: {len(cleaned_lines)} lines")
return result
def center_and_scale_gcode(gcode: str) -> str:
"""Center the drawing on the workplane and scale to fill 80% of it."""
lines = gcode.split("\n")
# Extract all valid coordinates (filter outliers)
coords = []
for line in lines:
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match and y_match:
try:
x = float(x_match.group(1))
y = float(y_match.group(1))
# Only include reasonable coordinates
if -1000 < x < 1000 and -1000 < y < 1000:
coords.append((x, y))
except ValueError:
pass
if len(coords) < 2:
return gcode # Nothing to transform
# Find bounding box
xs = [c[0] for c in coords]
ys = [c[1] for c in coords]
min_x, max_x = min(xs), max(xs)
min_y, max_y = min(ys), max(ys)
# Current dimensions
width = max_x - min_x
height = max_y - min_y
if width < 1 or height < 1:
return gcode # Degenerate case
# Target: 80% of workplane, centered
target_width = (BOUNDS["right"] - BOUNDS["left"]) * 0.8
target_height = (BOUNDS["top"] - BOUNDS["bottom"]) * 0.8
# Scale to fit (maintain aspect ratio)
scale = min(target_width / width, target_height / height)
# Center of current drawing
cx = (min_x + max_x) / 2
cy = (min_y + max_y) / 2
# Center of workplane
target_cx = (BOUNDS["left"] + BOUNDS["right"]) / 2
target_cy = (BOUNDS["bottom"] + BOUNDS["top"]) / 2
print(f"Centering: bbox=({min_x:.0f},{min_y:.0f})-({max_x:.0f},{max_y:.0f}), scale={scale:.2f}")
# Transform each line
result = []
for line in lines:
new_line = line
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
new_x = (x - cx) * scale + target_cx
new_x = max(BOUNDS["left"], min(BOUNDS["right"], new_x))
new_line = re.sub(r"X[-\d.]+", f"X{new_x:.2f}", new_line, count=1, flags=re.IGNORECASE)
except ValueError:
pass
if y_match:
try:
y = float(y_match.group(1))
new_y = (y - cy) * scale + target_cy
new_y = max(BOUNDS["bottom"], min(BOUNDS["top"], new_y))
new_line = re.sub(r"Y[-\d.]+", f"Y{new_y:.2f}", new_line, count=1, flags=re.IGNORECASE)
except ValueError:
pass
result.append(new_line)
return "\n".join(result)
def validate_gcode(gcode: str) -> str:
"""Clamp coordinates to machine bounds."""
lines = []
for line in gcode.split("\n"):
corrected = line
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
x = max(BOUNDS["left"], min(BOUNDS["right"], x))
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if y_match:
try:
y = float(y_match.group(1))
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y))
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE)
except ValueError:
pass
lines.append(corrected)
return "\n".join(lines)
def gcode_to_svg(gcode: str) -> str:
"""Convert gcode to SVG for visual preview."""
paths = []
current_path = []
x, y = 0.0, 0.0
pen_down = False
# Replace newline tokens with actual newlines
gcode = gcode.replace("<newline>", "\n")
# Split concatenated gcode into separate commands
# First split on explicit newlines
lines = []
for raw_line in gcode.split("\n"):
raw_line = raw_line.strip()
if not raw_line:
continue
# Split on command boundaries (G0, G1, M280, etc)
parts = re.split(r'(?=[GM]\d)', raw_line)
for part in parts:
part = part.strip()
if part and not part.startswith(";") and part[0] in "GMgm":
lines.append(part)
for line in lines:
if "M280" in line.upper():
match = re.search(r"S(\d+)", line, re.IGNORECASE)
if match:
angle = int(match.group(1))
was_down = pen_down
pen_down = angle < 50
if was_down and not pen_down and len(current_path) > 1:
paths.append(current_path[:])
current_path = []
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE)
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE)
if x_match:
try:
x = float(x_match.group(1))
except ValueError:
pass
if y_match:
try:
y = float(y_match.group(1))
except ValueError:
pass
if (x_match or y_match) and pen_down:
current_path.append((x, y))
if len(current_path) > 1:
paths.append(current_path)
w = BOUNDS["right"] - BOUNDS["left"]
h = BOUNDS["top"] - BOUNDS["bottom"]
padding = 20
# SVG with theme-aware colors using CSS variables
svg = f'''<svg xmlns="http://www.w3.org/2000/svg"
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}"
class="gcode-preview"
style="width: 100%; height: 480px; border-radius: 8px; border: 1px solid var(--block-border-color); background: var(--block-background-fill);">
<defs>
<style>
.gcode-preview .work-area {{ fill: var(--background-fill-primary); stroke: var(--block-border-color); }}
.gcode-preview .draw-path {{ stroke: var(--body-text-color); }}
.gcode-preview .info-text {{ fill: var(--body-text-color-subdued); }}
</style>
</defs>
<rect class="work-area" x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" stroke-width="1"/>
'''
for path in paths:
if len(path) < 2:
continue
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path))
svg += f'<path class="draw-path" d="{d}" fill="none" stroke-width="1" stroke-linecap="round" stroke-linejoin="round"/>'
total_points = sum(len(p) for p in paths)
svg += f'''
<text class="info-text" x="{BOUNDS["left"] + 8}" y="{-BOUNDS["top"] + 20}" font-family="monospace" font-size="12">
{len(paths)} paths / {total_points} points
</text>
'''
svg += "</svg>"
return svg
# ============================================================================
# GENERATION
# ============================================================================
def enhance_prompt(prompt: str) -> str:
"""Enhance prompt to match BLIP caption style from training data.
BLIP generates captions like:
- "a drawing of a horse"
- "a sketch of a cat"
- "a black and white drawing"
- "an illustration of a flower"
"""
prompt = prompt.strip().lower()
# Already in BLIP style
if prompt.startswith(("a ", "an ", "the ")):
enhanced = prompt
# Has style keyword
elif any(x in prompt for x in ["drawing", "sketch", "illustration", "image"]):
enhanced = f"a {prompt}"
# Simple noun - wrap in BLIP style
else:
enhanced = f"a drawing of a {prompt}"
# Add subtle style hints (BLIP often includes these)
enhanced += ", black and white, simple lines, sketch style"
return enhanced
@spaces.GPU
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float, seed: int = -1):
"""Generate gcode from text prompt."""
if not prompt or not prompt.strip():
return "Enter a prompt to generate gcode", gcode_to_svg("")
try:
m = get_model()
pipe = m["pipe"]
gcode_decoder = m["gcode_decoder"]
gcode_tokenizer = m["gcode_tokenizer"]
device = m["device"]
dtype = m["dtype"]
is_v3 = m.get("is_v3", False)
# Enhance prompt for better line drawing generation
enhanced = enhance_prompt(prompt)
print(f"Enhanced prompt: {enhanced}")
# Set seed for reproducibility
generator = None
if seed >= 0:
generator = torch.Generator(device=device).manual_seed(int(seed))
print(f"Using seed: {seed}")
# Text -> Latent via SD diffusion
with torch.no_grad():
# Use negative prompt to avoid unwanted styles
result = pipe(
enhanced,
negative_prompt="color, shading, gradient, photorealistic, 3d, complex, detailed texture",
num_inference_steps=num_steps,
guidance_scale=guidance,
output_type="latent",
generator=generator,
)
latent = result.images.to(dtype)
print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}")
# Latent -> Gcode via trained decoder
with torch.no_grad():
batch_size = latent.shape[0]
# Get proper token IDs
bos_id = gcode_tokenizer.bos_token_id
eos_id = gcode_tokenizer.eos_token_id
pad_id = gcode_tokenizer.pad_token_id
# For v3, start with BOS token; for v2, encode gcode header
if is_v3:
# Use the gcode header as the starting prompt
start_text = "G21\nG90\nM280 P0 S90\nG28\n"
start_tokens = gcode_tokenizer.encode(start_text, add_special_tokens=False)
if bos_id is not None:
start_tokens = [bos_id] + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=device)
else:
start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False)
start_id = start_tokens[0] if start_tokens else 0
input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
print(f"Starting with {input_ids.shape[1]} tokens, BOS={bos_id}, EOS={eos_id}")
max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - input_ids.shape[1])
# Track for repetition detection
recent_tokens = []
for step in range(max_gen):
logits = gcode_decoder(latent, input_ids)
next_logits = logits[:, -1, :] / temperature
# Suppress pad and unk tokens
if pad_id is not None:
next_logits[:, pad_id] = float('-inf')
next_logits[:, 1] = float('-inf') # <unk>
# Repetition penalty - stronger to prevent garbage
if recent_tokens:
for token_id in set(recent_tokens[-50:]):
next_logits[:, token_id] *= 0.5 # Stronger penalty
# Top-k + Top-p sampling
top_k = 50
top_p = 0.92
# Top-k filtering
top_k_logits, top_k_indices = torch.topk(next_logits, top_k, dim=-1)
# Top-p filtering
sorted_logits, sorted_idx = torch.sort(top_k_logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
sorted_logits[sorted_indices_to_remove] = float('-inf')
probs = torch.softmax(sorted_logits, dim=-1)
sampled_idx = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, sorted_idx.gather(-1, sampled_idx))
input_ids = torch.cat([input_ids, next_token], dim=1)
recent_tokens.append(next_token.item())
# Debug first few tokens
if step < 5:
tok_str = gcode_tokenizer.decode([next_token.item()])
print(f" Step {step}: token={next_token.item()}, str='{tok_str}'")
# Check EOS
if eos_id is not None and next_token.item() == eos_id:
print(f"Hit EOS at step {step}")
break
# Early stop on repetition
if len(recent_tokens) > 30:
if len(set(recent_tokens[-30:])) < 5:
print(f"Stopping due to repetition at step {step}")
break
print(f"Generated {input_ids.shape[1]} total tokens")
# Decode WITHOUT skipping special tokens (so we keep <newline>)
gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=False)
# Manually remove the special tokens we don't want, but keep <newline>
gcode = gcode.replace("<pad>", "").replace("<s>", "").replace("</s>", "").replace("<unk>", "")
# Now convert <newline> to actual newlines
gcode = gcode.replace("<newline>", "\n")
print(f"Raw decoded (first 300 chars): {repr(gcode[:300])}")
# Clean up the gcode
gcode = clean_gcode(gcode)
# Center and scale to fill workplane
gcode = center_and_scale_gcode(gcode)
gcode = validate_gcode(gcode)
line_count = len([l for l in gcode.split("\n") if l.strip()])
svg = gcode_to_svg(gcode)
header = f"; dcode output\n; prompt: {prompt}\n; {line_count} commands\n\n"
return header + gcode, svg
except Exception as e:
import traceback
traceback.print_exc()
return f"; Error: {e}", gcode_to_svg("")
# ============================================================================
# UI
# ============================================================================
css = """
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap');
* {
font-family: 'IBM Plex Mono', monospace !important;
}
.gradio-container {
max-width: 900px !important;
margin: auto;
}
footer {
display: none !important;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
gr.Markdown("# dcode")
gr.Markdown("text → polargraph gcode via stable diffusion")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="prompt",
placeholder="describe what to draw...",
lines=2,
show_label=True,
)
with gr.Accordion("settings", open=False):
temperature = gr.Slider(0.3, 1.2, value=0.7, label="temperature", step=0.1)
max_tokens = gr.Slider(256, 2048, value=2048, step=256, label="max tokens")
num_steps = gr.Slider(20, 75, value=50, step=5, label="diffusion steps")
guidance = gr.Slider(5.0, 20.0, value=12.0, step=0.5, label="guidance")
seed = gr.Number(value=-1, label="seed (-1 = random)", precision=0)
generate_btn = gr.Button("generate", variant="secondary")
gr.Examples(
examples=[
["a drawing of a horse"],
["a sketch of a cat"],
["a simple flower drawing"],
["a drawing of a tree"],
["abstract lines"],
["a portrait sketch"],
],
inputs=prompt,
label=None,
examples_per_page=6,
)
with gr.Column(scale=2):
preview = gr.HTML(value=gcode_to_svg(""))
with gr.Accordion("gcode", open=False):
gcode_output = gr.Code(label=None, language=None, lines=12)
gr.Markdown("---")
gr.Markdown("machine: 841×1189mm / pen servo 40-90° / [github](https://github.com/Twarner491/dcode) / [model](https://huggingface.co/twarner/dcode-sd-gcode-v3) / mit")
generate_btn.click(generate, [prompt, temperature, max_tokens, num_steps, guidance, seed], [gcode_output, preview])
prompt.submit(generate, [prompt, temperature, max_tokens, num_steps, guidance, seed], [gcode_output, preview])
if __name__ == "__main__":
demo.launch()