| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import importlib |
| | import os |
| | import pkgutil |
| | import sys |
| | from dataclasses import fields as dataclass_fields |
| | from dataclasses import is_dataclass |
| | from typing import Any, Dict, Optional |
| |
|
| | import attr |
| | import attrs |
| | from hydra import compose, initialize |
| | from hydra.core.config_store import ConfigStore |
| | from omegaconf import DictConfig, OmegaConf |
| |
|
| | from .log import log |
| | from .config import Config |
| | from .inference import * |
| |
|
| |
|
| | def is_attrs_or_dataclass(obj) -> bool: |
| | """ |
| | Check if the object is an instance of an attrs class or a dataclass. |
| | |
| | Args: |
| | obj: The object to check. |
| | |
| | Returns: |
| | bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. |
| | """ |
| | return is_dataclass(obj) or attr.has(type(obj)) |
| |
|
| |
|
| | def get_fields(obj): |
| | """ |
| | Get the fields of an attrs class or a dataclass. |
| | |
| | Args: |
| | obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. |
| | |
| | Returns: |
| | list: A list of field names. |
| | |
| | Raises: |
| | ValueError: If the object is neither an attrs class nor a dataclass. |
| | """ |
| | if is_dataclass(obj): |
| | return [field.name for field in dataclass_fields(obj)] |
| | elif attr.has(type(obj)): |
| | return [field.name for field in attr.fields(type(obj))] |
| | else: |
| | raise ValueError("The object is neither an attrs class nor a dataclass.") |
| |
|
| |
|
| | def override(config: Config, overrides: Optional[list[str]] = None) -> Config: |
| | """ |
| | :param config: the instance of class `Config` (usually from `make_config`) |
| | :param overrides: list of overrides for config |
| | :return: the composed instance of class `Config` |
| | """ |
| | |
| | |
| |
|
| | |
| | config_dict = attrs.asdict(config) |
| | config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) |
| | |
| | if overrides: |
| | if overrides[0] != "--": |
| | raise ValueError('Hydra config overrides must be separated with a "--" token.') |
| | overrides = overrides[1:] |
| | |
| | cs = ConfigStore.instance() |
| | cs.store(name="config", node=config_omegaconf) |
| | with initialize(version_base=None): |
| | config_omegaconf = compose(config_name="config", overrides=overrides) |
| | OmegaConf.resolve(config_omegaconf) |
| |
|
| | def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: |
| | """ |
| | Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data |
| | |
| | Args: |
| | ref_instance: The reference instance to determine the type and fields when needed |
| | kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data |
| | |
| | Returns: |
| | Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data |
| | |
| | Raises: |
| | AssertionError: If the fields do not match or if extra keys are found. |
| | Exception: If there is an error constructing the new instance. |
| | """ |
| | is_type = is_attrs_or_dataclass(ref_instance) |
| | if not is_type: |
| | return kwargs |
| | else: |
| | ref_fields = set(get_fields(ref_instance)) |
| | assert isinstance(kwargs, dict) or isinstance( |
| | kwargs, DictConfig |
| | ), "kwargs must be a dictionary or a DictConfig" |
| | keys = set(kwargs.keys()) |
| |
|
| | |
| | extra_keys = keys - ref_fields |
| | assert ref_fields == keys or keys.issubset( |
| | ref_fields |
| | ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" |
| |
|
| | resolved_kwargs: Dict[str, Any] = {} |
| | for f in keys: |
| | resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) |
| | try: |
| | new_instance = type(ref_instance)(**resolved_kwargs) |
| | except Exception as e: |
| | log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") |
| | log.error(e) |
| | raise e |
| | return new_instance |
| |
|
| | config = config_from_dict(config, config_omegaconf) |
| |
|
| | return config |
| |
|
| |
|
| | def get_config_module(config_file: str) -> str: |
| | if not config_file.endswith(".py"): |
| | log.error("Config file cannot be specified as module.") |
| | log.error("Please provide the path to the Python config file (relative to the Cosmos root).") |
| | assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found." |
| | |
| | config_module = config_file.replace("/", ".").replace(".py", "") |
| | return config_module |
| |
|
| |
|
| | def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: |
| | """ |
| | Import all modules from the specified package path recursively. |
| | |
| | This function is typically used in conjunction with Hydra to ensure that all modules |
| | within a specified package are imported, which is necessary for registering configurations. |
| | |
| | Example usage: |
| | ```python |
| | import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False) |
| | ``` |
| | |
| | Args: |
| | package_path (str): The dotted path to the package from which to import all modules. |
| | reload (bool): Flag to determine whether to reload modules if they're already imported. |
| | skip_underscore (bool): If True, skips importing modules that start with an underscore. |
| | """ |
| | return |
| | log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") |
| | package = importlib.import_module(package_path) |
| | package_directory = package.__path__ |
| |
|
| | def import_modules_recursively(directory: str, prefix: str) -> None: |
| | """ |
| | Recursively imports or reloads all modules in the given directory. |
| | |
| | Args: |
| | directory (str): The file system path to the current package directory. |
| | prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config'). |
| | """ |
| | for _, module_name, is_pkg in pkgutil.iter_modules([directory]): |
| | if skip_underscore and module_name.startswith("_"): |
| | log.debug(f"Skipping module {module_name} as it starts with an underscore") |
| | continue |
| |
|
| | full_module_name = f"{prefix}.{module_name}" |
| | log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") |
| |
|
| | if full_module_name in sys.modules and reload: |
| | importlib.reload(sys.modules[full_module_name]) |
| | else: |
| | importlib.import_module(full_module_name) |
| |
|
| | if is_pkg: |
| | sub_package_directory = os.path.join(directory, module_name) |
| | import_modules_recursively(sub_package_directory, full_module_name) |
| |
|
| | for directory in package_directory: |
| | import_modules_recursively(directory, package_path) |
| |
|