1081 lines
46 KiB
Python
1081 lines
46 KiB
Python
import copy
|
||
import inspect
|
||
import io
|
||
import re
|
||
import warnings
|
||
from configparser import (
|
||
MAX_INTERPOLATION_DEPTH,
|
||
ConfigParser,
|
||
ExtendedInterpolation,
|
||
InterpolationDepthError,
|
||
InterpolationMissingOptionError,
|
||
InterpolationSyntaxError,
|
||
NoOptionError,
|
||
NoSectionError,
|
||
ParsingError,
|
||
)
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from types import GeneratorType
|
||
from typing import (
|
||
Any,
|
||
Callable,
|
||
Dict,
|
||
Iterable,
|
||
List,
|
||
Mapping,
|
||
Optional,
|
||
Sequence,
|
||
Tuple,
|
||
Type,
|
||
Union,
|
||
cast,
|
||
)
|
||
|
||
import srsly
|
||
|
||
try:
|
||
from pydantic.v1 import BaseModel, Extra, ValidationError, create_model
|
||
from pydantic.v1.fields import ModelField
|
||
from pydantic.v1.main import ModelMetaclass
|
||
except ImportError:
|
||
from pydantic import BaseModel, create_model, ValidationError, Extra # type: ignore
|
||
from pydantic.main import ModelMetaclass # type: ignore
|
||
from pydantic.fields import ModelField # type: ignore
|
||
|
||
from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401
|
||
|
||
# Field used for positional arguments, e.g. [section.*.xyz]. The alias is
|
||
# required for the schema (shouldn't clash with user-defined arg names)
|
||
ARGS_FIELD = "*"
|
||
ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS"
|
||
# Aliases for fields that would otherwise shadow pydantic attributes. Can be any
|
||
# string, so we're using name + space so it looks the same in error messages etc.
|
||
RESERVED_FIELDS = {"validate": "validate\u0020"}
|
||
# Internal prefix used to mark section references for custom interpolation
|
||
SECTION_PREFIX = "__SECTION__:"
|
||
# Values that shouldn't be loaded during interpolation because it'd cause
|
||
# even explicit string values to be incorrectly parsed as bools/None etc.
|
||
JSON_EXCEPTIONS = ("true", "false", "null")
|
||
# Regex to detect whether a value contains a variable
|
||
VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}")
|
||
|
||
|
||
class CustomInterpolation(ExtendedInterpolation):
|
||
def before_read(self, parser, section, option, value):
|
||
# If we're dealing with a quoted string as the interpolation value,
|
||
# make sure we load and unquote it so we don't end up with '"value"'
|
||
try:
|
||
json_value = srsly.json_loads(value)
|
||
if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS:
|
||
value = json_value
|
||
except ValueError:
|
||
if value and value[0] == value[-1] == "'":
|
||
warnings.warn(
|
||
f"The value [{value}] seems to be single-quoted, but values "
|
||
"use JSON formatting, which requires double quotes."
|
||
)
|
||
except Exception:
|
||
pass
|
||
return super().before_read(parser, section, option, value)
|
||
|
||
def before_get(self, parser, section, option, value, defaults):
|
||
# Mostly copy-pasted from the built-in configparser implementation.
|
||
L = []
|
||
self.interpolate(parser, option, L, value, section, defaults, 1)
|
||
return "".join(L)
|
||
|
||
def interpolate(self, parser, option, accum, rest, section, map, depth):
|
||
# Mostly copy-pasted from the built-in configparser implementation.
|
||
# We need to overwrite this method so we can add special handling for
|
||
# block references :( All values produced here should be strings –
|
||
# we need to wait until the whole config is interpreted anyways so
|
||
# filling in incomplete values here is pointless. All we need is the
|
||
# section reference so we can fetch it later.
|
||
rawval = parser.get(section, option, raw=True, fallback=rest)
|
||
if depth > MAX_INTERPOLATION_DEPTH:
|
||
raise InterpolationDepthError(option, section, rawval)
|
||
while rest:
|
||
p = rest.find("$")
|
||
if p < 0:
|
||
accum.append(rest)
|
||
return
|
||
if p > 0:
|
||
accum.append(rest[:p])
|
||
rest = rest[p:]
|
||
# p is no longer used
|
||
c = rest[1:2]
|
||
if c == "$":
|
||
accum.append("$")
|
||
rest = rest[2:]
|
||
elif c == "{":
|
||
# We want to treat both ${a:b} and ${a.b} the same
|
||
m = self._KEYCRE.match(rest)
|
||
if m is None:
|
||
err = f"bad interpolation variable reference {rest}"
|
||
raise InterpolationSyntaxError(option, section, err)
|
||
orig_var = m.group(1)
|
||
path = orig_var.replace(":", ".").rsplit(".", 1)
|
||
rest = rest[m.end() :]
|
||
sect = section
|
||
opt = option
|
||
try:
|
||
if len(path) == 1:
|
||
opt = parser.optionxform(path[0])
|
||
if opt in map:
|
||
v = map[opt]
|
||
else:
|
||
# We have block reference, store it as a special key
|
||
section_name = parser[parser.optionxform(path[0])]._name
|
||
v = self._get_section_name(section_name)
|
||
elif len(path) == 2:
|
||
sect = path[0]
|
||
opt = parser.optionxform(path[1])
|
||
fallback = "__FALLBACK__"
|
||
v = parser.get(sect, opt, raw=True, fallback=fallback)
|
||
# If a variable doesn't exist, try again and treat the
|
||
# reference as a section
|
||
if v == fallback:
|
||
v = self._get_section_name(parser[f"{sect}.{opt}"]._name)
|
||
else:
|
||
err = f"More than one ':' found: {rest}"
|
||
raise InterpolationSyntaxError(option, section, err)
|
||
except (KeyError, NoSectionError, NoOptionError):
|
||
raise InterpolationMissingOptionError(
|
||
option, section, rawval, orig_var
|
||
) from None
|
||
if "$" in v:
|
||
new_map = dict(parser.items(sect, raw=True))
|
||
self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1)
|
||
else:
|
||
accum.append(v)
|
||
else:
|
||
err = "'$' must be followed by '$' or '{', found: %r" % (rest,)
|
||
raise InterpolationSyntaxError(option, section, err)
|
||
|
||
def _get_section_name(self, name: str) -> str:
|
||
"""Generate the name of a section. Note that we use a quoted string here
|
||
so we can use section references within lists and load the list as
|
||
JSON. Since section references can't be used within strings, we don't
|
||
need the quoted vs. unquoted distinction like we do for variables.
|
||
|
||
Examples (assuming section = {"foo": 1}):
|
||
- value: ${section.foo} -> value: 1
|
||
- value: "hello ${section.foo}" -> value: "hello 1"
|
||
- value: ${section} -> value: {"foo": 1}
|
||
- value: "${section}" -> value: {"foo": 1}
|
||
- value: "hello ${section}" -> invalid
|
||
"""
|
||
return f'"{SECTION_PREFIX}{name}"'
|
||
|
||
|
||
def get_configparser(interpolate: bool = True):
|
||
config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None)
|
||
# Preserve case of keys: https://stackoverflow.com/a/1611877/6400719
|
||
config.optionxform = str # type: ignore
|
||
return config
|
||
|
||
|
||
class Config(dict):
|
||
"""This class holds the model and training configuration and can load and
|
||
save the TOML-style configuration format from/to a string, file or bytes.
|
||
The Config class is a subclass of dict and uses Python's ConfigParser
|
||
under the hood.
|
||
"""
|
||
|
||
is_interpolated: bool
|
||
|
||
def __init__(
|
||
self,
|
||
data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None,
|
||
*,
|
||
is_interpolated: Optional[bool] = None,
|
||
section_order: Optional[List[str]] = None,
|
||
) -> None:
|
||
"""Initialize a new Config object with optional data."""
|
||
dict.__init__(self)
|
||
if data is None:
|
||
data = {}
|
||
if not isinstance(data, (dict, Config, ConfigParser)):
|
||
raise ValueError(
|
||
f"Can't initialize Config with data. Expected dict, Config or "
|
||
f"ConfigParser but got: {type(data)}"
|
||
)
|
||
# Whether the config has been interpolated. We can use this to check
|
||
# whether we need to interpolate again when it's resolved. We assume
|
||
# that a config is interpolated by default.
|
||
if is_interpolated is not None:
|
||
self.is_interpolated = is_interpolated
|
||
elif isinstance(data, Config):
|
||
self.is_interpolated = data.is_interpolated
|
||
else:
|
||
self.is_interpolated = True
|
||
if section_order is not None:
|
||
self.section_order = section_order
|
||
elif isinstance(data, Config):
|
||
self.section_order = data.section_order
|
||
else:
|
||
self.section_order = []
|
||
# Update with data
|
||
self.update(self._sort(data))
|
||
|
||
def interpolate(self) -> "Config":
|
||
"""Interpolate a config. Returns a copy of the object."""
|
||
# This is currently the most effective way because we need our custom
|
||
# to_str logic to run in order to re-serialize the values so we can
|
||
# interpolate them again. ConfigParser.read_dict will just call str()
|
||
# on all values, which isn't enough.
|
||
return Config().from_str(self.to_str())
|
||
|
||
def interpret_config(self, config: "ConfigParser") -> None:
|
||
"""Interpret a config, parse nested sections and parse the values
|
||
as JSON. Mostly used internally and modifies the config in place.
|
||
"""
|
||
self._validate_sections(config)
|
||
# Sort sections by depth, so that we can iterate breadth-first. This
|
||
# allows us to check that we're not expanding an undefined block.
|
||
get_depth = lambda item: len(item[0].split("."))
|
||
for section, values in sorted(config.items(), key=get_depth):
|
||
if section == "DEFAULT":
|
||
# Skip [DEFAULT] section so it doesn't cause validation error
|
||
continue
|
||
parts = section.split(".")
|
||
node = self
|
||
for part in parts[:-1]:
|
||
if part == "*":
|
||
node = node.setdefault(part, {})
|
||
elif part not in node:
|
||
err_title = (
|
||
"Error parsing config section. Perhaps a section name is wrong?"
|
||
)
|
||
err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}]
|
||
raise ConfigValidationError(
|
||
config=self, errors=err, title=err_title
|
||
)
|
||
else:
|
||
node = node[part]
|
||
if not isinstance(node, dict):
|
||
# Happens if both value *and* subsection were defined for a key
|
||
err = [{"loc": parts, "msg": "found conflicting values"}]
|
||
err_cfg = f"{self}\n{({part: dict(values)})}"
|
||
raise ConfigValidationError(config=err_cfg, errors=err)
|
||
# Set the default section
|
||
node = node.setdefault(parts[-1], {})
|
||
if not isinstance(node, dict):
|
||
# Happens if both value *and* subsection were defined for a key
|
||
err = [{"loc": parts, "msg": "found conflicting values"}]
|
||
err_cfg = f"{self}\n{({part: dict(values)})}"
|
||
raise ConfigValidationError(config=err_cfg, errors=err)
|
||
try:
|
||
keys_values = list(values.items())
|
||
except InterpolationMissingOptionError as e:
|
||
raise ConfigValidationError(desc=f"{e}") from None
|
||
for key, value in keys_values:
|
||
config_v = config.get(section, key)
|
||
node[key] = self._interpret_value(config_v)
|
||
self.replace_section_refs(self)
|
||
|
||
def replace_section_refs(
|
||
self, config: Union[Dict[str, Any], "Config"], parent: str = ""
|
||
) -> None:
|
||
"""Replace references to section blocks in the final config."""
|
||
for key, value in config.items():
|
||
key_parent = f"{parent}.{key}".strip(".")
|
||
if isinstance(value, dict):
|
||
self.replace_section_refs(value, parent=key_parent)
|
||
elif isinstance(value, list):
|
||
config[key] = [
|
||
self._get_section_ref(v, parent=[parent, key]) for v in value
|
||
]
|
||
else:
|
||
config[key] = self._get_section_ref(value, parent=[parent, key])
|
||
|
||
def _interpret_value(self, value: Any) -> Any:
|
||
"""Interpret a single config value."""
|
||
result = try_load_json(value)
|
||
# If value is a string and it contains a variable, use original value
|
||
# (not interpreted string, which could lead to double quotes:
|
||
# ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string,
|
||
# so we're not keeping lists as strings.
|
||
# NOTE: This currently can't handle uninterpolated values like [${x.y}]!
|
||
if isinstance(result, str) and VARIABLE_RE.search(value):
|
||
result = value
|
||
return result
|
||
|
||
def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any:
|
||
"""Get a single section reference."""
|
||
if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'):
|
||
value = try_load_json(value)
|
||
if isinstance(value, str) and value.startswith(SECTION_PREFIX):
|
||
parts = value.replace(SECTION_PREFIX, "").split(".")
|
||
result = self
|
||
for item in parts:
|
||
try:
|
||
result = result[item]
|
||
except (KeyError, TypeError): # This should never happen
|
||
err_title = "Error parsing reference to config section"
|
||
err_msg = f"Section '{'.'.join(parts)}' is not defined"
|
||
err = [{"loc": parts, "msg": err_msg}]
|
||
raise ConfigValidationError(
|
||
config=self, errors=err, title=err_title
|
||
) from None
|
||
return result
|
||
elif isinstance(value, str) and SECTION_PREFIX in value:
|
||
# String value references a section (either a dict or return
|
||
# value of promise). We can't allow this, since variables are
|
||
# always interpolated *before* configs are resolved.
|
||
err_desc = (
|
||
"Can't reference whole sections or return values of function "
|
||
"blocks inside a string or list\n\nYou can change your variable to "
|
||
"reference a value instead. Keep in mind that it's not "
|
||
"possible to interpolate the return value of a registered "
|
||
"function, since variables are interpolated when the config "
|
||
"is loaded, and registered functions are resolved afterwards."
|
||
)
|
||
err = [{"loc": parent, "msg": "uses section variable in string or list"}]
|
||
raise ConfigValidationError(errors=err, desc=err_desc)
|
||
return value
|
||
|
||
def copy(self) -> "Config":
|
||
"""Deepcopy the config."""
|
||
try:
|
||
config = copy.deepcopy(self)
|
||
except Exception as e:
|
||
raise ValueError(f"Couldn't deep-copy config: {e}") from e
|
||
return Config(
|
||
config,
|
||
is_interpolated=self.is_interpolated,
|
||
section_order=self.section_order,
|
||
)
|
||
|
||
def merge(
|
||
self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False
|
||
) -> "Config":
|
||
"""Deep merge the config with updates, using current as defaults."""
|
||
defaults = self.copy()
|
||
updates = Config(updates).copy()
|
||
merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra)
|
||
return Config(
|
||
merged,
|
||
is_interpolated=defaults.is_interpolated and updates.is_interpolated,
|
||
section_order=defaults.section_order,
|
||
)
|
||
|
||
def _sort(
|
||
self, data: Union["Config", "ConfigParser", Dict[str, Any]]
|
||
) -> Dict[str, Any]:
|
||
"""Sort sections using the currently defined sort order. Sort
|
||
sections by index on section order, if available, then alphabetic, and
|
||
account for subsections, which should always follow their parent.
|
||
"""
|
||
sort_map = {section: i for i, section in enumerate(self.section_order)}
|
||
sort_key = lambda x: (
|
||
sort_map.get(x[0].split(".")[0], len(sort_map)),
|
||
_mask_positional_args(x[0]),
|
||
)
|
||
return dict(sorted(data.items(), key=sort_key))
|
||
|
||
def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None:
|
||
"""Set overrides in the ConfigParser before config is interpreted."""
|
||
err_title = "Error parsing config overrides"
|
||
for key, value in overrides.items():
|
||
err_msg = "not a section value that can be overridden"
|
||
err = [{"loc": key.split("."), "msg": err_msg}]
|
||
if "." not in key:
|
||
raise ConfigValidationError(errors=err, title=err_title)
|
||
section, option = key.rsplit(".", 1)
|
||
# Check for section and accept if option not in config[section]
|
||
if section not in config:
|
||
raise ConfigValidationError(errors=err, title=err_title)
|
||
config.set(section, option, try_dump_json(value, overrides))
|
||
|
||
def _validate_sections(self, config: "ConfigParser") -> None:
|
||
# If the config defines top-level properties that are not sections (e.g.
|
||
# if config was constructed from dict), those values would be added as
|
||
# [DEFAULTS] and included in *every other section*. This is usually not
|
||
# what we want and it can lead to very confusing results.
|
||
default_section = config.defaults()
|
||
if default_section:
|
||
err_title = "Found config values without a top-level section"
|
||
err_msg = "not part of a section"
|
||
err = [{"loc": [k], "msg": err_msg} for k in default_section]
|
||
raise ConfigValidationError(errors=err, title=err_title)
|
||
|
||
def from_str(
|
||
self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {}
|
||
) -> "Config":
|
||
"""Load the config from a string."""
|
||
config = get_configparser(interpolate=interpolate)
|
||
if overrides:
|
||
config = get_configparser(interpolate=False)
|
||
try:
|
||
config.read_string(text)
|
||
except ParsingError as e:
|
||
desc = f"Make sure the sections and values are formatted correctly.\n\n{e}"
|
||
raise ConfigValidationError(desc=desc) from None
|
||
config._sections = self._sort(config._sections)
|
||
self._set_overrides(config, overrides)
|
||
self.clear()
|
||
self.interpret_config(config)
|
||
if overrides and interpolate:
|
||
# do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty
|
||
self = self.interpolate()
|
||
self.is_interpolated = interpolate
|
||
return self
|
||
|
||
def to_str(self, *, interpolate: bool = True) -> str:
|
||
"""Write the config to a string."""
|
||
flattened = get_configparser(interpolate=interpolate)
|
||
queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)]
|
||
for path, node in queue:
|
||
section_name = ".".join(path)
|
||
is_kwarg = path and path[-1] != "*"
|
||
if is_kwarg and not flattened.has_section(section_name):
|
||
# Always create sections for non-'*' sections, not only if
|
||
# they have leaf entries, as we don't want to expand
|
||
# blocks that are undefined
|
||
flattened.add_section(section_name)
|
||
for key, value in node.items():
|
||
if hasattr(value, "items"):
|
||
# Reference to a function with no arguments, serialize
|
||
# inline as a dict and don't create new section
|
||
if registry.is_promise(value) and len(value) == 1 and is_kwarg:
|
||
flattened.set(section_name, key, try_dump_json(value, node))
|
||
else:
|
||
queue.append((path + (key,), value))
|
||
else:
|
||
flattened.set(section_name, key, try_dump_json(value, node))
|
||
# Order so subsection follow parent (not all sections, then all subs etc.)
|
||
flattened._sections = self._sort(flattened._sections)
|
||
self._validate_sections(flattened)
|
||
string_io = io.StringIO()
|
||
flattened.write(string_io)
|
||
return string_io.getvalue().strip()
|
||
|
||
def to_bytes(self, *, interpolate: bool = True) -> bytes:
|
||
"""Serialize the config to a byte string."""
|
||
return self.to_str(interpolate=interpolate).encode("utf8")
|
||
|
||
def from_bytes(
|
||
self,
|
||
bytes_data: bytes,
|
||
*,
|
||
interpolate: bool = True,
|
||
overrides: Dict[str, Any] = {},
|
||
) -> "Config":
|
||
"""Load the config from a byte string."""
|
||
return self.from_str(
|
||
bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides
|
||
)
|
||
|
||
def to_disk(self, path: Union[str, Path], *, interpolate: bool = True):
|
||
"""Serialize the config to a file."""
|
||
path = Path(path) if isinstance(path, str) else path
|
||
with path.open("w", encoding="utf8") as file_:
|
||
file_.write(self.to_str(interpolate=interpolate))
|
||
|
||
def from_disk(
|
||
self,
|
||
path: Union[str, Path],
|
||
*,
|
||
interpolate: bool = True,
|
||
overrides: Dict[str, Any] = {},
|
||
) -> "Config":
|
||
"""Load config from a file."""
|
||
path = Path(path) if isinstance(path, str) else path
|
||
with path.open("r", encoding="utf8") as file_:
|
||
text = file_.read()
|
||
return self.from_str(text, interpolate=interpolate, overrides=overrides)
|
||
|
||
|
||
def _mask_positional_args(name: str) -> List[Optional[str]]:
|
||
"""Create a section name representation that masks names
|
||
of positional arguments to retain their order in sorts."""
|
||
|
||
stable_name = cast(List[Optional[str]], name.split("."))
|
||
|
||
# Remove names of sections that are a positional argument.
|
||
for i in range(1, len(stable_name)):
|
||
if stable_name[i - 1] == "*":
|
||
stable_name[i] = None
|
||
|
||
return stable_name
|
||
|
||
|
||
def try_load_json(value: str) -> Any:
|
||
"""Load a JSON string if possible, otherwise default to original value."""
|
||
try:
|
||
return srsly.json_loads(value)
|
||
except Exception:
|
||
return value
|
||
|
||
|
||
def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str:
|
||
"""Dump a config value as JSON and output user-friendly error if it fails."""
|
||
# Special case if we have a variable: it's already a string so don't dump
|
||
# to preserve ${x:y} vs. "${x:y}"
|
||
if isinstance(value, str) and VARIABLE_RE.search(value):
|
||
return value
|
||
if isinstance(value, str) and value.replace(".", "", 1).isdigit():
|
||
# Work around values that are strings but numbers
|
||
value = f'"{value}"'
|
||
try:
|
||
value = srsly.json_dumps(value)
|
||
value = re.sub(r"\$([^{])", "$$\1", value)
|
||
value = re.sub(r"\$$", "$$", value)
|
||
return value
|
||
except Exception as e:
|
||
err_msg = (
|
||
f"Couldn't serialize config value of type {type(value)}: {e}. Make "
|
||
f"sure all values in your config are JSON-serializable. If you want "
|
||
f"to include Python objects, use a registered function that returns "
|
||
f"the object instead."
|
||
)
|
||
raise ConfigValidationError(config=data, desc=err_msg) from e
|
||
|
||
|
||
def deep_merge_configs(
|
||
config: Union[Dict[str, Any], Config],
|
||
defaults: Union[Dict[str, Any], Config],
|
||
*,
|
||
remove_extra: bool = False,
|
||
) -> Union[Dict[str, Any], Config]:
|
||
"""Deep merge two configs."""
|
||
if remove_extra:
|
||
# Filter out values in the original config that are not in defaults
|
||
keys = list(config.keys())
|
||
for key in keys:
|
||
if key not in defaults:
|
||
del config[key]
|
||
for key, value in defaults.items():
|
||
if isinstance(value, dict):
|
||
node = config.setdefault(key, {})
|
||
if not isinstance(node, dict):
|
||
continue
|
||
value_promises = [k for k in value if k.startswith("@")]
|
||
value_promise = value_promises[0] if value_promises else None
|
||
node_promises = [k for k in node if k.startswith("@")] if node else []
|
||
node_promise = node_promises[0] if node_promises else None
|
||
# We only update the block from defaults if it refers to the same
|
||
# registered function
|
||
if (
|
||
value_promise
|
||
and node_promise
|
||
and (
|
||
value_promise in node
|
||
and node[value_promise] != value[value_promise]
|
||
)
|
||
):
|
||
continue
|
||
if node_promise and (
|
||
node_promise not in value or node[node_promise] != value[node_promise]
|
||
):
|
||
continue
|
||
defaults = deep_merge_configs(node, value, remove_extra=remove_extra)
|
||
elif key not in config:
|
||
config[key] = value
|
||
return config
|
||
|
||
|
||
class ConfigValidationError(ValueError):
|
||
def __init__(
|
||
self,
|
||
*,
|
||
config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None,
|
||
errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(),
|
||
title: Optional[str] = "Config validation error",
|
||
desc: Optional[str] = None,
|
||
parent: Optional[str] = None,
|
||
show_config: bool = True,
|
||
) -> None:
|
||
"""Custom error for validating configs.
|
||
|
||
config (Union[Config, Dict[str, Dict[str, Any]], str]): The
|
||
config the validation error refers to.
|
||
errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]):
|
||
A list of errors as dicts with keys "loc" (list of strings
|
||
describing the path of the value), "msg" (validation message
|
||
to show) and optional "type" (mostly internals).
|
||
Same format as produced by pydantic's validation error (e.errors()).
|
||
title (str): The error title.
|
||
desc (str): Optional error description, displayed below the title.
|
||
parent (str): Optional parent to use as prefix for all error locations.
|
||
For example, parent "element" will result in "element -> a -> b".
|
||
show_config (bool): Whether to print the whole config with the error.
|
||
|
||
ATTRIBUTES:
|
||
config (Union[Config, Dict[str, Dict[str, Any]], str]): The config.
|
||
errors (Iterable[Dict[str, Any]]): The errors.
|
||
error_types (Set[str]): All "type" values defined in the errors, if
|
||
available. This is most relevant for the pydantic errors that define
|
||
types like "type_error.integer". This attribute makes it easy to
|
||
check if a config validation error includes errors of a certain
|
||
type, e.g. to log additional information or custom help messages.
|
||
title (str): The title.
|
||
desc (str): The description.
|
||
parent (str): The parent.
|
||
show_config (bool): Whether to show the config.
|
||
text (str): The formatted error text.
|
||
"""
|
||
self.config = config
|
||
self.errors = errors
|
||
self.title = title
|
||
self.desc = desc
|
||
self.parent = parent
|
||
self.show_config = show_config
|
||
self.error_types = set()
|
||
for error in self.errors:
|
||
err_type = error.get("type")
|
||
if err_type:
|
||
self.error_types.add(err_type)
|
||
self.text = self._format()
|
||
ValueError.__init__(self, self.text)
|
||
|
||
@classmethod
|
||
def from_error(
|
||
cls,
|
||
err: "ConfigValidationError",
|
||
title: Optional[str] = None,
|
||
desc: Optional[str] = None,
|
||
parent: Optional[str] = None,
|
||
show_config: Optional[bool] = None,
|
||
) -> "ConfigValidationError":
|
||
"""Create a new ConfigValidationError based on an existing error, e.g.
|
||
to re-raise it with different settings. If no overrides are provided,
|
||
the values from the original error are used.
|
||
|
||
err (ConfigValidationError): The original error.
|
||
title (str): Overwrite error title.
|
||
desc (str): Overwrite error description.
|
||
parent (str): Overwrite error parent.
|
||
show_config (bool): Overwrite whether to show config.
|
||
RETURNS (ConfigValidationError): The new error.
|
||
"""
|
||
return cls(
|
||
config=err.config,
|
||
errors=err.errors,
|
||
title=title if title is not None else err.title,
|
||
desc=desc if desc is not None else err.desc,
|
||
parent=parent if parent is not None else err.parent,
|
||
show_config=show_config if show_config is not None else err.show_config,
|
||
)
|
||
|
||
def _format(self) -> str:
|
||
"""Format the error message."""
|
||
loc_divider = "->"
|
||
data = []
|
||
for error in self.errors:
|
||
err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])])
|
||
if self.parent:
|
||
err_loc = f"{self.parent} {loc_divider} {err_loc}"
|
||
data.append((err_loc, error.get("msg")))
|
||
result = []
|
||
if self.title:
|
||
result.append(self.title)
|
||
if self.desc:
|
||
result.append(self.desc)
|
||
if data:
|
||
result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data]))
|
||
if self.config and self.show_config:
|
||
result.append(f"{self.config}")
|
||
return "\n\n" + "\n".join(result)
|
||
|
||
|
||
def alias_generator(name: str) -> str:
|
||
"""Generate field aliases in promise schema."""
|
||
# Underscore fields are not allowed in model, so use alias
|
||
if name == ARGS_FIELD_ALIAS:
|
||
return ARGS_FIELD
|
||
# Auto-alias fields that shadow base model attributes
|
||
if name in RESERVED_FIELDS:
|
||
return RESERVED_FIELDS[name]
|
||
return name
|
||
|
||
|
||
def copy_model_field(field: ModelField, type_: Any) -> ModelField:
|
||
"""Copy a model field and assign a new type, e.g. to accept an Any type
|
||
even though the original value is typed differently.
|
||
"""
|
||
return ModelField(
|
||
name=field.name,
|
||
type_=type_,
|
||
class_validators=field.class_validators,
|
||
model_config=field.model_config,
|
||
default=field.default,
|
||
default_factory=field.default_factory,
|
||
required=field.required,
|
||
)
|
||
|
||
|
||
class EmptySchema(BaseModel):
|
||
class Config:
|
||
extra = "allow"
|
||
arbitrary_types_allowed = True
|
||
|
||
|
||
class _PromiseSchemaConfig:
|
||
extra = "forbid"
|
||
arbitrary_types_allowed = True
|
||
alias_generator = alias_generator
|
||
|
||
|
||
@dataclass
|
||
class Promise:
|
||
registry: str
|
||
name: str
|
||
args: List[str]
|
||
kwargs: Dict[str, Any]
|
||
|
||
|
||
class registry:
|
||
@classmethod
|
||
def has(cls, registry_name: str, func_name: str) -> bool:
|
||
"""Check whether a function is available in a registry."""
|
||
if not hasattr(cls, registry_name):
|
||
return False
|
||
reg = getattr(cls, registry_name)
|
||
return func_name in reg
|
||
|
||
@classmethod
|
||
def get(cls, registry_name: str, func_name: str) -> Callable:
|
||
"""Get a registered function from a given registry."""
|
||
if not hasattr(cls, registry_name):
|
||
raise ValueError(f"Unknown registry: '{registry_name}'")
|
||
reg = getattr(cls, registry_name)
|
||
func = reg.get(func_name)
|
||
if func is None:
|
||
raise ValueError(f"Could not find '{func_name}' in '{registry_name}'")
|
||
return func
|
||
|
||
@classmethod
|
||
def resolve(
|
||
cls,
|
||
config: Union[Config, Dict[str, Dict[str, Any]]],
|
||
*,
|
||
schema: Type[BaseModel] = EmptySchema,
|
||
overrides: Dict[str, Any] = {},
|
||
validate: bool = True,
|
||
) -> Dict[str, Any]:
|
||
resolved, _ = cls._make(
|
||
config, schema=schema, overrides=overrides, validate=validate, resolve=True
|
||
)
|
||
return resolved
|
||
|
||
@classmethod
|
||
def fill(
|
||
cls,
|
||
config: Union[Config, Dict[str, Dict[str, Any]]],
|
||
*,
|
||
schema: Type[BaseModel] = EmptySchema,
|
||
overrides: Dict[str, Any] = {},
|
||
validate: bool = True,
|
||
):
|
||
_, filled = cls._make(
|
||
config, schema=schema, overrides=overrides, validate=validate, resolve=False
|
||
)
|
||
return filled
|
||
|
||
@classmethod
|
||
def _make(
|
||
cls,
|
||
config: Union[Config, Dict[str, Dict[str, Any]]],
|
||
*,
|
||
schema: Type[BaseModel] = EmptySchema,
|
||
overrides: Dict[str, Any] = {},
|
||
resolve: bool = True,
|
||
validate: bool = True,
|
||
) -> Tuple[Dict[str, Any], Config]:
|
||
"""Unpack a config dictionary and create two versions of the config:
|
||
a resolved version with objects from the registry created recursively,
|
||
and a filled version with all references to registry functions left
|
||
intact, but filled with all values and defaults based on the type
|
||
annotations. If validate=True, the config will be validated against the
|
||
type annotations of the registered functions referenced in the config
|
||
(if available) and/or the schema (if available).
|
||
"""
|
||
# Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}}
|
||
# Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0}
|
||
if cls.is_promise(config):
|
||
err_msg = "The top-level config object can't be a reference to a registered function."
|
||
raise ConfigValidationError(config=config, errors=[{"msg": err_msg}])
|
||
# If a Config was loaded with interpolate=False, we assume it needs to
|
||
# be interpolated first, otherwise we take it at face value
|
||
is_interpolated = not isinstance(config, Config) or config.is_interpolated
|
||
section_order = config.section_order if isinstance(config, Config) else None
|
||
orig_config = config
|
||
if not is_interpolated:
|
||
config = Config(orig_config).interpolate()
|
||
filled, _, resolved = cls._fill(
|
||
config, schema, validate=validate, overrides=overrides, resolve=resolve
|
||
)
|
||
filled = Config(filled, section_order=section_order)
|
||
# Check that overrides didn't include invalid properties not in config
|
||
if validate:
|
||
cls._validate_overrides(filled, overrides)
|
||
# Merge the original config back to preserve variables if we started
|
||
# with a config that wasn't interpolated. Here, we prefer variables to
|
||
# allow auto-filling a non-interpolated config without destroying
|
||
# variable references.
|
||
if not is_interpolated:
|
||
filled = filled.merge(
|
||
Config(orig_config, is_interpolated=False), remove_extra=True
|
||
)
|
||
return dict(resolved), filled
|
||
|
||
@classmethod
|
||
def _fill(
|
||
cls,
|
||
config: Union[Config, Dict[str, Dict[str, Any]]],
|
||
schema: Type[BaseModel] = EmptySchema,
|
||
*,
|
||
validate: bool = True,
|
||
resolve: bool = True,
|
||
parent: str = "",
|
||
overrides: Dict[str, Dict[str, Any]] = {},
|
||
) -> Tuple[
|
||
Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any]
|
||
]:
|
||
"""Build three representations of the config:
|
||
1. All promises are preserved (just like config user would provide).
|
||
2. Promises are replaced by their return values. This is the validation
|
||
copy and will be parsed by pydantic. It lets us include hacks to
|
||
work around problems (e.g. handling of generators).
|
||
3. Final copy with promises replaced by their return values.
|
||
"""
|
||
filled: Dict[str, Any] = {}
|
||
validation: Dict[str, Any] = {}
|
||
final: Dict[str, Any] = {}
|
||
for key, value in config.items():
|
||
# If the field name is reserved, we use its alias for validation
|
||
v_key = RESERVED_FIELDS.get(key, key)
|
||
key_parent = f"{parent}.{key}".strip(".")
|
||
if key_parent in overrides:
|
||
value = overrides[key_parent]
|
||
config[key] = value
|
||
if cls.is_promise(value):
|
||
if key in schema.__fields__ and not resolve:
|
||
# If we're not resolving the config, make sure that the field
|
||
# expecting the promise is typed Any so it doesn't fail
|
||
# validation if it doesn't receive the function return value
|
||
field = schema.__fields__[key]
|
||
schema.__fields__[key] = copy_model_field(field, Any)
|
||
promise_schema = cls.make_promise_schema(value, resolve=resolve)
|
||
filled[key], validation[v_key], final[key] = cls._fill(
|
||
value,
|
||
promise_schema,
|
||
validate=validate,
|
||
resolve=resolve,
|
||
parent=key_parent,
|
||
overrides=overrides,
|
||
)
|
||
reg_name, func_name = cls.get_constructor(final[key])
|
||
args, kwargs = cls.parse_args(final[key])
|
||
if resolve:
|
||
# Call the function and populate the field value. We can't
|
||
# just create an instance of the type here, since this
|
||
# wouldn't work for generics / more complex custom types
|
||
getter = cls.get(reg_name, func_name)
|
||
# We don't want to try/except this and raise our own error
|
||
# here, because we want the traceback if the function fails.
|
||
getter_result = getter(*args, **kwargs)
|
||
else:
|
||
# We're not resolving and calling the function, so replace
|
||
# the getter_result with a Promise class
|
||
getter_result = Promise(
|
||
registry=reg_name, name=func_name, args=args, kwargs=kwargs
|
||
)
|
||
validation[v_key] = getter_result
|
||
final[key] = getter_result
|
||
if isinstance(validation[v_key], GeneratorType):
|
||
# If value is a generator we can't validate type without
|
||
# consuming it (which doesn't work if it's infinite – see
|
||
# schedule for examples). So we skip it.
|
||
validation[v_key] = []
|
||
elif hasattr(value, "items"):
|
||
field_type = EmptySchema
|
||
if key in schema.__fields__:
|
||
field = schema.__fields__[key]
|
||
field_type = field.type_
|
||
if not isinstance(field.type_, ModelMetaclass):
|
||
# If we don't have a pydantic schema and just a type
|
||
field_type = EmptySchema
|
||
filled[key], validation[v_key], final[key] = cls._fill(
|
||
value,
|
||
field_type,
|
||
validate=validate,
|
||
resolve=resolve,
|
||
parent=key_parent,
|
||
overrides=overrides,
|
||
)
|
||
if key == ARGS_FIELD and isinstance(validation[v_key], dict):
|
||
# If the value of variable positional args is a dict (e.g.
|
||
# created via config blocks), only use its values
|
||
validation[v_key] = list(validation[v_key].values())
|
||
final[key] = list(final[key].values())
|
||
else:
|
||
filled[key] = value
|
||
# Prevent pydantic from consuming generator if part of a union
|
||
validation[v_key] = (
|
||
value if not isinstance(value, GeneratorType) else []
|
||
)
|
||
final[key] = value
|
||
# Now that we've filled in all of the promises, update with defaults
|
||
# from schema, and validate if validation is enabled
|
||
exclude = []
|
||
if validate:
|
||
try:
|
||
result = schema.parse_obj(validation)
|
||
except ValidationError as e:
|
||
raise ConfigValidationError(
|
||
config=config, errors=e.errors(), parent=parent
|
||
) from None
|
||
else:
|
||
# Same as parse_obj, but without validation
|
||
result = schema.construct(**validation)
|
||
# If our schema doesn't allow extra values, we need to filter them
|
||
# manually because .construct doesn't parse anything
|
||
if schema.Config.extra in (Extra.forbid, Extra.ignore):
|
||
fields = schema.__fields__.keys()
|
||
exclude = [k for k in result.__fields_set__ if k not in fields]
|
||
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
|
||
validation.update(result.dict(exclude=exclude_validation))
|
||
filled, final = cls._update_from_parsed(validation, filled, final)
|
||
if exclude:
|
||
filled = {k: v for k, v in filled.items() if k not in exclude}
|
||
validation = {k: v for k, v in validation.items() if k not in exclude}
|
||
final = {k: v for k, v in final.items() if k not in exclude}
|
||
return filled, validation, final
|
||
|
||
@classmethod
|
||
def _update_from_parsed(
|
||
cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any]
|
||
):
|
||
"""Update the final result with the parsed config like converted
|
||
values recursively.
|
||
"""
|
||
for key, value in validation.items():
|
||
if key in RESERVED_FIELDS.values():
|
||
continue # skip aliases for reserved fields
|
||
if key not in filled:
|
||
filled[key] = value
|
||
if key not in final:
|
||
final[key] = value
|
||
if isinstance(value, dict):
|
||
filled[key], final[key] = cls._update_from_parsed(
|
||
value, filled[key], final[key]
|
||
)
|
||
# Update final config with parsed value if they're not equal (in
|
||
# value and in type) but not if it's a generator because we had to
|
||
# replace that to validate it correctly
|
||
elif key == ARGS_FIELD:
|
||
continue # don't substitute if list of positional args
|
||
# Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched.
|
||
elif str(type(value)) == "<class 'numpy.ndarray'>":
|
||
final[key] = value
|
||
elif (
|
||
value != final[key] or not isinstance(type(value), type(final[key]))
|
||
) and not isinstance(final[key], GeneratorType):
|
||
final[key] = value
|
||
return filled, final
|
||
|
||
@classmethod
|
||
def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]):
|
||
"""Validate overrides against a filled config to make sure there are
|
||
no references to properties that don't exist and weren't used."""
|
||
error_msg = "Invalid override: config value doesn't exist"
|
||
errors = []
|
||
for override_key in overrides.keys():
|
||
if not cls._is_in_config(override_key, filled):
|
||
errors.append({"msg": error_msg, "loc": [override_key]})
|
||
if errors:
|
||
raise ConfigValidationError(config=filled, errors=errors)
|
||
|
||
@classmethod
|
||
def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]):
|
||
"""Check whether a nested config property like "section.subsection.key"
|
||
is in a given config."""
|
||
tree = prop.split(".")
|
||
obj = dict(config)
|
||
while tree:
|
||
key = tree.pop(0)
|
||
if isinstance(obj, dict) and key in obj:
|
||
obj = obj[key]
|
||
else:
|
||
return False
|
||
return True
|
||
|
||
@classmethod
|
||
def is_promise(cls, obj: Any) -> bool:
|
||
"""Check whether an object is a "promise", i.e. contains a reference
|
||
to a registered function (via a key starting with `"@"`.
|
||
"""
|
||
if not hasattr(obj, "keys"):
|
||
return False
|
||
id_keys = [k for k in obj.keys() if k.startswith("@")]
|
||
if len(id_keys):
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]:
|
||
id_keys = [k for k in obj.keys() if k.startswith("@")]
|
||
if len(id_keys) != 1:
|
||
err_msg = f"A block can only contain one function registry reference. Got: {id_keys}"
|
||
raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}])
|
||
else:
|
||
key = id_keys[0]
|
||
value = obj[key]
|
||
return (key[1:], value)
|
||
|
||
@classmethod
|
||
def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]:
|
||
args = []
|
||
kwargs = {}
|
||
for key, value in obj.items():
|
||
if not key.startswith("@"):
|
||
if key == ARGS_FIELD:
|
||
args = value
|
||
elif key in RESERVED_FIELDS.values():
|
||
continue
|
||
else:
|
||
kwargs[key] = value
|
||
return args, kwargs
|
||
|
||
@classmethod
|
||
def make_promise_schema(
|
||
cls, obj: Dict[str, Any], *, resolve: bool = True
|
||
) -> Type[BaseModel]:
|
||
"""Create a schema for a promise dict (referencing a registry function)
|
||
by inspecting the function signature.
|
||
"""
|
||
reg_name, func_name = cls.get_constructor(obj)
|
||
if not resolve and not cls.has(reg_name, func_name):
|
||
return EmptySchema
|
||
func = cls.get(reg_name, func_name)
|
||
# Read the argument annotations and defaults from the function signature
|
||
id_keys = [k for k in obj.keys() if k.startswith("@")]
|
||
sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)}
|
||
for param in inspect.signature(func).parameters.values():
|
||
# If no annotation is specified assume it's anything
|
||
annotation = param.annotation if param.annotation != param.empty else Any
|
||
# If no default value is specified assume that it's required
|
||
default = param.default if param.default != param.empty else ...
|
||
# Handle spread arguments and use their annotation as Sequence[whatever]
|
||
if param.kind == param.VAR_POSITIONAL:
|
||
spread_annot = Sequence[annotation] # type: ignore
|
||
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default)
|
||
else:
|
||
name = RESERVED_FIELDS.get(param.name, param.name)
|
||
sig_args[name] = (annotation, default)
|
||
sig_args["__config__"] = _PromiseSchemaConfig
|
||
return create_model("ArgModel", **sig_args)
|
||
|
||
|
||
__all__ = [
|
||
"Config",
|
||
"registry",
|
||
"ConfigValidationError",
|
||
"SimpleFrozenDict",
|
||
"SimpleFrozenList",
|
||
]
|