Spaces:
Running
Running
| import argparse | |
| from ui.components import create_main_demo_ui | |
| from pipeline_ace_step import ACEStepPipeline | |
| from data_sampler import DataSampler | |
| import os | |
| # 获取当前脚本的绝对路径,用于构建默认的存储路径 | |
| APP_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| parser = argparse.ArgumentParser() | |
| # 将 checkpoint_path 的默认值改为应用程序根目录下的 'checkpoints' 文件夹 | |
| parser.add_argument("--checkpoint_path", type=str, default=os.path.join(APP_ROOT, "checkpoints")) | |
| parser.add_argument("--server_name", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--device_id", type=int, default=0) | |
| parser.add_argument("--share", action='store_true', default=False) | |
| parser.add_argument("--bf16", action='store_true', default=True) | |
| parser.add_argument("--torch_compile", type=bool, default=False) | |
| args = parser.parse_args() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) | |
| # 将 persistent_storage_path 的默认值改为应用程序根目录下的 'persistent_data' 文件夹 | |
| persistent_storage_path = os.path.join(APP_ROOT, "persistent_data") | |
| def main(args): | |
| print(f"Using checkpoint path: {args.checkpoint_path}") | |
| print(f"Using persistent storage path: {persistent_storage_path}") | |
| model_demo = ACEStepPipeline( | |
| checkpoint_dir=args.checkpoint_path, | |
| dtype="bfloat16" if args.bf16 else "float32", | |
| persistent_storage_path=persistent_storage_path, # 传递修改后的路径 | |
| torch_compile=args.torch_compile | |
| ) | |
| data_sampler = DataSampler() | |
| demo = create_main_demo_ui( | |
| text2music_process_func=model_demo.__call__, | |
| sample_data_func=data_sampler.sample, | |
| load_data_func=data_sampler.load_json, | |
| ) | |
| demo.queue(default_concurrency_limit=8).launch( | |
| server_name=args.server_name, # 添加这一行以使用命令行参数 | |
| server_port=args.port, # 添加这一行以使用命令行参数 | |
| share=args.share # 添加这一行以使用命令行参数 | |
| ) | |
| if __name__ == "__main__": | |
| main(args) |