Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import logging
from argparse import Namespace
from typing import Any, Dict, List
from omegaconf import DictConfig, OmegaConf
logger = logging.getLogger(__name__)
def log_config(recipe_name: str, cfg: DictConfig) -> None:
"""
Logs the resolved config (merged YAML file and CLI overrides) to rank zero.
Args:
recipe_name (str): name of the recipe to display
cfg (DictConfig): parsed config object
"""
# Log the config only on rank 0
cfg_str = OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True)
logger.info(
msg=f"Running {recipe_name} with resolved config:\n\n{cfg_str}"
)
def _merge_yaml_and_cli_args(
yaml_args: Namespace, cli_args: List[str]
) -> DictConfig:
"""
Takes the direct output of argparse's parse_known_args which returns known
args as a Namespace and unknown args as a dotlist (in our case, yaml args and
cli args, respectively) and merges them into a single OmegaConf DictConfig.
If a cli arg overrides a yaml arg with a _component_ field, the cli arg can
be specified with the parent field directly, e.g., model=sarus_llm.models.lora_llama2_7b
instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the
component should be specified with dot notation, e.g., model.lora_rank=16.
Example:
>>> config.yaml:
>>> a: 1
>>> b:
>>> _component_: torchtune.models.my_model
>>> c: 3
>>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4
>>> yaml_args, cli_args = parser.parse_known_args()
>>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
>>> print(conf)
>>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}}
Args:
yaml_args (Namespace): Namespace containing args from yaml file, components
should have _component_ fields
cli_args (List[str]): List of key=value strings
Returns:
DictConfig: OmegaConf DictConfig containing merged args
Raises:
ValueError: If a cli override is not in the form of key=value
"""
# Convert Namespace to simple dict
yaml_kwargs = vars(yaml_args)
cli_dotlist = []
for arg in cli_args:
# If CLI override uses the remove flag (~), remove the key from the yaml config
if arg.startswith("~"):
dotpath = arg[1:].split("=")[0]
if "_component_" in dotpath:
raise ValueError(
f"Removing components from CLI is not supported: ~{dotpath}"
)
try:
_remove_key_by_dotpath(yaml_kwargs, dotpath)
except (KeyError, ValueError):
raise ValueError(
f"Could not find key {dotpath} in yaml config to remove"
) from None
continue
# Get other overrides that should be specified as key=value
try:
k, v = arg.split("=")
except ValueError:
raise ValueError(
f"Command-line overrides must be in the form of key=value, got {arg}"
) from None
cli_dotlist.append(f"{k}={v}")
# Merge the args
cli_conf = OmegaConf.from_dotlist(cli_dotlist)
yaml_conf = OmegaConf.create(yaml_kwargs)
# CLI takes precedence over yaml args
return OmegaConf.merge(yaml_conf, cli_conf)
def _remove_key_by_dotpath(nested_dict: Dict[str, Any], dotpath: str) -> None:
"""
Removes a key specified by dotpath from a nested dict. Errors should handled by
the calling function.
Args:
d (Dict[str, Any]): Dict to remove key from
dotpath (str): dotpath of key to remove, e.g., "a.b.c"
"""
path = dotpath.split(".")
def recurse_and_delete(d: Dict[str, Any], path: List[str]) -> None:
if len(path) == 1:
del d[path[0]]
else:
recurse_and_delete(d[path[0]], path[1:])
if not d[path[0]]:
del d[path[0]]
recurse_and_delete(nested_dict, path)