Spaces:
Sleeping
Sleeping
| import click | |
| import os | |
| import sys | |
| import importlib | |
| import importlib.util | |
| import json | |
| from click.core import Context, Option | |
| from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__ | |
| from ding.framework import Parallel | |
| from ding.entry.cli_parsers import PLATFORM_PARSERS | |
| def print_version(ctx: Context, param: Option, value: bool) -> None: | |
| if not value or ctx.resilient_parsing: | |
| return | |
| click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__)) | |
| click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__)) | |
| ctx.exit() | |
| CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) | |
| def cli_ditask(*args, **kwargs): | |
| return _cli_ditask(*args, **kwargs) | |
| def _parse_platform_args(platform: str, platform_spec: str, all_args: dict): | |
| if platform_spec: | |
| try: | |
| if os.path.splitext(platform_spec) == "json": | |
| with open(platform_spec) as f: | |
| platform_spec = json.load(f) | |
| else: | |
| platform_spec = json.loads(platform_spec) | |
| except: | |
| click.echo("platform_spec is not a valid json!") | |
| exit(1) | |
| if platform not in PLATFORM_PARSERS: | |
| click.echo("platform type is invalid! type: {}".format(platform)) | |
| exit(1) | |
| all_args.pop("platform") | |
| all_args.pop("platform_spec") | |
| try: | |
| parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args) | |
| except Exception as e: | |
| click.echo("error when parse platform spec configure: {}".format(e)) | |
| raise e | |
| return parsed_args | |
| def _cli_ditask( | |
| package: str, | |
| main: str, | |
| parallel_workers: int, | |
| protocol: str, | |
| ports: str, | |
| attach_to: str, | |
| address: str, | |
| labels: str, | |
| node_ids: str, | |
| topology: str, | |
| mq_type: str, | |
| redis_host: str, | |
| redis_port: int, | |
| startup_interval: int, | |
| local_rank: int = 0, | |
| platform: str = None, | |
| platform_spec: str = None, | |
| ): | |
| # Parse entry point | |
| all_args = locals() | |
| if platform: | |
| parsed_args = _parse_platform_args(platform, platform_spec, all_args) | |
| return _cli_ditask(**parsed_args) | |
| if not package: | |
| package = os.getcwd() | |
| sys.path.append(package) | |
| if main is None: | |
| mod_name = os.path.basename(package) | |
| mod_name, _ = os.path.splitext(mod_name) | |
| func_name = "main" | |
| else: | |
| mod_name, func_name = main.rsplit(".", 1) | |
| root_mod_name = mod_name.split(".", 1)[0] | |
| sys.path.append(os.path.join(package, root_mod_name)) | |
| mod = importlib.import_module(mod_name) | |
| main_func = getattr(mod, func_name) | |
| # Parse arguments | |
| ports = ports or 50515 | |
| if not isinstance(ports, int): | |
| ports = ports.split(",") | |
| ports = list(map(lambda i: int(i), ports)) | |
| ports = ports[0] if len(ports) == 1 else ports | |
| if attach_to: | |
| attach_to = attach_to.split(",") | |
| attach_to = list(map(lambda s: s.strip(), attach_to)) | |
| if labels: | |
| labels = labels.split(",") | |
| labels = set(map(lambda s: s.strip(), labels)) | |
| if node_ids and not isinstance(node_ids, int): | |
| node_ids = node_ids.split(",") | |
| node_ids = list(map(lambda i: int(i), node_ids)) | |
| Parallel.runner( | |
| n_parallel_workers=parallel_workers, | |
| ports=ports, | |
| protocol=protocol, | |
| topology=topology, | |
| attach_to=attach_to, | |
| address=address, | |
| labels=labels, | |
| node_ids=node_ids, | |
| mq_type=mq_type, | |
| redis_host=redis_host, | |
| redis_port=redis_port, | |
| startup_interval=startup_interval | |
| )(main_func) | |