Yeu3ui's picture
Create app.py
dbb5a8c verified
import gradio as gr
import os
import subprocess
import shutil
import json
import time
from pathlib import Path
import torch
# Setup directories
DATASET_DIR = Path("./datasets")
OUTPUT_DIR = Path("./output")
DATASET_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)
# Global variable to store dataset path
current_dataset_path = None
def check_gpu():
"""Check if GPU is available"""
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
return f"βœ… GPU Available: {gpu_name}"
return "⚠️ No GPU detected - training will be slow"
def upload_and_prepare_dataset(files, dataset_name, trigger_word):
"""Upload images and prepare dataset"""
global current_dataset_path
if not files:
return "❌ Please upload at least one image", None, ""
if not dataset_name:
dataset_name = f"dataset_{int(time.time())}"
# Create dataset directory
dataset_path = DATASET_DIR / dataset_name
dataset_path.mkdir(exist_ok=True, parents=True)
# Save images
image_count = 0
for file in files:
if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
filename = Path(file.name).name
destination = dataset_path / filename
shutil.copy(file.name, destination)
# Create simple caption file
caption_file = destination.with_suffix('.txt')
caption_text = trigger_word if trigger_word else "a photo"
with open(caption_file, 'w') as f:
f.write(caption_text)
image_count += 1
if image_count == 0:
return "❌ No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, ""
current_dataset_path = str(dataset_path)
status = f"βœ… Successfully uploaded {image_count} images\n"
status += f"πŸ“ Dataset: {dataset_name}\n"
if trigger_word:
status += f"🏷️ Trigger word: '{trigger_word}'\n"
status += f"πŸ’Ύ Location: {current_dataset_path}"
return status, current_dataset_path, f"Dataset ready: {dataset_name}"
def train_lora(
dataset_path,
project_name,
trigger_word,
steps,
learning_rate,
lora_rank,
resolution,
progress=gr.Progress()
):
"""Train LoRA model"""
if not dataset_path or not os.path.exists(dataset_path):
return "❌ Please upload a dataset first!", None
if not project_name:
project_name = f"lora_{int(time.time())}"
output_path = OUTPUT_DIR / project_name
output_path.mkdir(exist_ok=True, parents=True)
# Create training config
config = {
"job": "extension",
"config": {
"name": project_name,
"process": [{
"type": "sd_trainer",
"training_folder": str(output_path),
"device": "cuda:0",
"trigger_word": trigger_word or "",
"network": {
"type": "lora",
"linear": int(lora_rank),
"linear_alpha": int(lora_rank),
},
"save": {
"dtype": "float16",
"save_every": max(100, int(steps / 4)),
"max_step_saves_to_keep": 3,
},
"datasets": [{
"folder_path": dataset_path,
"caption_ext": "txt",
"caption_dropout_rate": 0.05,
"resolution": [int(resolution), int(resolution)],
}],
"train": {
"batch_size": 1,
"steps": int(steps),
"gradient_accumulation_steps": 1,
"train_unet": True,
"train_text_encoder": False,
"gradient_checkpointing": True,
"noise_scheduler": "flowmatch",
"optimizer": "adamw8bit",
"lr": float(learning_rate),
"ema_config": {
"use_ema": True,
"ema_decay": 0.99,
},
"dtype": "bf16",
},
"model": {
"name_or_path": "Tongyi-MAI/Z-Image-Base",
"is_v_pred": False,
"quantize": True,
},
"sample": {
"sampler": "flowmatch",
"sample_every": max(100, int(steps / 4)),
"width": int(resolution),
"height": int(resolution),
"prompts": [
f"{trigger_word} high quality photo" if trigger_word else "high quality photo",
f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene",
],
"neg": "",
"seed": 42,
"guidance_scale": 0.0,
"sample_steps": 9,
},
}]
}
}
# Save config
config_path = output_path / "config.json"
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
progress(0.1, desc="Installing AI Toolkit...")
# Install AI Toolkit if not exists
if not Path("./ai-toolkit").exists():
try:
subprocess.run(
["git", "clone", "https://github.com/ostris/ai-toolkit.git"],
check=True,
capture_output=True
)
os.chdir("ai-toolkit")
subprocess.run(
["git", "submodule", "update", "--init", "--recursive"],
check=True,
capture_output=True
)
subprocess.run(
["pip", "install", "-q", "-r", "requirements.txt"],
check=True
)
os.chdir("..")
except Exception as e:
return f"❌ Failed to install AI Toolkit: {str(e)}", None
progress(0.3, desc="Starting training...")
# Run training
try:
result = subprocess.run(
["python", "ai-toolkit/run.py", str(config_path)],
capture_output=True,
text=True,
timeout=3600 # 1 hour timeout
)
if result.returncode != 0:
return f"❌ Training failed:\n{result.stderr}", None
progress(0.9, desc="Training complete! Finding LoRA file...")
# Find the trained LoRA file
lora_files = list(output_path.glob("*.safetensors"))
if lora_files:
lora_file = lora_files[-1] # Get the latest one
success_msg = f"βœ… Training Complete!\n\n"
success_msg += f"πŸ“¦ LoRA saved: {lora_file.name}\n"
success_msg += f"πŸ’Ύ Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n"
success_msg += f"🏷️ Use trigger word: '{trigger_word}' in your prompts"
return success_msg, str(lora_file)
else:
return "⚠️ Training completed but no LoRA file found", None
except subprocess.TimeoutExpired:
return "❌ Training timeout (> 1 hour). Try reducing steps.", None
except Exception as e:
return f"❌ Training error: {str(e)}", None
# Gradio Interface
with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 Z-Image LoRA Trainer
Train custom LoRA models for Z-Image-Base (6B parameter model)
**Quick Start:**
1. Upload 10-50 images of your subject
2. Enter a trigger word (e.g., "mycharacter", "mystyle")
3. Click Train
4. Download your LoRA when complete
⚠️ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab!
""")
# GPU Status
gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False)
with gr.Tab("πŸ“€ Upload Dataset"):
with gr.Row():
with gr.Column():
file_input = gr.Files(
label="Upload Images (10-50 recommended)",
file_types=["image"],
file_count="multiple"
)
dataset_name_input = gr.Textbox(
label="Dataset Name",
placeholder="my_dataset",
value="my_dataset"
)
trigger_word_input = gr.Textbox(
label="Trigger Word (optional but recommended)",
placeholder="e.g., mycharacter, mystyle",
info="A unique word to activate your LoRA"
)
upload_btn = gr.Button("πŸ“€ Upload Dataset", variant="primary", size="lg")
with gr.Column():
upload_status = gr.Textbox(label="Upload Status", lines=8)
dataset_path_state = gr.Textbox(label="Dataset Path", visible=False)
dataset_ready = gr.Textbox(label="Ready to Train", interactive=False)
with gr.Tab("πŸš€ Train LoRA"):
with gr.Row():
with gr.Column():
project_name_input = gr.Textbox(
label="Project Name",
placeholder="my_lora",
value="my_lora"
)
gr.Markdown("### Training Settings")
steps_input = gr.Slider(
label="Training Steps",
minimum=100,
maximum=3000,
value=1000,
step=100,
info="More steps = better quality but slower. Start with 1000."
)
learning_rate_input = gr.Slider(
label="Learning Rate",
minimum=0.00001,
maximum=0.001,
value=0.0001,
step=0.00001,
info="Default 0.0001 works well for most cases"
)
lora_rank_input = gr.Slider(
label="LoRA Rank",
minimum=4,
maximum=128,
value=16,
step=4,
info="Higher = more detail but larger file. 16 is balanced."
)
resolution_input = gr.Radio(
label="Resolution",
choices=[512, 768, 1024],
value=1024,
info="Z-Image native resolution is 1024x1024"
)
train_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
with gr.Column():
training_status = gr.Textbox(label="Training Status", lines=15)
lora_output = gr.File(label="Download Trained LoRA")
with gr.Tab("ℹ️ Help"):
gr.Markdown("""
## πŸ“š How to Use
### Step 1: Prepare Your Images
- **10-50 images** of your subject (more is better for complex subjects)
- **Consistent subject** across images
- **Good variety** in poses, angles, lighting
- **High quality** photos (clear, well-lit)
### Step 2: Upload Dataset
- Choose a descriptive **dataset name**
- Add a **trigger word** (e.g., "sks person", "mystyle")
- Upload your images
### Step 3: Configure Training
- **Project name**: Name for your LoRA
- **Steps**:
- 500-1000 for simple subjects
- 1000-2000 for complex subjects/styles
- **Learning rate**: Keep default (0.0001)
- **LoRA Rank**: 16 is good for most cases
### Step 4: Train
- Click "Start Training"
- Wait 10-30 minutes (don't close tab)
- Download your LoRA when complete
### Step 5: Use Your LoRA
- Load in ComfyUI, Automatic1111, or other Z-Image tools
- Use your trigger word in prompts
- Example: "a photo of [trigger_word] in a forest"
## 🎯 Tips for Best Results
- **Good dataset** = good results
- **Consistent subject** across images
- **Unique trigger word** (not common words)
- **Start with 1000 steps**, adjust if needed
- **Don't overtrain** (if quality decreases, reduce steps)
## ⚠️ Troubleshooting
**Training fails with OOM error:**
- Reduce resolution to 768 or 512
- Use fewer steps
- Upload fewer images
**LoRA doesn't look like subject:**
- Upload more images (20-30+)
- Increase steps to 1500-2000
- Ensure images are consistent
**LoRA is too strong/weak:**
- Adjust LoRA weight in your inference tool (0.5-1.5)
## πŸ“– Resources
- **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base)
- **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit)
- **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter)
""")
# Event handlers
upload_btn.click(
fn=upload_and_prepare_dataset,
inputs=[file_input, dataset_name_input, trigger_word_input],
outputs=[upload_status, dataset_path_state, dataset_ready]
)
train_btn.click(
fn=train_lora,
inputs=[
dataset_path_state,
project_name_input,
trigger_word_input,
steps_input,
learning_rate_input,
lora_rank_input,
resolution_input
],
outputs=[training_status, lora_output]
)
if __name__ == "__main__":
demo.launch()