"""Utilities related to attribute docstring extraction.""" from __future__ import annotations import ast import inspect import textwrap from typing import Any class DocstringVisitor(ast.NodeVisitor): def __init__(self) -> None: super().__init__() self.target: str | None = None self.attrs: dict[str, str] = {} self.previous_node_type: type[ast.AST] | None = None def visit(self, node: ast.AST) -> Any: node_result = super().visit(node) self.previous_node_type = type(node) return node_result def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: if isinstance(node.target, ast.Name): self.target = node.target.id def visit_Expr(self, node: ast.Expr) -> Any: if ( isinstance(node.value, ast.Constant) and isinstance(node.value.value, str) and self.previous_node_type is ast.AnnAssign ): docstring = inspect.cleandoc(node.value.value) if self.target: self.attrs[self.target] = docstring self.target = None def _dedent_source_lines(source: list[str]) -> str: # Required for nested class definitions, e.g. in a function block dedent_source = textwrap.dedent(''.join(source)) if dedent_source.startswith((' ', '\t')): # We are in the case where there's a dedented (usually multiline) string # at a lower indentation level than the class itself. We wrap our class # in a function as a workaround. dedent_source = f'def dedent_workaround():\n{dedent_source}' return dedent_source def _extract_source_from_frame(cls: type[Any]) -> list[str] | None: frame = inspect.currentframe() while frame: if inspect.getmodule(frame) is inspect.getmodule(cls): lnum = frame.f_lineno try: lines, _ = inspect.findsource(frame) except OSError: # Source can't be retrieved (maybe because running in an interactive terminal), # we don't want to error here. pass else: block_lines = inspect.getblock(lines[lnum - 1 :]) dedent_source = _dedent_source_lines(block_lines) try: block_tree = ast.parse(dedent_source) except SyntaxError: pass else: stmt = block_tree.body[0] if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround': # `_dedent_source_lines` wrapped the class around the workaround function stmt = stmt.body[0] if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__: return block_lines frame = frame.f_back def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]: """Map model attributes and their corresponding docstring. Args: cls: The class of the Pydantic model to inspect. use_inspect: Whether to skip usage of frames to find the object and use the `inspect` module instead. Returns: A mapping containing attribute names and their corresponding docstring. """ if use_inspect: # Might not work as expected if two classes have the same name in the same source file. try: source, _ = inspect.getsourcelines(cls) except OSError: return {} else: source = _extract_source_from_frame(cls) if not source: return {} dedent_source = _dedent_source_lines(source) visitor = DocstringVisitor() visitor.visit(ast.parse(dedent_source)) return visitor.attrs