Pre-Compiled AoTI

#23
by multimodalart HF Staff - opened
Files changed (2) hide show
  1. app.py +8 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,10 +5,10 @@ import torch
5
  import spaces
6
  from PIL import Image
7
  from diffusers import FlowMatchEulerDiscreteScheduler
8
- from optimization import optimize_pipeline_
9
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
10
- from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
11
- from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
12
  import math
13
 
14
  # --- Model Loading ---
@@ -51,9 +51,11 @@ pipe.load_lora_weights(
51
  )
52
  pipe.fuse_lora(lora_scale=1.0)
53
 
54
- pipe.transformer.__class__ = QwenImageTransformer2DModel
55
- pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
56
- optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
 
 
57
 
58
  # --- Constants ---
59
  MAX_SEED = np.iinfo(np.int32).max
 
5
  import spaces
6
  from PIL import Image
7
  from diffusers import FlowMatchEulerDiscreteScheduler
8
+ # from optimization import optimize_pipeline_
9
  from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
10
+ # from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
11
+ # from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
12
  import math
13
 
14
  # --- Model Loading ---
 
51
  )
52
  pipe.fuse_lora(lora_scale=1.0)
53
 
54
+ spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/Qwen-Image", variant="fa3")
55
+
56
+ #pipe.transformer.__class__ = QwenImageTransformer2DModel
57
+ #pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
58
+ #optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
59
 
60
  # --- Constants ---
61
  MAX_SEED = np.iinfo(np.int32).max
requirements.txt CHANGED
@@ -8,4 +8,5 @@ dashscope
8
  kernels
9
  torchvision
10
  peft
11
- torchao==0.11.0
 
 
8
  kernels
9
  torchvision
10
  peft
11
+ torchao==0.11.0
12
+ torch==2.8