Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import subprocess | |
| import shutil | |
| import json | |
| import time | |
| from pathlib import Path | |
| import torch | |
| # Setup directories | |
| DATASET_DIR = Path("./datasets") | |
| OUTPUT_DIR = Path("./output") | |
| DATASET_DIR.mkdir(exist_ok=True) | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # Global variable to store dataset path | |
| current_dataset_path = None | |
| def check_gpu(): | |
| """Check if GPU is available""" | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| return f"β GPU Available: {gpu_name}" | |
| return "β οΈ No GPU detected - training will be slow" | |
| def upload_and_prepare_dataset(files, dataset_name, trigger_word): | |
| """Upload images and prepare dataset""" | |
| global current_dataset_path | |
| if not files: | |
| return "β Please upload at least one image", None, "" | |
| if not dataset_name: | |
| dataset_name = f"dataset_{int(time.time())}" | |
| # Create dataset directory | |
| dataset_path = DATASET_DIR / dataset_name | |
| dataset_path.mkdir(exist_ok=True, parents=True) | |
| # Save images | |
| image_count = 0 | |
| for file in files: | |
| if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')): | |
| filename = Path(file.name).name | |
| destination = dataset_path / filename | |
| shutil.copy(file.name, destination) | |
| # Create simple caption file | |
| caption_file = destination.with_suffix('.txt') | |
| caption_text = trigger_word if trigger_word else "a photo" | |
| with open(caption_file, 'w') as f: | |
| f.write(caption_text) | |
| image_count += 1 | |
| if image_count == 0: | |
| return "β No valid images found. Upload PNG, JPG, JPEG, or WEBP files.", None, "" | |
| current_dataset_path = str(dataset_path) | |
| status = f"β Successfully uploaded {image_count} images\n" | |
| status += f"π Dataset: {dataset_name}\n" | |
| if trigger_word: | |
| status += f"π·οΈ Trigger word: '{trigger_word}'\n" | |
| status += f"πΎ Location: {current_dataset_path}" | |
| return status, current_dataset_path, f"Dataset ready: {dataset_name}" | |
| def train_lora( | |
| dataset_path, | |
| project_name, | |
| trigger_word, | |
| steps, | |
| learning_rate, | |
| lora_rank, | |
| resolution, | |
| progress=gr.Progress() | |
| ): | |
| """Train LoRA model""" | |
| if not dataset_path or not os.path.exists(dataset_path): | |
| return "β Please upload a dataset first!", None | |
| if not project_name: | |
| project_name = f"lora_{int(time.time())}" | |
| output_path = OUTPUT_DIR / project_name | |
| output_path.mkdir(exist_ok=True, parents=True) | |
| # Create training config | |
| config = { | |
| "job": "extension", | |
| "config": { | |
| "name": project_name, | |
| "process": [{ | |
| "type": "sd_trainer", | |
| "training_folder": str(output_path), | |
| "device": "cuda:0", | |
| "trigger_word": trigger_word or "", | |
| "network": { | |
| "type": "lora", | |
| "linear": int(lora_rank), | |
| "linear_alpha": int(lora_rank), | |
| }, | |
| "save": { | |
| "dtype": "float16", | |
| "save_every": max(100, int(steps / 4)), | |
| "max_step_saves_to_keep": 3, | |
| }, | |
| "datasets": [{ | |
| "folder_path": dataset_path, | |
| "caption_ext": "txt", | |
| "caption_dropout_rate": 0.05, | |
| "resolution": [int(resolution), int(resolution)], | |
| }], | |
| "train": { | |
| "batch_size": 1, | |
| "steps": int(steps), | |
| "gradient_accumulation_steps": 1, | |
| "train_unet": True, | |
| "train_text_encoder": False, | |
| "gradient_checkpointing": True, | |
| "noise_scheduler": "flowmatch", | |
| "optimizer": "adamw8bit", | |
| "lr": float(learning_rate), | |
| "ema_config": { | |
| "use_ema": True, | |
| "ema_decay": 0.99, | |
| }, | |
| "dtype": "bf16", | |
| }, | |
| "model": { | |
| "name_or_path": "Tongyi-MAI/Z-Image-Base", | |
| "is_v_pred": False, | |
| "quantize": True, | |
| }, | |
| "sample": { | |
| "sampler": "flowmatch", | |
| "sample_every": max(100, int(steps / 4)), | |
| "width": int(resolution), | |
| "height": int(resolution), | |
| "prompts": [ | |
| f"{trigger_word} high quality photo" if trigger_word else "high quality photo", | |
| f"{trigger_word} beautiful scene" if trigger_word else "beautiful scene", | |
| ], | |
| "neg": "", | |
| "seed": 42, | |
| "guidance_scale": 0.0, | |
| "sample_steps": 9, | |
| }, | |
| }] | |
| } | |
| } | |
| # Save config | |
| config_path = output_path / "config.json" | |
| with open(config_path, 'w') as f: | |
| json.dump(config, f, indent=2) | |
| progress(0.1, desc="Installing AI Toolkit...") | |
| # Install AI Toolkit if not exists | |
| if not Path("./ai-toolkit").exists(): | |
| try: | |
| subprocess.run( | |
| ["git", "clone", "https://github.com/ostris/ai-toolkit.git"], | |
| check=True, | |
| capture_output=True | |
| ) | |
| os.chdir("ai-toolkit") | |
| subprocess.run( | |
| ["git", "submodule", "update", "--init", "--recursive"], | |
| check=True, | |
| capture_output=True | |
| ) | |
| subprocess.run( | |
| ["pip", "install", "-q", "-r", "requirements.txt"], | |
| check=True | |
| ) | |
| os.chdir("..") | |
| except Exception as e: | |
| return f"β Failed to install AI Toolkit: {str(e)}", None | |
| progress(0.3, desc="Starting training...") | |
| # Run training | |
| try: | |
| result = subprocess.run( | |
| ["python", "ai-toolkit/run.py", str(config_path)], | |
| capture_output=True, | |
| text=True, | |
| timeout=3600 # 1 hour timeout | |
| ) | |
| if result.returncode != 0: | |
| return f"β Training failed:\n{result.stderr}", None | |
| progress(0.9, desc="Training complete! Finding LoRA file...") | |
| # Find the trained LoRA file | |
| lora_files = list(output_path.glob("*.safetensors")) | |
| if lora_files: | |
| lora_file = lora_files[-1] # Get the latest one | |
| success_msg = f"β Training Complete!\n\n" | |
| success_msg += f"π¦ LoRA saved: {lora_file.name}\n" | |
| success_msg += f"πΎ Size: {lora_file.stat().st_size / (1024*1024):.2f} MB\n" | |
| success_msg += f"π·οΈ Use trigger word: '{trigger_word}' in your prompts" | |
| return success_msg, str(lora_file) | |
| else: | |
| return "β οΈ Training completed but no LoRA file found", None | |
| except subprocess.TimeoutExpired: | |
| return "β Training timeout (> 1 hour). Try reducing steps.", None | |
| except Exception as e: | |
| return f"β Training error: {str(e)}", None | |
| # Gradio Interface | |
| with gr.Blocks(title="Z-Image LoRA Trainer", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¨ Z-Image LoRA Trainer | |
| Train custom LoRA models for Z-Image-Base (6B parameter model) | |
| **Quick Start:** | |
| 1. Upload 10-50 images of your subject | |
| 2. Enter a trigger word (e.g., "mycharacter", "mystyle") | |
| 3. Click Train | |
| 4. Download your LoRA when complete | |
| β οΈ **Note:** Training takes 10-30 minutes depending on steps. Don't close this tab! | |
| """) | |
| # GPU Status | |
| gpu_status = gr.Textbox(label="GPU Status", value=check_gpu(), interactive=False) | |
| with gr.Tab("π€ Upload Dataset"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.Files( | |
| label="Upload Images (10-50 recommended)", | |
| file_types=["image"], | |
| file_count="multiple" | |
| ) | |
| dataset_name_input = gr.Textbox( | |
| label="Dataset Name", | |
| placeholder="my_dataset", | |
| value="my_dataset" | |
| ) | |
| trigger_word_input = gr.Textbox( | |
| label="Trigger Word (optional but recommended)", | |
| placeholder="e.g., mycharacter, mystyle", | |
| info="A unique word to activate your LoRA" | |
| ) | |
| upload_btn = gr.Button("π€ Upload Dataset", variant="primary", size="lg") | |
| with gr.Column(): | |
| upload_status = gr.Textbox(label="Upload Status", lines=8) | |
| dataset_path_state = gr.Textbox(label="Dataset Path", visible=False) | |
| dataset_ready = gr.Textbox(label="Ready to Train", interactive=False) | |
| with gr.Tab("π Train LoRA"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| project_name_input = gr.Textbox( | |
| label="Project Name", | |
| placeholder="my_lora", | |
| value="my_lora" | |
| ) | |
| gr.Markdown("### Training Settings") | |
| steps_input = gr.Slider( | |
| label="Training Steps", | |
| minimum=100, | |
| maximum=3000, | |
| value=1000, | |
| step=100, | |
| info="More steps = better quality but slower. Start with 1000." | |
| ) | |
| learning_rate_input = gr.Slider( | |
| label="Learning Rate", | |
| minimum=0.00001, | |
| maximum=0.001, | |
| value=0.0001, | |
| step=0.00001, | |
| info="Default 0.0001 works well for most cases" | |
| ) | |
| lora_rank_input = gr.Slider( | |
| label="LoRA Rank", | |
| minimum=4, | |
| maximum=128, | |
| value=16, | |
| step=4, | |
| info="Higher = more detail but larger file. 16 is balanced." | |
| ) | |
| resolution_input = gr.Radio( | |
| label="Resolution", | |
| choices=[512, 768, 1024], | |
| value=1024, | |
| info="Z-Image native resolution is 1024x1024" | |
| ) | |
| train_btn = gr.Button("π Start Training", variant="primary", size="lg") | |
| with gr.Column(): | |
| training_status = gr.Textbox(label="Training Status", lines=15) | |
| lora_output = gr.File(label="Download Trained LoRA") | |
| with gr.Tab("βΉοΈ Help"): | |
| gr.Markdown(""" | |
| ## π How to Use | |
| ### Step 1: Prepare Your Images | |
| - **10-50 images** of your subject (more is better for complex subjects) | |
| - **Consistent subject** across images | |
| - **Good variety** in poses, angles, lighting | |
| - **High quality** photos (clear, well-lit) | |
| ### Step 2: Upload Dataset | |
| - Choose a descriptive **dataset name** | |
| - Add a **trigger word** (e.g., "sks person", "mystyle") | |
| - Upload your images | |
| ### Step 3: Configure Training | |
| - **Project name**: Name for your LoRA | |
| - **Steps**: | |
| - 500-1000 for simple subjects | |
| - 1000-2000 for complex subjects/styles | |
| - **Learning rate**: Keep default (0.0001) | |
| - **LoRA Rank**: 16 is good for most cases | |
| ### Step 4: Train | |
| - Click "Start Training" | |
| - Wait 10-30 minutes (don't close tab) | |
| - Download your LoRA when complete | |
| ### Step 5: Use Your LoRA | |
| - Load in ComfyUI, Automatic1111, or other Z-Image tools | |
| - Use your trigger word in prompts | |
| - Example: "a photo of [trigger_word] in a forest" | |
| ## π― Tips for Best Results | |
| - **Good dataset** = good results | |
| - **Consistent subject** across images | |
| - **Unique trigger word** (not common words) | |
| - **Start with 1000 steps**, adjust if needed | |
| - **Don't overtrain** (if quality decreases, reduce steps) | |
| ## β οΈ Troubleshooting | |
| **Training fails with OOM error:** | |
| - Reduce resolution to 768 or 512 | |
| - Use fewer steps | |
| - Upload fewer images | |
| **LoRA doesn't look like subject:** | |
| - Upload more images (20-30+) | |
| - Increase steps to 1500-2000 | |
| - Ensure images are consistent | |
| **LoRA is too strong/weak:** | |
| - Adjust LoRA weight in your inference tool (0.5-1.5) | |
| ## π Resources | |
| - **Z-Image Model**: [Tongyi-MAI/Z-Image-Base](https://huggingface.co/Tongyi-MAI/Z-Image-Base) | |
| - **AI Toolkit**: [github.com/ostris/ai-toolkit](https://github.com/ostris/ai-toolkit) | |
| - **Training Adapter**: [ostris/zimage_turbo_training_adapter](https://huggingface.co/ostris/zimage_turbo_training_adapter) | |
| """) | |
| # Event handlers | |
| upload_btn.click( | |
| fn=upload_and_prepare_dataset, | |
| inputs=[file_input, dataset_name_input, trigger_word_input], | |
| outputs=[upload_status, dataset_path_state, dataset_ready] | |
| ) | |
| train_btn.click( | |
| fn=train_lora, | |
| inputs=[ | |
| dataset_path_state, | |
| project_name_input, | |
| trigger_word_input, | |
| steps_input, | |
| learning_rate_input, | |
| lora_rank_input, | |
| resolution_input | |
| ], | |
| outputs=[training_status, lora_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |