nroggendorff commited on
Commit
4381bf7
·
verified ·
1 Parent(s): 61dc47b

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +63 -28
train.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -35,7 +35,7 @@ import torch.utils.checkpoint
35
  import transformers
36
  from accelerate import Accelerator
37
  from accelerate.logging import get_logger
38
- from accelerate.utils import ProjectConfiguration, set_seed
39
  from datasets import concatenate_datasets, load_dataset
40
  from huggingface_hub import create_repo, upload_folder
41
  from packaging import version
@@ -50,18 +50,21 @@ from diffusers.optimization import get_scheduler
50
  from diffusers.training_utils import EMAModel, compute_snr
51
  from diffusers.utils import check_min_version, is_wandb_available
52
  from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
- from diffusers.utils.import_utils import is_xformers_available
54
  from diffusers.utils.torch_utils import is_compiled_module
55
 
56
 
57
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
- check_min_version("0.28.0.dev0")
59
 
60
  logger = get_logger(__name__)
 
 
61
 
 
62
 
63
  DATASET_NAME_MAPPING = {
64
- "lambdalabs/pokemon-blip-captions": ("image", "text"),
65
  }
66
 
67
 
@@ -389,7 +392,7 @@ def parse_args(input_args=None):
389
  type=float,
390
  default=None,
391
  help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
392
- "More details here: https://arxiv.org/abs/2303.09556.",
393
  )
394
  parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
395
  parser.add_argument(
@@ -460,10 +463,22 @@ def parse_args(input_args=None):
460
  ),
461
  )
462
  parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
 
 
 
463
  parser.add_argument(
464
  "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
465
  )
466
  parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
 
 
 
 
 
 
 
 
 
467
 
468
  if input_args is not None:
469
  args = parser.parse_args(input_args)
@@ -477,7 +492,6 @@ def parse_args(input_args=None):
477
  # Sanity checks
478
  if args.dataset_name is None and args.train_data_dir is None:
479
  raise ValueError("Need either a dataset name or a training folder.")
480
-
481
  if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
482
  raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
483
 
@@ -536,6 +550,9 @@ def compute_vae_encodings(batch, vae):
536
  with torch.no_grad():
537
  model_input = vae.encode(pixel_values).latent_dist.sample()
538
  model_input = model_input * vae.config.scaling_factor
 
 
 
539
  return {"model_input": model_input.cpu()}
540
 
541
 
@@ -584,7 +601,7 @@ def main(args):
584
  if args.report_to == "wandb" and args.hub_token is not None:
585
  raise ValueError(
586
  "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
587
- " Please use `huggingface-cli login` to authenticate with the Hub."
588
  )
589
 
590
  logging_dir = Path(args.output_dir, args.logging_dir)
@@ -716,7 +733,12 @@ def main(args):
716
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
717
  )
718
  ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
719
-
 
 
 
 
 
720
  if args.enable_xformers_memory_efficient_attention:
721
  if is_xformers_available():
722
  import xformers
@@ -742,7 +764,8 @@ def main(args):
742
  model.save_pretrained(os.path.join(output_dir, "unet"))
743
 
744
  # make sure to pop weight so that corresponding model is not saved again
745
- weights.pop()
 
746
 
747
  def load_model_hook(models, input_dir):
748
  if args.use_ema:
@@ -809,9 +832,7 @@ def main(args):
809
  if args.dataset_name is not None:
810
  # Downloading and loading a dataset from the hub.
811
  dataset = load_dataset(
812
- args.dataset_name,
813
- args.dataset_config_name,
814
- cache_dir=args.cache_dir,
815
  )
816
  else:
817
  data_files = {}
@@ -849,7 +870,10 @@ def main(args):
849
  )
850
 
851
  # Preprocessing the datasets.
852
- train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
 
 
 
853
  train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
854
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
855
  train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
@@ -907,14 +931,14 @@ def main(args):
907
  # fingerprint used by the cache for the other processes to load the result
908
  # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
909
  new_fingerprint = Hasher.hash(args)
910
- new_fingerprint_for_vae = Hasher.hash(vae_path)
911
  train_dataset_with_embeddings = train_dataset.map(
912
  compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
913
  )
914
  train_dataset_with_vae = train_dataset.map(
915
  compute_vae_encodings_fn,
916
  batched=True,
917
- batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
918
  new_fingerprint=new_fingerprint_for_vae,
919
  )
920
  precomputed_dataset = concatenate_datasets(
@@ -925,7 +949,10 @@ def main(args):
925
  del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
926
  del text_encoders, tokenizers, vae
927
  gc.collect()
928
- torch.cuda.empty_cache()
 
 
 
929
 
930
  def collate_fn(examples):
931
  model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1074,15 +1101,14 @@ def main(args):
1074
 
1075
  # Add noise to the model input according to the noise magnitude at each timestep
1076
  # (this is the forward diffusion process)
1077
- noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1078
 
1079
  # time ids
1080
  def compute_time_ids(original_size, crops_coords_top_left):
1081
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1082
  target_size = (args.resolution, args.resolution)
1083
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1084
- add_time_ids = torch.tensor([add_time_ids])
1085
- add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1086
  return add_time_ids
1087
 
1088
  add_time_ids = torch.cat(
@@ -1091,7 +1117,7 @@ def main(args):
1091
 
1092
  # Predict the noise residual
1093
  unet_added_conditions = {"time_ids": add_time_ids}
1094
- prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
1095
  pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
1096
  unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1097
  model_pred = unet(
@@ -1122,7 +1148,7 @@ def main(args):
1122
  if args.snr_gamma is None:
1123
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1124
  else:
1125
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1126
  # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1127
  # This is discussed in Section 4.2 of the same paper.
1128
  snr = compute_snr(noise_scheduler, timesteps)
@@ -1160,7 +1186,8 @@ def main(args):
1160
  accelerator.log({"train_loss": train_loss}, step=global_step)
1161
  train_loss = 0.0
1162
 
1163
- if accelerator.is_main_process:
 
1164
  if global_step % args.checkpointing_steps == 0:
1165
  # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1166
  if args.checkpoints_total_limit is not None:
@@ -1226,7 +1253,11 @@ def main(args):
1226
  pipeline.set_progress_bar_config(disable=True)
1227
 
1228
  # run inference
1229
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
 
 
 
 
1230
  pipeline_args = {"prompt": args.validation_prompt}
1231
 
1232
  with autocast_ctx:
@@ -1250,7 +1281,10 @@ def main(args):
1250
  )
1251
 
1252
  del pipeline
1253
- torch.cuda.empty_cache()
 
 
 
1254
 
1255
  if args.use_ema:
1256
  # Switch back to the original UNet parameters.
@@ -1287,7 +1321,9 @@ def main(args):
1287
  images = []
1288
  if args.validation_prompt and args.num_validation_images > 0:
1289
  pipeline = pipeline.to(accelerator.device)
1290
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
 
 
1291
 
1292
  with autocast_ctx:
1293
  images = [
@@ -1331,5 +1367,4 @@ def main(args):
1331
 
1332
  if __name__ == "__main__":
1333
  args = parse_args()
1334
- main(args)
1335
- raise RuntimeError("The script is finished.")
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
35
  import transformers
36
  from accelerate import Accelerator
37
  from accelerate.logging import get_logger
38
+ from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
39
  from datasets import concatenate_datasets, load_dataset
40
  from huggingface_hub import create_repo, upload_folder
41
  from packaging import version
 
50
  from diffusers.training_utils import EMAModel, compute_snr
51
  from diffusers.utils import check_min_version, is_wandb_available
52
  from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
54
  from diffusers.utils.torch_utils import is_compiled_module
55
 
56
 
57
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ check_min_version("0.36.0.dev0")
59
 
60
  logger = get_logger(__name__)
61
+ if is_torch_npu_available():
62
+ import torch_npu
63
 
64
+ torch.npu.config.allow_internal_format = False
65
 
66
  DATASET_NAME_MAPPING = {
67
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
68
  }
69
 
70
 
 
392
  type=float,
393
  default=None,
394
  help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
395
+ "More details here: https://huggingface.co/papers/2303.09556.",
396
  )
397
  parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
398
  parser.add_argument(
 
463
  ),
464
  )
465
  parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
466
+ parser.add_argument(
467
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
468
+ )
469
  parser.add_argument(
470
  "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
471
  )
472
  parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
473
+ parser.add_argument(
474
+ "--image_interpolation_mode",
475
+ type=str,
476
+ default="lanczos",
477
+ choices=[
478
+ f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
479
+ ],
480
+ help="The image interpolation method to use for resizing images.",
481
+ )
482
 
483
  if input_args is not None:
484
  args = parser.parse_args(input_args)
 
492
  # Sanity checks
493
  if args.dataset_name is None and args.train_data_dir is None:
494
  raise ValueError("Need either a dataset name or a training folder.")
 
495
  if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
496
  raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
497
 
 
550
  with torch.no_grad():
551
  model_input = vae.encode(pixel_values).latent_dist.sample()
552
  model_input = model_input * vae.config.scaling_factor
553
+
554
+ # There might have slightly performance improvement
555
+ # by changing model_input.cpu() to accelerator.gather(model_input)
556
  return {"model_input": model_input.cpu()}
557
 
558
 
 
601
  if args.report_to == "wandb" and args.hub_token is not None:
602
  raise ValueError(
603
  "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
604
+ " Please use `hf auth login` to authenticate with the Hub."
605
  )
606
 
607
  logging_dir = Path(args.output_dir, args.logging_dir)
 
733
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
734
  )
735
  ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
736
+ if args.enable_npu_flash_attention:
737
+ if is_torch_npu_available():
738
+ logger.info("npu flash attention enabled.")
739
+ unet.enable_npu_flash_attention()
740
+ else:
741
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
742
  if args.enable_xformers_memory_efficient_attention:
743
  if is_xformers_available():
744
  import xformers
 
764
  model.save_pretrained(os.path.join(output_dir, "unet"))
765
 
766
  # make sure to pop weight so that corresponding model is not saved again
767
+ if weights:
768
+ weights.pop()
769
 
770
  def load_model_hook(models, input_dir):
771
  if args.use_ema:
 
832
  if args.dataset_name is not None:
833
  # Downloading and loading a dataset from the hub.
834
  dataset = load_dataset(
835
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
 
 
836
  )
837
  else:
838
  data_files = {}
 
870
  )
871
 
872
  # Preprocessing the datasets.
873
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
874
+ if interpolation is None:
875
+ raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
876
+ train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
877
  train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
878
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
879
  train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
 
931
  # fingerprint used by the cache for the other processes to load the result
932
  # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
933
  new_fingerprint = Hasher.hash(args)
934
+ new_fingerprint_for_vae = Hasher.hash((vae_path, args))
935
  train_dataset_with_embeddings = train_dataset.map(
936
  compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
937
  )
938
  train_dataset_with_vae = train_dataset.map(
939
  compute_vae_encodings_fn,
940
  batched=True,
941
+ batch_size=args.train_batch_size,
942
  new_fingerprint=new_fingerprint_for_vae,
943
  )
944
  precomputed_dataset = concatenate_datasets(
 
949
  del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
950
  del text_encoders, tokenizers, vae
951
  gc.collect()
952
+ if is_torch_npu_available():
953
+ torch_npu.npu.empty_cache()
954
+ elif torch.cuda.is_available():
955
+ torch.cuda.empty_cache()
956
 
957
  def collate_fn(examples):
958
  model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
 
1101
 
1102
  # Add noise to the model input according to the noise magnitude at each timestep
1103
  # (this is the forward diffusion process)
1104
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
1105
 
1106
  # time ids
1107
  def compute_time_ids(original_size, crops_coords_top_left):
1108
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1109
  target_size = (args.resolution, args.resolution)
1110
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1111
+ add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
 
1112
  return add_time_ids
1113
 
1114
  add_time_ids = torch.cat(
 
1117
 
1118
  # Predict the noise residual
1119
  unet_added_conditions = {"time_ids": add_time_ids}
1120
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
1121
  pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
1122
  unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1123
  model_pred = unet(
 
1148
  if args.snr_gamma is None:
1149
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1150
  else:
1151
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
1152
  # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1153
  # This is discussed in Section 4.2 of the same paper.
1154
  snr = compute_snr(noise_scheduler, timesteps)
 
1186
  accelerator.log({"train_loss": train_loss}, step=global_step)
1187
  train_loss = 0.0
1188
 
1189
+ # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1190
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1191
  if global_step % args.checkpointing_steps == 0:
1192
  # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1193
  if args.checkpoints_total_limit is not None:
 
1253
  pipeline.set_progress_bar_config(disable=True)
1254
 
1255
  # run inference
1256
+ generator = (
1257
+ torch.Generator(device=accelerator.device).manual_seed(args.seed)
1258
+ if args.seed is not None
1259
+ else None
1260
+ )
1261
  pipeline_args = {"prompt": args.validation_prompt}
1262
 
1263
  with autocast_ctx:
 
1281
  )
1282
 
1283
  del pipeline
1284
+ if is_torch_npu_available():
1285
+ torch_npu.npu.empty_cache()
1286
+ elif torch.cuda.is_available():
1287
+ torch.cuda.empty_cache()
1288
 
1289
  if args.use_ema:
1290
  # Switch back to the original UNet parameters.
 
1321
  images = []
1322
  if args.validation_prompt and args.num_validation_images > 0:
1323
  pipeline = pipeline.to(accelerator.device)
1324
+ generator = (
1325
+ torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
1326
+ )
1327
 
1328
  with autocast_ctx:
1329
  images = [
 
1367
 
1368
  if __name__ == "__main__":
1369
  args = parse_args()
1370
+ main(args)