notrito's picture
col matrx
aeb4847
"""
Token Journey Visualizer - Transformation Inspector
"""
import gradio as gr
import numpy as np
from utils import extract_single_transformation, get_token_choices
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
cached_data = {}
def format_vector(vector, title):
"""Format vector as HTML table (first 50 + ... + last 50)."""
first_vals = vector[:50]
last_vals = vector[-50:]
sample = np.concatenate([first_vals, last_vals])
html = f"<h3>{title}</h3>"
html += "<div style='font-family: monospace; font-size: 11px; max-width: 100%; overflow-x: scroll;'>"
html += "<table style='border-collapse: collapse;'>"
html += "<tr>"
# First 50 dimensions
for idx in range(50):
val = sample[idx]
color = "gray"
html += f"<td style='padding: 4px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
# Separator
html += "<td style='padding: 4px; text-align: center; font-weight: bold;'>...</td>"
# Last 50 dimensions
for idx in range(50, 100):
val = sample[idx]
color = "gray"
html += f"<td style='padding: 4px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
html += "</tr>"
html += "</table></div>"
return html
def format_matrix(matrix, title):
"""Format matrix as HTML table (first 50x50 + ... + last 50x50)."""
first_first = matrix[:50, :50]
first_last = matrix[:50, -50:]
last_first = matrix[-50:, :50]
last_last = matrix[-50:, -50:]
sample_first = np.concatenate([first_first, first_last], axis=1)
sample_last = np.concatenate([last_first, last_last], axis=1)
sample = np.concatenate([sample_first, sample_last], axis=0)
html = f"<h3>{title}</h3>"
html += "<div style='font-family: monospace; font-size: 9px; max-height: 600px; overflow: scroll;'>"
html += "<table style='border-collapse: collapse;'>"
# First 50 rows
for row in range(50):
html += "<tr>"
# First 50 columns
for col in range(50):
val = sample[row, col]
color = "gray"
html += f"<td style='padding: 2px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
# Column separator
html += "<td style='padding: 2px; text-align: center; font-weight: bold;'>...</td>"
# Last 50 columns
for col in range(50, 100):
val = sample[row, col]
color = "gray"
html += f"<td style='padding: 2px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
html += "</tr>"
# Row separator
html += "<tr>"
for _ in range(50):
html += "<td style='padding: 2px; text-align: center; font-weight: bold;'>⋮</td>"
html += "<td style='padding: 2px; text-align: center; font-weight: bold;'>⋱</td>"
for _ in range(50):
html += "<td style='padding: 2px; text-align: center; font-weight: bold;'>⋮</td>"
html += "</tr>"
# Last 50 rows
for row in range(50, 100):
html += "<tr>"
# First 50 columns
for col in range(50):
val = sample[row, col]
color = "gray"
html += f"<td style='padding: 2px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
# Column separator
html += "<td style='padding: 2px; text-align: center; font-weight: bold;'>...</td>"
# Last 50 columns
for col in range(50, 100):
val = sample[row, col]
color = "gray"
html += f"<td style='padding: 2px; text-align: right; color: {color}; border: 1px solid #eee;'>{val:.4f}</td>"
html += "</tr>"
html += "</table></div>"
return html
def process_text(text):
"""Tokenize text."""
if not text.strip():
return gr.Dropdown(choices=[])
try:
choices, indices = get_token_choices(text, MODEL_NAME)
cached_data['text'] = text
print(f"✅ Tokenized: {len(choices)} tokens")
return gr.update(choices=choices, value=choices[0] if choices else None)
except Exception as e:
print(f"❌ Error in process_text: {e}")
import traceback
traceback.print_exc()
return gr.update(choices=[], value=None)
def visualize_transformation(token_choice, layer):
"""Show transformation as numbers."""
print(f"Visualize called: token={token_choice}, layer={layer}")
if not token_choice:
print("⚠️ No token selected")
return "⚠️ Select a token first", "", ""
if 'text' not in cached_data:
print("⚠️ No text in cache")
return "⚠️ Process text first", "", ""
try:
token_index = int(token_choice.split(":")[0])
result = extract_single_transformation(
text=cached_data['text'],
token_index=token_index,
component="q_proj",
layer=layer,
model_name=MODEL_NAME
)
input_html = format_vector(result['input_vector'], "Input Vector (100 dims)")
matrix_html = format_matrix(result['weight_matrix'], "W_q Matrix (100×100)")
output_html = format_vector(result['output_vector'], "Output Vector (100 dims)")
return input_html, matrix_html, output_html
except Exception as e:
return f"Error: {e}", "", ""
with gr.Blocks() as demo:
gr.Markdown("# Token Transformation Inspector")
with gr.Row():
text_input = gr.Textbox(label="Text", value="The cat sat on the mat")
process_btn = gr.Button("Process")
with gr.Row():
token_dropdown = gr.Dropdown(label="Token", choices=[])
layer_slider = gr.Slider(0, 21, value=0, step=1, label="Layer")
visualize_btn = gr.Button("Visualize")
input_display = gr.HTML()
matrix_display = gr.HTML()
output_display = gr.HTML()
process_btn.click(process_text, text_input, token_dropdown)
visualize_btn.click(visualize_transformation, [token_dropdown, layer_slider],
[input_display, matrix_display, output_display])
if __name__ == "__main__":
demo.launch()