|
|
|
|
|
|
|
|
import logging |
|
|
from typing import Type, TypeVar |
|
|
|
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf |
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
def set_struct_recursively(cfg, strict: bool = True): |
|
|
|
|
|
OmegaConf.set_struct(cfg, strict) |
|
|
|
|
|
|
|
|
if isinstance(cfg, DictConfig): |
|
|
for key, value in cfg.items(): |
|
|
if isinstance(value, (DictConfig, ListConfig)): |
|
|
set_struct_recursively(value, strict) |
|
|
elif isinstance(cfg, ListConfig): |
|
|
for item in cfg: |
|
|
if isinstance(item, (DictConfig, ListConfig)): |
|
|
set_struct_recursively(item, strict) |
|
|
|
|
|
|
|
|
def flatten_dict(d, parent_key="", sep="_"): |
|
|
items = [] |
|
|
for k, v in d.items(): |
|
|
new_key = f"{parent_key}{sep}{k}" if parent_key else k |
|
|
if isinstance(v, dict): |
|
|
items.extend(flatten_dict(v, new_key, sep=sep).items()) |
|
|
else: |
|
|
items.append((new_key, v)) |
|
|
return dict(items) |
|
|
|
|
|
|
|
|
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T: |
|
|
""" |
|
|
Converts a dictionary to a dataclass instance, recursively for nested structures. |
|
|
""" |
|
|
base = OmegaConf.structured(cls()) |
|
|
OmegaConf.set_struct(base, strict) |
|
|
override = OmegaConf.create(data) |
|
|
return OmegaConf.to_object(OmegaConf.merge(base, override)) |
|
|
|
|
|
|
|
|
def dataclass_to_dict(dataclass_instance: T) -> dict: |
|
|
""" |
|
|
Converts a dataclass instance to a dictionary, recursively for nested structures. |
|
|
""" |
|
|
if isinstance(dataclass_instance, dict): |
|
|
return dataclass_instance |
|
|
|
|
|
return OmegaConf.to_container( |
|
|
OmegaConf.structured(dataclass_instance), resolve=True |
|
|
) |
|
|
|
|
|
|
|
|
def load_config_file(config_file, dataclass_cls: Type[T]) -> T: |
|
|
config = OmegaConf.to_container(OmegaConf.load(config_file), resolve=True) |
|
|
return dataclass_from_dict(dataclass_cls, config) |
|
|
|
|
|
|
|
|
def dump_config(config, path, log_config=True): |
|
|
yaml_dump = OmegaConf.to_yaml(OmegaConf.structured(config)) |
|
|
with open(path, "w") as f: |
|
|
if log_config: |
|
|
logger.info("Using the following config for this run:") |
|
|
logger.info(yaml_dump) |
|
|
f.write(yaml_dump) |
|
|
|