| from __future__ import annotations |
|
|
| import datetime |
| import os |
| import pathlib |
| import shlex |
| import shutil |
| import subprocess |
|
|
| import gradio as gr |
| import PIL.Image |
| import slugify |
| import torch |
| from huggingface_hub import HfApi |
|
|
| from app_upload import LoRAModelUploader |
| from utils import save_model_card |
|
|
| URL_TO_JOIN_LORA_LIBRARY_ORG = 'https://huggingface.co/organizations/lora-library/share/hjetHAcKjnPHXhHfbeEcqnBqmhgilFfpOL' |
|
|
|
|
| def pad_image(image: PIL.Image.Image) -> PIL.Image.Image: |
| w, h = image.size |
| if w == h: |
| return image |
| elif w > h: |
| new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0)) |
| new_image.paste(image, (0, (w - h) // 2)) |
| return new_image |
| else: |
| new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0)) |
| new_image.paste(image, ((h - w) // 2, 0)) |
| return new_image |
|
|
|
|
| class Trainer: |
| def __init__(self, hf_token: str | None = None): |
| self.hf_token = hf_token |
| self.api = HfApi(token=hf_token) |
| self.model_uploader = LoRAModelUploader(hf_token) |
|
|
| def prepare_dataset(self, instance_images: list, resolution: int, |
| instance_data_dir: pathlib.Path) -> None: |
| shutil.rmtree(instance_data_dir, ignore_errors=True) |
| instance_data_dir.mkdir(parents=True) |
| for i, temp_path in enumerate(instance_images): |
| image = PIL.Image.open(temp_path.name) |
| image = pad_image(image) |
| image = image.resize((resolution, resolution)) |
| image = image.convert('RGB') |
| out_path = instance_data_dir / f'{i:03d}.jpg' |
| image.save(out_path, format='JPEG', quality=100) |
|
|
| def join_lora_library_org(self) -> None: |
| subprocess.run( |
| shlex.split( |
| f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LORA_LIBRARY_ORG}' |
| )) |
|
|
| def run( |
| self, |
| instance_images: list | None, |
| instance_prompt: str, |
| output_model_name: str, |
| overwrite_existing_model: bool, |
| validation_prompt: str, |
| base_model: str, |
| resolution_s: str, |
| n_steps: int, |
| learning_rate: float, |
| gradient_accumulation: int, |
| seed: int, |
| fp16: bool, |
| use_8bit_adam: bool, |
| checkpointing_steps: int, |
| use_wandb: bool, |
| validation_epochs: int, |
| upload_to_hub: bool, |
| use_private_repo: bool, |
| delete_existing_repo: bool, |
| upload_to: str, |
| remove_gpu_after_training: bool, |
| ) -> str: |
| if not torch.cuda.is_available(): |
| raise gr.Error('CUDA is not available.') |
| if instance_images is None: |
| raise gr.Error('You need to upload images.') |
| if not instance_prompt: |
| raise gr.Error('The instance prompt is missing.') |
| if not validation_prompt: |
| raise gr.Error('The validation prompt is missing.') |
|
|
| resolution = int(resolution_s) |
|
|
| if not output_model_name: |
| timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') |
| output_model_name = f'lora-dreambooth-{timestamp}' |
| output_model_name = slugify.slugify(output_model_name) |
|
|
| repo_dir = pathlib.Path(__file__).parent |
| output_dir = repo_dir / 'experiments' / output_model_name |
| if overwrite_existing_model or upload_to_hub: |
| shutil.rmtree(output_dir, ignore_errors=True) |
| output_dir.mkdir(parents=True) |
|
|
| instance_data_dir = repo_dir / 'training_data' / output_model_name |
| self.prepare_dataset(instance_images, resolution, instance_data_dir) |
|
|
| if upload_to_hub: |
| self.join_lora_library_org() |
|
|
| command = f''' |
| accelerate launch train_dreambooth_lora.py \ |
| --pretrained_model_name_or_path={base_model} \ |
| --instance_data_dir={instance_data_dir} \ |
| --output_dir={output_dir} \ |
| --instance_prompt="{instance_prompt}" \ |
| --resolution={resolution} \ |
| --train_batch_size=1 \ |
| --gradient_accumulation_steps={gradient_accumulation} \ |
| --learning_rate={learning_rate} \ |
| --lr_scheduler=constant \ |
| --lr_warmup_steps=0 \ |
| --max_train_steps={n_steps} \ |
| --checkpointing_steps={checkpointing_steps} \ |
| --validation_prompt="{validation_prompt}" \ |
| --validation_epochs={validation_epochs} \ |
| --seed={seed} |
| ''' |
| if fp16: |
| command += ' --mixed_precision fp16' |
| if use_8bit_adam: |
| command += ' --use_8bit_adam' |
| if use_wandb: |
| command += ' --report_to wandb' |
|
|
| with open(output_dir / 'train.sh', 'w') as f: |
| command_s = ' '.join(command.split()) |
| f.write(command_s) |
| subprocess.run(shlex.split(command)) |
| save_model_card(save_dir=output_dir, |
| base_model=base_model, |
| instance_prompt=instance_prompt, |
| test_prompt=validation_prompt, |
| test_image_dir='test_images') |
|
|
| message = 'Training completed!' |
| print(message) |
|
|
| if upload_to_hub: |
| upload_message = self.model_uploader.upload_lora_model( |
| folder_path=output_dir.as_posix(), |
| repo_name=output_model_name, |
| upload_to=upload_to, |
| private=use_private_repo, |
| delete_existing_repo=delete_existing_repo) |
| print(upload_message) |
| message = message + '\n' + upload_message |
|
|
| if remove_gpu_after_training: |
| space_id = os.getenv('SPACE_ID') |
| if space_id: |
| self.api.request_space_hardware(repo_id=space_id, |
| hardware='cpu-basic') |
|
|
| return message |
|
|