335 lines
12 KiB
Python
335 lines
12 KiB
Python
|
from __future__ import annotations as _annotations
|
||
|
|
||
|
import warnings
|
||
|
from contextlib import contextmanager
|
||
|
from typing import (
|
||
|
TYPE_CHECKING,
|
||
|
Any,
|
||
|
Callable,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
from pydantic_core import core_schema
|
||
|
from typing_extensions import (
|
||
|
Literal,
|
||
|
Self,
|
||
|
)
|
||
|
|
||
|
from ..aliases import AliasGenerator
|
||
|
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||
|
from ..errors import PydanticUserError
|
||
|
from ..warnings import PydanticDeprecatedSince20
|
||
|
|
||
|
if not TYPE_CHECKING:
|
||
|
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||
|
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||
|
DeprecationWarning = PydanticDeprecatedSince20
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from .._internal._schema_generation_shared import GenerateSchema
|
||
|
|
||
|
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||
|
|
||
|
|
||
|
class ConfigWrapper:
|
||
|
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
|
||
|
|
||
|
__slots__ = ('config_dict',)
|
||
|
|
||
|
config_dict: ConfigDict
|
||
|
|
||
|
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
|
||
|
# stop matching
|
||
|
title: str | None
|
||
|
str_to_lower: bool
|
||
|
str_to_upper: bool
|
||
|
str_strip_whitespace: bool
|
||
|
str_min_length: int
|
||
|
str_max_length: int | None
|
||
|
extra: ExtraValues | None
|
||
|
frozen: bool
|
||
|
populate_by_name: bool
|
||
|
use_enum_values: bool
|
||
|
validate_assignment: bool
|
||
|
arbitrary_types_allowed: bool
|
||
|
from_attributes: bool
|
||
|
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||
|
# to construct error `loc`s, default `True`
|
||
|
loc_by_alias: bool
|
||
|
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||
|
ignored_types: tuple[type, ...]
|
||
|
allow_inf_nan: bool
|
||
|
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||
|
json_encoders: dict[type[object], JsonEncoder] | None
|
||
|
|
||
|
# new in V2
|
||
|
strict: bool
|
||
|
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||
|
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||
|
ser_json_timedelta: Literal['iso8601', 'float']
|
||
|
ser_json_bytes: Literal['utf8', 'base64']
|
||
|
ser_json_inf_nan: Literal['null', 'constants']
|
||
|
# whether to validate default values during validation, default False
|
||
|
validate_default: bool
|
||
|
validate_return: bool
|
||
|
protected_namespaces: tuple[str, ...]
|
||
|
hide_input_in_errors: bool
|
||
|
defer_build: bool
|
||
|
plugin_settings: dict[str, object] | None
|
||
|
schema_generator: type[GenerateSchema] | None
|
||
|
json_schema_serialization_defaults_required: bool
|
||
|
json_schema_mode_override: Literal['validation', 'serialization', None]
|
||
|
coerce_numbers_to_str: bool
|
||
|
regex_engine: Literal['rust-regex', 'python-re']
|
||
|
validation_error_cause: bool
|
||
|
use_attribute_docstrings: bool
|
||
|
cache_strings: bool | Literal['all', 'keys', 'none']
|
||
|
|
||
|
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||
|
if check:
|
||
|
self.config_dict = prepare_config(config)
|
||
|
else:
|
||
|
self.config_dict = cast(ConfigDict, config)
|
||
|
|
||
|
@classmethod
|
||
|
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
|
||
|
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
|
||
|
|
||
|
The config wrapper built based on (in descending order of priority):
|
||
|
- options from `kwargs`
|
||
|
- options from the `namespace`
|
||
|
- options from the base classes (`bases`)
|
||
|
|
||
|
Args:
|
||
|
bases: A tuple of base classes.
|
||
|
namespace: The namespace of the class being created.
|
||
|
kwargs: The kwargs passed to the class being created.
|
||
|
|
||
|
Returns:
|
||
|
A `ConfigWrapper` instance for `BaseModel`.
|
||
|
"""
|
||
|
config_new = ConfigDict()
|
||
|
for base in bases:
|
||
|
config = getattr(base, 'model_config', None)
|
||
|
if config:
|
||
|
config_new.update(config.copy())
|
||
|
|
||
|
config_class_from_namespace = namespace.get('Config')
|
||
|
config_dict_from_namespace = namespace.get('model_config')
|
||
|
|
||
|
raw_annotations = namespace.get('__annotations__', {})
|
||
|
if raw_annotations.get('model_config') and not config_dict_from_namespace:
|
||
|
raise PydanticUserError(
|
||
|
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||
|
code='model-config-invalid-field-name',
|
||
|
)
|
||
|
|
||
|
if config_class_from_namespace and config_dict_from_namespace:
|
||
|
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||
|
|
||
|
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
|
||
|
|
||
|
config_new.update(config_from_namespace)
|
||
|
|
||
|
for k in list(kwargs.keys()):
|
||
|
if k in config_keys:
|
||
|
config_new[k] = kwargs.pop(k)
|
||
|
|
||
|
return cls(config_new)
|
||
|
|
||
|
# we don't show `__getattr__` to type checkers so missing attributes cause errors
|
||
|
if not TYPE_CHECKING: # pragma: no branch
|
||
|
|
||
|
def __getattr__(self, name: str) -> Any:
|
||
|
try:
|
||
|
return self.config_dict[name]
|
||
|
except KeyError:
|
||
|
try:
|
||
|
return config_defaults[name]
|
||
|
except KeyError:
|
||
|
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||
|
|
||
|
def core_config(self, obj: Any) -> core_schema.CoreConfig:
|
||
|
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
|
||
|
|
||
|
Pass `obj=None` if you do not want to attempt to infer the `title`.
|
||
|
|
||
|
We don't use getattr here since we don't want to populate with defaults.
|
||
|
|
||
|
Args:
|
||
|
obj: An object used to populate `title` if not set in config.
|
||
|
|
||
|
Returns:
|
||
|
A `CoreConfig` object created from config.
|
||
|
"""
|
||
|
|
||
|
def dict_not_none(**kwargs: Any) -> Any:
|
||
|
return {k: v for k, v in kwargs.items() if v is not None}
|
||
|
|
||
|
core_config = core_schema.CoreConfig(
|
||
|
**dict_not_none(
|
||
|
title=self.config_dict.get('title') or (obj and obj.__name__),
|
||
|
extra_fields_behavior=self.config_dict.get('extra'),
|
||
|
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
|
||
|
populate_by_name=self.config_dict.get('populate_by_name'),
|
||
|
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
|
||
|
str_to_lower=self.config_dict.get('str_to_lower'),
|
||
|
str_to_upper=self.config_dict.get('str_to_upper'),
|
||
|
strict=self.config_dict.get('strict'),
|
||
|
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
|
||
|
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
|
||
|
ser_json_inf_nan=self.config_dict.get('ser_json_inf_nan'),
|
||
|
from_attributes=self.config_dict.get('from_attributes'),
|
||
|
loc_by_alias=self.config_dict.get('loc_by_alias'),
|
||
|
revalidate_instances=self.config_dict.get('revalidate_instances'),
|
||
|
validate_default=self.config_dict.get('validate_default'),
|
||
|
str_max_length=self.config_dict.get('str_max_length'),
|
||
|
str_min_length=self.config_dict.get('str_min_length'),
|
||
|
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
|
||
|
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
|
||
|
regex_engine=self.config_dict.get('regex_engine'),
|
||
|
validation_error_cause=self.config_dict.get('validation_error_cause'),
|
||
|
cache_strings=self.config_dict.get('cache_strings'),
|
||
|
)
|
||
|
)
|
||
|
return core_config
|
||
|
|
||
|
def __repr__(self):
|
||
|
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||
|
return f'ConfigWrapper({c})'
|
||
|
|
||
|
|
||
|
class ConfigWrapperStack:
|
||
|
"""A stack of `ConfigWrapper` instances."""
|
||
|
|
||
|
def __init__(self, config_wrapper: ConfigWrapper):
|
||
|
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
|
||
|
|
||
|
@property
|
||
|
def tail(self) -> ConfigWrapper:
|
||
|
return self._config_wrapper_stack[-1]
|
||
|
|
||
|
@contextmanager
|
||
|
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
|
||
|
if config_wrapper is None:
|
||
|
yield
|
||
|
return
|
||
|
|
||
|
if not isinstance(config_wrapper, ConfigWrapper):
|
||
|
config_wrapper = ConfigWrapper(config_wrapper, check=False)
|
||
|
|
||
|
self._config_wrapper_stack.append(config_wrapper)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self._config_wrapper_stack.pop()
|
||
|
|
||
|
|
||
|
config_defaults = ConfigDict(
|
||
|
title=None,
|
||
|
str_to_lower=False,
|
||
|
str_to_upper=False,
|
||
|
str_strip_whitespace=False,
|
||
|
str_min_length=0,
|
||
|
str_max_length=None,
|
||
|
# let the model / dataclass decide how to handle it
|
||
|
extra=None,
|
||
|
frozen=False,
|
||
|
populate_by_name=False,
|
||
|
use_enum_values=False,
|
||
|
validate_assignment=False,
|
||
|
arbitrary_types_allowed=False,
|
||
|
from_attributes=False,
|
||
|
loc_by_alias=True,
|
||
|
alias_generator=None,
|
||
|
ignored_types=(),
|
||
|
allow_inf_nan=True,
|
||
|
json_schema_extra=None,
|
||
|
strict=False,
|
||
|
revalidate_instances='never',
|
||
|
ser_json_timedelta='iso8601',
|
||
|
ser_json_bytes='utf8',
|
||
|
ser_json_inf_nan='null',
|
||
|
validate_default=False,
|
||
|
validate_return=False,
|
||
|
protected_namespaces=('model_',),
|
||
|
hide_input_in_errors=False,
|
||
|
json_encoders=None,
|
||
|
defer_build=False,
|
||
|
plugin_settings=None,
|
||
|
schema_generator=None,
|
||
|
json_schema_serialization_defaults_required=False,
|
||
|
json_schema_mode_override=None,
|
||
|
coerce_numbers_to_str=False,
|
||
|
regex_engine='rust-regex',
|
||
|
validation_error_cause=False,
|
||
|
use_attribute_docstrings=False,
|
||
|
cache_strings=True,
|
||
|
)
|
||
|
|
||
|
|
||
|
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
|
||
|
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
|
||
|
|
||
|
Args:
|
||
|
config: The input config.
|
||
|
|
||
|
Returns:
|
||
|
A ConfigDict object created from config.
|
||
|
"""
|
||
|
if config is None:
|
||
|
return ConfigDict()
|
||
|
|
||
|
if not isinstance(config, dict):
|
||
|
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||
|
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||
|
|
||
|
config_dict = cast(ConfigDict, config)
|
||
|
check_deprecated(config_dict)
|
||
|
return config_dict
|
||
|
|
||
|
|
||
|
config_keys = set(ConfigDict.__annotations__.keys())
|
||
|
|
||
|
|
||
|
V2_REMOVED_KEYS = {
|
||
|
'allow_mutation',
|
||
|
'error_msg_templates',
|
||
|
'fields',
|
||
|
'getter_dict',
|
||
|
'smart_union',
|
||
|
'underscore_attrs_are_private',
|
||
|
'json_loads',
|
||
|
'json_dumps',
|
||
|
'copy_on_model_validation',
|
||
|
'post_init_call',
|
||
|
}
|
||
|
V2_RENAMED_KEYS = {
|
||
|
'allow_population_by_field_name': 'populate_by_name',
|
||
|
'anystr_lower': 'str_to_lower',
|
||
|
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||
|
'anystr_upper': 'str_to_upper',
|
||
|
'keep_untouched': 'ignored_types',
|
||
|
'max_anystr_length': 'str_max_length',
|
||
|
'min_anystr_length': 'str_min_length',
|
||
|
'orm_mode': 'from_attributes',
|
||
|
'schema_extra': 'json_schema_extra',
|
||
|
'validate_all': 'validate_default',
|
||
|
}
|
||
|
|
||
|
|
||
|
def check_deprecated(config_dict: ConfigDict) -> None:
|
||
|
"""Check for deprecated config keys and warn the user.
|
||
|
|
||
|
Args:
|
||
|
config_dict: The input config.
|
||
|
"""
|
||
|
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
|
||
|
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
|
||
|
if deprecated_removed_keys or deprecated_renamed_keys:
|
||
|
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
|
||
|
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
|
||
|
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
|
||
|
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
|
||
|
warnings.warn(message, UserWarning)
|