XiangpengYang commited on
Commit
91db0c0
·
1 Parent(s): 2c4b825
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -26,6 +26,14 @@ from videox_fun.data.dataset_image_video import derive_ground_object_from_instru
26
  from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
27
  from videox_fun.utils.utils import save_videos_grid, timer
28
 
 
 
 
 
 
 
 
 
29
  # Redefine create_height_width to remove Chinese and specific defaults if needed,
30
  # although we will mostly ignore sliders if we use input resolution.
31
  # We will create a custom version here to avoid modifying the library file if possible,
@@ -150,7 +158,6 @@ def preload_models(controller, default_model_path, default_lora_name, acc_lora_p
150
  torch.cuda.empty_cache()
151
 
152
  class VideoCoF_Controller(Wan_Controller):
153
- @spaces.GPU(duration=300)
154
  @timer
155
  def generate(
156
  self,
@@ -362,6 +369,8 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
362
  config_path=config_path, compile_dit=compile_dit,
363
  weight_dtype=weight_dtype
364
  )
 
 
365
 
366
  with gr.Blocks() as demo:
367
  gr.Markdown("# VideoCoF Demo")
@@ -472,7 +481,7 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
472
 
473
  # Event handlers
474
  generate_button.click(
475
- fn=controller.generate,
476
  inputs=[
477
  diffusion_transformer_dropdown,
478
  base_model_dropdown,
 
26
  from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
27
  from videox_fun.utils.utils import save_videos_grid, timer
28
 
29
+ global_controller = None
30
+
31
+ @spaces.GPU(duration=300)
32
+ @timer
33
+ def generate_wrapper(*args):
34
+ global global_controller
35
+ return global_controller.generate(*args)
36
+
37
  # Redefine create_height_width to remove Chinese and specific defaults if needed,
38
  # although we will mostly ignore sliders if we use input resolution.
39
  # We will create a custom version here to avoid modifying the library file if possible,
 
158
  torch.cuda.empty_cache()
159
 
160
  class VideoCoF_Controller(Wan_Controller):
 
161
  @timer
162
  def generate(
163
  self,
 
369
  config_path=config_path, compile_dit=compile_dit,
370
  weight_dtype=weight_dtype
371
  )
372
+ global global_controller
373
+ global_controller = controller
374
 
375
  with gr.Blocks() as demo:
376
  gr.Markdown("# VideoCoF Demo")
 
481
 
482
  # Event handlers
483
  generate_button.click(
484
+ fn=generate_wrapper,
485
  inputs=[
486
  diffusion_transformer_dropdown,
487
  base_model_dropdown,