import os import sys from pathlib import Path from typing import Any, Dict, List, Optional import srsly from click import NoSuchOption from click.parser import split_arg_string from confection import Config from wasabi import msg from ..cli.main import PROJECT_FILE from ..schemas import ProjectConfigSchema, validate from .environment import ENV_VARS from .frozen import SimpleFrozenDict from .logging import logger from .validation import show_validation_error, validate_project_commands def parse_config_overrides( args: List[str], env_var: Optional[str] = ENV_VARS.CONFIG_OVERRIDES ) -> Dict[str, Any]: """Generate a dictionary of config overrides based on the extra arguments provided on the CLI, e.g. --training.batch_size to override "training.batch_size". Arguments without a "." are considered invalid, since the config only allows top-level sections to exist. env_vars (Optional[str]): Optional environment variable to read from. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. """ env_string = os.environ.get(env_var, "") if env_var else "" env_overrides = _parse_overrides(split_arg_string(env_string)) cli_overrides = _parse_overrides(args, is_cli=True) if cli_overrides: keys = [k for k in cli_overrides if k not in env_overrides] logger.debug("Config overrides from CLI: %s", keys) if env_overrides: logger.debug("Config overrides from env variables: %s", list(env_overrides)) return {**cli_overrides, **env_overrides} def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]: result = {} while args: opt = args.pop(0) err = f"Invalid config override '{opt}'" if opt.startswith("--"): # new argument orig_opt = opt opt = opt.replace("--", "") if "." not in opt: if is_cli: raise NoSuchOption(orig_opt) else: msg.fail(f"{err}: can't override top-level sections", exits=1) if "=" in opt: # we have --opt=value opt, value = opt.split("=", 1) opt = opt.replace("-", "_") else: if not args or args[0].startswith("--"): # flag with no value value = "true" else: value = args.pop(0) result[opt] = _parse_override(value) else: msg.fail(f"{err}: name should start with --", exits=1) return result def _parse_override(value: Any) -> Any: # Just like we do in the config, we're calling json.loads on the # values. But since they come from the CLI, it'd be unintuitive to # explicitly mark strings with escaped quotes. So we're working # around that here by falling back to a string if parsing fails. # TODO: improve logic to handle simple types like list of strings? try: return srsly.json_loads(value) except ValueError: return str(value) def load_project_config( path: Path, interpolate: bool = True, overrides: Dict[str, Any] = SimpleFrozenDict() ) -> Dict[str, Any]: """Load the project.yml file from a directory and validate it. Also make sure that all directories defined in the config exist. path (Path): The path to the project directory. interpolate (bool): Whether to substitute project variables. overrides (Dict[str, Any]): Optional config overrides. RETURNS (Dict[str, Any]): The loaded project.yml. """ config_path = path / PROJECT_FILE if not config_path.exists(): msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1) invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct." try: config = srsly.read_yaml(config_path) except ValueError as e: msg.fail(invalid_err, e, exits=1) errors = validate(ProjectConfigSchema, config) if errors: msg.fail(invalid_err) print("\n".join(errors)) sys.exit(1) validate_project_commands(config) if interpolate: err = f"{PROJECT_FILE} validation error" with show_validation_error(title=err, hint_fill=False): config = substitute_project_variables(config, overrides) # Make sure directories defined in config exist for subdir in config.get("directories", []): dir_path = path / subdir if not dir_path.exists(): dir_path.mkdir(parents=True) return config def substitute_project_variables( config: Dict[str, Any], overrides: Dict[str, Any] = SimpleFrozenDict(), key: str = "vars", env_key: str = "env", ) -> Dict[str, Any]: """Interpolate variables in the project file using the config system. config (Dict[str, Any]): The project config. overrides (Dict[str, Any]): Optional config overrides. key (str): Key containing variables in project config. env_key (str): Key containing environment variable mapping in project config. RETURNS (Dict[str, Any]): The interpolated project config. """ config.setdefault(key, {}) config.setdefault(env_key, {}) # Substitute references to env vars with their values for config_var, env_var in config[env_key].items(): config[env_key][config_var] = _parse_override(os.environ.get(env_var, "")) # Need to put variables in the top scope again so we can have a top-level # section "project" (otherwise, a list of commands in the top scope wouldn't) # be allowed by Thinc's config system cfg = Config({"project": config, key: config[key], env_key: config[env_key]}) cfg = Config().from_str(cfg.to_str(), overrides=overrides) interpolated = cfg.interpolate() return dict(interpolated["project"])