Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
-
# Copyright
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
@@ -35,7 +35,7 @@ import torch.utils.checkpoint
|
|
| 35 |
import transformers
|
| 36 |
from accelerate import Accelerator
|
| 37 |
from accelerate.logging import get_logger
|
| 38 |
-
from accelerate.utils import ProjectConfiguration, set_seed
|
| 39 |
from datasets import concatenate_datasets, load_dataset
|
| 40 |
from huggingface_hub import create_repo, upload_folder
|
| 41 |
from packaging import version
|
|
@@ -50,18 +50,21 @@ from diffusers.optimization import get_scheduler
|
|
| 50 |
from diffusers.training_utils import EMAModel, compute_snr
|
| 51 |
from diffusers.utils import check_min_version, is_wandb_available
|
| 52 |
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 53 |
-
from diffusers.utils.import_utils import is_xformers_available
|
| 54 |
from diffusers.utils.torch_utils import is_compiled_module
|
| 55 |
|
| 56 |
|
| 57 |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 58 |
-
check_min_version("0.
|
| 59 |
|
| 60 |
logger = get_logger(__name__)
|
|
|
|
|
|
|
| 61 |
|
|
|
|
| 62 |
|
| 63 |
DATASET_NAME_MAPPING = {
|
| 64 |
-
"lambdalabs/
|
| 65 |
}
|
| 66 |
|
| 67 |
|
|
@@ -389,7 +392,7 @@ def parse_args(input_args=None):
|
|
| 389 |
type=float,
|
| 390 |
default=None,
|
| 391 |
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 392 |
-
"More details here: https://
|
| 393 |
)
|
| 394 |
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 395 |
parser.add_argument(
|
|
@@ -460,10 +463,22 @@ def parse_args(input_args=None):
|
|
| 460 |
),
|
| 461 |
)
|
| 462 |
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
|
|
|
|
|
|
|
|
|
| 463 |
parser.add_argument(
|
| 464 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 465 |
)
|
| 466 |
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
if input_args is not None:
|
| 469 |
args = parser.parse_args(input_args)
|
|
@@ -477,7 +492,6 @@ def parse_args(input_args=None):
|
|
| 477 |
# Sanity checks
|
| 478 |
if args.dataset_name is None and args.train_data_dir is None:
|
| 479 |
raise ValueError("Need either a dataset name or a training folder.")
|
| 480 |
-
|
| 481 |
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
| 482 |
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
| 483 |
|
|
@@ -536,6 +550,9 @@ def compute_vae_encodings(batch, vae):
|
|
| 536 |
with torch.no_grad():
|
| 537 |
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 538 |
model_input = model_input * vae.config.scaling_factor
|
|
|
|
|
|
|
|
|
|
| 539 |
return {"model_input": model_input.cpu()}
|
| 540 |
|
| 541 |
|
|
@@ -584,7 +601,7 @@ def main(args):
|
|
| 584 |
if args.report_to == "wandb" and args.hub_token is not None:
|
| 585 |
raise ValueError(
|
| 586 |
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 587 |
-
" Please use `
|
| 588 |
)
|
| 589 |
|
| 590 |
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
@@ -716,7 +733,12 @@ def main(args):
|
|
| 716 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 717 |
)
|
| 718 |
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
|
| 719 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
if args.enable_xformers_memory_efficient_attention:
|
| 721 |
if is_xformers_available():
|
| 722 |
import xformers
|
|
@@ -742,7 +764,8 @@ def main(args):
|
|
| 742 |
model.save_pretrained(os.path.join(output_dir, "unet"))
|
| 743 |
|
| 744 |
# make sure to pop weight so that corresponding model is not saved again
|
| 745 |
-
weights
|
|
|
|
| 746 |
|
| 747 |
def load_model_hook(models, input_dir):
|
| 748 |
if args.use_ema:
|
|
@@ -809,9 +832,7 @@ def main(args):
|
|
| 809 |
if args.dataset_name is not None:
|
| 810 |
# Downloading and loading a dataset from the hub.
|
| 811 |
dataset = load_dataset(
|
| 812 |
-
args.dataset_name,
|
| 813 |
-
args.dataset_config_name,
|
| 814 |
-
cache_dir=args.cache_dir,
|
| 815 |
)
|
| 816 |
else:
|
| 817 |
data_files = {}
|
|
@@ -849,7 +870,10 @@ def main(args):
|
|
| 849 |
)
|
| 850 |
|
| 851 |
# Preprocessing the datasets.
|
| 852 |
-
|
|
|
|
|
|
|
|
|
|
| 853 |
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
| 854 |
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 855 |
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
|
@@ -907,14 +931,14 @@ def main(args):
|
|
| 907 |
# fingerprint used by the cache for the other processes to load the result
|
| 908 |
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
| 909 |
new_fingerprint = Hasher.hash(args)
|
| 910 |
-
new_fingerprint_for_vae = Hasher.hash(vae_path)
|
| 911 |
train_dataset_with_embeddings = train_dataset.map(
|
| 912 |
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
|
| 913 |
)
|
| 914 |
train_dataset_with_vae = train_dataset.map(
|
| 915 |
compute_vae_encodings_fn,
|
| 916 |
batched=True,
|
| 917 |
-
batch_size=args.train_batch_size
|
| 918 |
new_fingerprint=new_fingerprint_for_vae,
|
| 919 |
)
|
| 920 |
precomputed_dataset = concatenate_datasets(
|
|
@@ -925,7 +949,10 @@ def main(args):
|
|
| 925 |
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
|
| 926 |
del text_encoders, tokenizers, vae
|
| 927 |
gc.collect()
|
| 928 |
-
|
|
|
|
|
|
|
|
|
|
| 929 |
|
| 930 |
def collate_fn(examples):
|
| 931 |
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
|
|
@@ -1074,15 +1101,14 @@ def main(args):
|
|
| 1074 |
|
| 1075 |
# Add noise to the model input according to the noise magnitude at each timestep
|
| 1076 |
# (this is the forward diffusion process)
|
| 1077 |
-
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
| 1078 |
|
| 1079 |
# time ids
|
| 1080 |
def compute_time_ids(original_size, crops_coords_top_left):
|
| 1081 |
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
| 1082 |
target_size = (args.resolution, args.resolution)
|
| 1083 |
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 1084 |
-
add_time_ids = torch.tensor([add_time_ids])
|
| 1085 |
-
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
| 1086 |
return add_time_ids
|
| 1087 |
|
| 1088 |
add_time_ids = torch.cat(
|
|
@@ -1091,7 +1117,7 @@ def main(args):
|
|
| 1091 |
|
| 1092 |
# Predict the noise residual
|
| 1093 |
unet_added_conditions = {"time_ids": add_time_ids}
|
| 1094 |
-
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
|
| 1095 |
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
| 1096 |
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
| 1097 |
model_pred = unet(
|
|
@@ -1122,7 +1148,7 @@ def main(args):
|
|
| 1122 |
if args.snr_gamma is None:
|
| 1123 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 1124 |
else:
|
| 1125 |
-
# Compute loss-weights as per Section 3.4 of https://
|
| 1126 |
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 1127 |
# This is discussed in Section 4.2 of the same paper.
|
| 1128 |
snr = compute_snr(noise_scheduler, timesteps)
|
|
@@ -1160,7 +1186,8 @@ def main(args):
|
|
| 1160 |
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 1161 |
train_loss = 0.0
|
| 1162 |
|
| 1163 |
-
|
|
|
|
| 1164 |
if global_step % args.checkpointing_steps == 0:
|
| 1165 |
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1166 |
if args.checkpoints_total_limit is not None:
|
|
@@ -1226,7 +1253,11 @@ def main(args):
|
|
| 1226 |
pipeline.set_progress_bar_config(disable=True)
|
| 1227 |
|
| 1228 |
# run inference
|
| 1229 |
-
generator =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1230 |
pipeline_args = {"prompt": args.validation_prompt}
|
| 1231 |
|
| 1232 |
with autocast_ctx:
|
|
@@ -1250,7 +1281,10 @@ def main(args):
|
|
| 1250 |
)
|
| 1251 |
|
| 1252 |
del pipeline
|
| 1253 |
-
|
|
|
|
|
|
|
|
|
|
| 1254 |
|
| 1255 |
if args.use_ema:
|
| 1256 |
# Switch back to the original UNet parameters.
|
|
@@ -1287,7 +1321,9 @@ def main(args):
|
|
| 1287 |
images = []
|
| 1288 |
if args.validation_prompt and args.num_validation_images > 0:
|
| 1289 |
pipeline = pipeline.to(accelerator.device)
|
| 1290 |
-
generator =
|
|
|
|
|
|
|
| 1291 |
|
| 1292 |
with autocast_ctx:
|
| 1293 |
images = [
|
|
@@ -1331,5 +1367,4 @@ def main(args):
|
|
| 1331 |
|
| 1332 |
if __name__ == "__main__":
|
| 1333 |
args = parse_args()
|
| 1334 |
-
main(args)
|
| 1335 |
-
raise RuntimeError("The script is finished.")
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 35 |
import transformers
|
| 36 |
from accelerate import Accelerator
|
| 37 |
from accelerate.logging import get_logger
|
| 38 |
+
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
|
| 39 |
from datasets import concatenate_datasets, load_dataset
|
| 40 |
from huggingface_hub import create_repo, upload_folder
|
| 41 |
from packaging import version
|
|
|
|
| 50 |
from diffusers.training_utils import EMAModel, compute_snr
|
| 51 |
from diffusers.utils import check_min_version, is_wandb_available
|
| 52 |
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 53 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
| 54 |
from diffusers.utils.torch_utils import is_compiled_module
|
| 55 |
|
| 56 |
|
| 57 |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 58 |
+
check_min_version("0.36.0.dev0")
|
| 59 |
|
| 60 |
logger = get_logger(__name__)
|
| 61 |
+
if is_torch_npu_available():
|
| 62 |
+
import torch_npu
|
| 63 |
|
| 64 |
+
torch.npu.config.allow_internal_format = False
|
| 65 |
|
| 66 |
DATASET_NAME_MAPPING = {
|
| 67 |
+
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
| 68 |
}
|
| 69 |
|
| 70 |
|
|
|
|
| 392 |
type=float,
|
| 393 |
default=None,
|
| 394 |
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
| 395 |
+
"More details here: https://huggingface.co/papers/2303.09556.",
|
| 396 |
)
|
| 397 |
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 398 |
parser.add_argument(
|
|
|
|
| 463 |
),
|
| 464 |
)
|
| 465 |
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 466 |
+
parser.add_argument(
|
| 467 |
+
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
|
| 468 |
+
)
|
| 469 |
parser.add_argument(
|
| 470 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 471 |
)
|
| 472 |
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
| 473 |
+
parser.add_argument(
|
| 474 |
+
"--image_interpolation_mode",
|
| 475 |
+
type=str,
|
| 476 |
+
default="lanczos",
|
| 477 |
+
choices=[
|
| 478 |
+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
|
| 479 |
+
],
|
| 480 |
+
help="The image interpolation method to use for resizing images.",
|
| 481 |
+
)
|
| 482 |
|
| 483 |
if input_args is not None:
|
| 484 |
args = parser.parse_args(input_args)
|
|
|
|
| 492 |
# Sanity checks
|
| 493 |
if args.dataset_name is None and args.train_data_dir is None:
|
| 494 |
raise ValueError("Need either a dataset name or a training folder.")
|
|
|
|
| 495 |
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
| 496 |
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
| 497 |
|
|
|
|
| 550 |
with torch.no_grad():
|
| 551 |
model_input = vae.encode(pixel_values).latent_dist.sample()
|
| 552 |
model_input = model_input * vae.config.scaling_factor
|
| 553 |
+
|
| 554 |
+
# There might have slightly performance improvement
|
| 555 |
+
# by changing model_input.cpu() to accelerator.gather(model_input)
|
| 556 |
return {"model_input": model_input.cpu()}
|
| 557 |
|
| 558 |
|
|
|
|
| 601 |
if args.report_to == "wandb" and args.hub_token is not None:
|
| 602 |
raise ValueError(
|
| 603 |
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
| 604 |
+
" Please use `hf auth login` to authenticate with the Hub."
|
| 605 |
)
|
| 606 |
|
| 607 |
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
|
|
| 733 |
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
|
| 734 |
)
|
| 735 |
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
|
| 736 |
+
if args.enable_npu_flash_attention:
|
| 737 |
+
if is_torch_npu_available():
|
| 738 |
+
logger.info("npu flash attention enabled.")
|
| 739 |
+
unet.enable_npu_flash_attention()
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
|
| 742 |
if args.enable_xformers_memory_efficient_attention:
|
| 743 |
if is_xformers_available():
|
| 744 |
import xformers
|
|
|
|
| 764 |
model.save_pretrained(os.path.join(output_dir, "unet"))
|
| 765 |
|
| 766 |
# make sure to pop weight so that corresponding model is not saved again
|
| 767 |
+
if weights:
|
| 768 |
+
weights.pop()
|
| 769 |
|
| 770 |
def load_model_hook(models, input_dir):
|
| 771 |
if args.use_ema:
|
|
|
|
| 832 |
if args.dataset_name is not None:
|
| 833 |
# Downloading and loading a dataset from the hub.
|
| 834 |
dataset = load_dataset(
|
| 835 |
+
args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
|
|
|
|
|
|
|
| 836 |
)
|
| 837 |
else:
|
| 838 |
data_files = {}
|
|
|
|
| 870 |
)
|
| 871 |
|
| 872 |
# Preprocessing the datasets.
|
| 873 |
+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
| 874 |
+
if interpolation is None:
|
| 875 |
+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
|
| 876 |
+
train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
|
| 877 |
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
|
| 878 |
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
| 879 |
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
|
|
|
| 931 |
# fingerprint used by the cache for the other processes to load the result
|
| 932 |
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
| 933 |
new_fingerprint = Hasher.hash(args)
|
| 934 |
+
new_fingerprint_for_vae = Hasher.hash((vae_path, args))
|
| 935 |
train_dataset_with_embeddings = train_dataset.map(
|
| 936 |
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
|
| 937 |
)
|
| 938 |
train_dataset_with_vae = train_dataset.map(
|
| 939 |
compute_vae_encodings_fn,
|
| 940 |
batched=True,
|
| 941 |
+
batch_size=args.train_batch_size,
|
| 942 |
new_fingerprint=new_fingerprint_for_vae,
|
| 943 |
)
|
| 944 |
precomputed_dataset = concatenate_datasets(
|
|
|
|
| 949 |
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
|
| 950 |
del text_encoders, tokenizers, vae
|
| 951 |
gc.collect()
|
| 952 |
+
if is_torch_npu_available():
|
| 953 |
+
torch_npu.npu.empty_cache()
|
| 954 |
+
elif torch.cuda.is_available():
|
| 955 |
+
torch.cuda.empty_cache()
|
| 956 |
|
| 957 |
def collate_fn(examples):
|
| 958 |
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
|
|
|
|
| 1101 |
|
| 1102 |
# Add noise to the model input according to the noise magnitude at each timestep
|
| 1103 |
# (this is the forward diffusion process)
|
| 1104 |
+
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
|
| 1105 |
|
| 1106 |
# time ids
|
| 1107 |
def compute_time_ids(original_size, crops_coords_top_left):
|
| 1108 |
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
| 1109 |
target_size = (args.resolution, args.resolution)
|
| 1110 |
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 1111 |
+
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
|
|
|
|
| 1112 |
return add_time_ids
|
| 1113 |
|
| 1114 |
add_time_ids = torch.cat(
|
|
|
|
| 1117 |
|
| 1118 |
# Predict the noise residual
|
| 1119 |
unet_added_conditions = {"time_ids": add_time_ids}
|
| 1120 |
+
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
|
| 1121 |
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
| 1122 |
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
| 1123 |
model_pred = unet(
|
|
|
|
| 1148 |
if args.snr_gamma is None:
|
| 1149 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 1150 |
else:
|
| 1151 |
+
# Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
|
| 1152 |
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
| 1153 |
# This is discussed in Section 4.2 of the same paper.
|
| 1154 |
snr = compute_snr(noise_scheduler, timesteps)
|
|
|
|
| 1186 |
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 1187 |
train_loss = 0.0
|
| 1188 |
|
| 1189 |
+
# DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
|
| 1190 |
+
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
|
| 1191 |
if global_step % args.checkpointing_steps == 0:
|
| 1192 |
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
| 1193 |
if args.checkpoints_total_limit is not None:
|
|
|
|
| 1253 |
pipeline.set_progress_bar_config(disable=True)
|
| 1254 |
|
| 1255 |
# run inference
|
| 1256 |
+
generator = (
|
| 1257 |
+
torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 1258 |
+
if args.seed is not None
|
| 1259 |
+
else None
|
| 1260 |
+
)
|
| 1261 |
pipeline_args = {"prompt": args.validation_prompt}
|
| 1262 |
|
| 1263 |
with autocast_ctx:
|
|
|
|
| 1281 |
)
|
| 1282 |
|
| 1283 |
del pipeline
|
| 1284 |
+
if is_torch_npu_available():
|
| 1285 |
+
torch_npu.npu.empty_cache()
|
| 1286 |
+
elif torch.cuda.is_available():
|
| 1287 |
+
torch.cuda.empty_cache()
|
| 1288 |
|
| 1289 |
if args.use_ema:
|
| 1290 |
# Switch back to the original UNet parameters.
|
|
|
|
| 1321 |
images = []
|
| 1322 |
if args.validation_prompt and args.num_validation_images > 0:
|
| 1323 |
pipeline = pipeline.to(accelerator.device)
|
| 1324 |
+
generator = (
|
| 1325 |
+
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
| 1326 |
+
)
|
| 1327 |
|
| 1328 |
with autocast_ctx:
|
| 1329 |
images = [
|
|
|
|
| 1367 |
|
| 1368 |
if __name__ == "__main__":
|
| 1369 |
args = parse_args()
|
| 1370 |
+
main(args)
|
|
|