ai-content-maker/.venv/Lib/site-packages/thinc/mypy.py

267 lines
11 KiB
Python

import itertools
from typing import Dict, List
from mypy.checker import TypeChecker
from mypy.errorcodes import ErrorCode
from mypy.errors import Errors
from mypy.nodes import CallExpr, Decorator, Expression, FuncDef, MypyFile, NameExpr
from mypy.options import Options
from mypy.plugin import CheckerPluginInterface, FunctionContext, Plugin
from mypy.subtypes import is_subtype
from mypy.types import CallableType, Instance, Type, TypeVarType
thinc_model_fullname = "thinc.model.Model"
chained_out_fullname = "thinc.types.XY_YZ_OutT"
intoin_outtoout_out_fullname = "thinc.types.XY_XY_OutT"
def plugin(version: str):
return ThincPlugin
class ThincPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
def get_function_hook(self, fullname: str):
return function_hook
def function_hook(ctx: FunctionContext) -> Type:
try:
return get_reducers_type(ctx)
except AssertionError:
# Add more function callbacks here
return ctx.default_return_type
def get_reducers_type(ctx: FunctionContext) -> Type:
"""
Determine a more specific model type for functions that combine models.
This function operates on function *calls*. It analyzes each function call
by looking at the function definition and the arguments passed as part of
the function call, then determines a more specific return type for the
function call.
This method accepts a `FunctionContext` as part of the Mypy plugin
interface. This function context provides easy access to:
* `args`: List of "actual arguments" filling each "formal argument" of the
called function. "Actual arguments" are those passed to the function
as part of the function call. "Formal arguments" are the parameters
defined by the function definition. The same actual argument may serve
to fill multiple formal arguments. In some cases the relationship may
even be ambiguous. For example, calling `range(*args)`, the actual
argument `*args` may fill the `start`, `stop` or `step` formal
arguments, depending on the length of the list.
The `args` list is of length `num_formals`, with each element
corresponding to a formal argument. Each value in the `args` list is a
list of actual arguments which may fill the formal argument. For
example, in the function call `range(*args, num)`, `num` may fill the
`start`, `end` or `step` formal arguments depending on the length of
`args`, so type-checking needs to consider all of these possibilities.
* `arg_types`: Type annotation (or inferred type) of each argument. Like
`args`, this value is a list of lists with an outer list entry for each
formal argument and an inner list entry for each possible actual
argument for the formal argument.
* `arg_kinds`: "Kind" of argument passed to the function call. Argument
kinds include positional, star (`*args`), named (`x=y`) and star2
(`**kwargs`) arguments (among others). Like `args`, this value is a list
of lists.
* `context`: AST node representing the function call with all available
type information. Notable attributes include:
* `args` and `arg_kinds`: Simple list of actual arguments, not mapped to
formal arguments.
* `callee`: AST node representing the function being called. Typically
this is a `NameExpr`. To resolve this node to the function definition
it references, accessing `callee.node` will usually return either a
`FuncDef` or `Decorator` node.
* etc.
This function infers a more specific type for model-combining functions by
making certain assumptions about how the function operates based on the
order of its formal arguments and its return type.
If the return type is `Model[InT, XY_YZ_OutT]`, the output of each
argument is expected to be used as the input to the next argument. It's
therefore necessary to check that the output type of each model is
compatible with the input type of the following model. The combined model
has the type `Model[InT, OutT]`, where `InT` is the input type of the
first model and `OutT` is the output type of the last model.
If the return type is `Model[InT, XY_XY_OutT]`, all model arguments
receive input of the same type and are expected to produce output of the
same type. It's therefore necessary to check that all models have the same
input types and the same output types. The combined model has the type
`Model[InT, OutT]`, where `InT` is the input type of all model arguments
and `OutT` is the output type of all model arguments.
Raises:
AssertionError: Raised if a more specific model type couldn't be
determined, indicating that the default general return type should
be used.
"""
# Verify that we have a type-checking API and a default return type (presumably a
# `thinc.model.Model` instance)
assert isinstance(ctx.api, TypeChecker)
assert isinstance(ctx.default_return_type, Instance)
# Verify that we're inspecting a function call to a callable defined or decorated function
assert isinstance(ctx.context, CallExpr)
callee = ctx.context.callee
assert isinstance(callee, NameExpr)
callee_node = callee.node
assert isinstance(callee_node, (FuncDef, Decorator))
callee_node_type = callee_node.type
assert isinstance(callee_node_type, CallableType)
# Verify that the callable returns a `thinc.model.Model`
# TODO: Use `map_instance_to_supertype` to map subtypes to `Model` instances.
# (figure out how to look up the `TypeInfo` for a class outside of the module being type-checked)
callee_return_type = callee_node_type.ret_type
assert isinstance(callee_return_type, Instance)
assert callee_return_type.type.fullname == thinc_model_fullname
assert callee_return_type.args
assert len(callee_return_type.args) == 2
# Obtain the output type parameter of the `thinc.model.Model` return type
# of the called API function
out_type = callee_return_type.args[1]
# Check if the `Model`'s output type parameter is one of the "special
# type variables" defined to represent model composition (chaining) and
# homogeneous reduction
assert isinstance(out_type, TypeVarType)
assert out_type.fullname
if out_type.fullname not in {intoin_outtoout_out_fullname, chained_out_fullname}:
return ctx.default_return_type
# Extract type of each argument used to call the API function, making sure that they are also
# `thinc.model.Model` instances
args = list(itertools.chain(*ctx.args))
arg_types = []
for arg_type in itertools.chain(*ctx.arg_types):
# TODO: Use `map_instance_to_supertype` to map subtypes to `Model` instances.
assert isinstance(arg_type, Instance)
assert arg_type.type.fullname == thinc_model_fullname
assert len(arg_type.args) == 2
arg_types.append(arg_type)
# Collect neighboring pairs of arguments and their types
arg_pairs = list(zip(args[:-1], args[1:]))
arg_types_pairs = list(zip(arg_types[:-1], arg_types[1:]))
# Determine if passed models will be chained or if they all need to have
# the same input and output type
if out_type.fullname == chained_out_fullname:
# Models will be chained, meaning that the output of each model will
# be passed as the input to the next model
# Verify that model inputs and outputs are compatible
for (arg1, arg2), (type1, type2) in zip(arg_pairs, arg_types_pairs):
assert isinstance(type1, Instance)
assert isinstance(type2, Instance)
assert type1.type.fullname == thinc_model_fullname
assert type2.type.fullname == thinc_model_fullname
check_chained(
l1_arg=arg1, l1_type=type1, l2_arg=arg2, l2_type=type2, api=ctx.api
)
# Generated model takes the first model's input and returns the last model's output
return Instance(
ctx.default_return_type.type, [arg_types[0].args[0], arg_types[-1].args[1]]
)
elif out_type.fullname == intoin_outtoout_out_fullname:
# Models must have the same input and output types
# Verify that model inputs and outputs are compatible
for (arg1, arg2), (type1, type2) in zip(arg_pairs, arg_types_pairs):
assert isinstance(type1, Instance)
assert isinstance(type2, Instance)
assert type1.type.fullname == thinc_model_fullname
assert type2.type.fullname == thinc_model_fullname
check_intoin_outtoout(
l1_arg=arg1, l1_type=type1, l2_arg=arg2, l2_type=type2, api=ctx.api
)
# Generated model accepts and returns the same types as all passed models
return Instance(
ctx.default_return_type.type, [arg_types[0].args[0], arg_types[0].args[1]]
)
# Make sure the default return type is returned if no branch was selected
assert False, "Thinc mypy plugin error: it should return before this point"
def check_chained(
*,
l1_arg: Expression,
l1_type: Instance,
l2_arg: Expression,
l2_type: Instance,
api: CheckerPluginInterface,
):
if not is_subtype(l1_type.args[1], l2_type.args[0]):
api.fail(
f"Layer outputs type ({l1_type.args[1]}) but the next layer expects ({l2_type.args[0]}) as an input",
l1_arg,
code=error_layer_output,
)
api.fail(
f"Layer input type ({l2_type.args[0]}) is not compatible with output ({l1_type.args[1]}) from previous layer",
l2_arg,
code=error_layer_input,
)
def check_intoin_outtoout(
*,
l1_arg: Expression,
l1_type: Instance,
l2_arg: Expression,
l2_type: Instance,
api: CheckerPluginInterface,
):
if l1_type.args[0] != l2_type.args[0]:
api.fail(
f"Layer input ({l1_type.args[0]}) not compatible with next layer input ({l2_type.args[0]})",
l1_arg,
code=error_layer_input,
)
api.fail(
f"Layer input ({l2_type.args[0]}) not compatible with previous layer input ({l1_type.args[0]})",
l2_arg,
code=error_layer_input,
)
if l1_type.args[1] != l2_type.args[1]:
api.fail(
f"Layer output ({l1_type.args[1]}) not compatible with next layer output ({l2_type.args[1]})",
l1_arg,
code=error_layer_output,
)
api.fail(
f"Layer output ({l2_type.args[1]}) not compatible with previous layer output ({l1_type.args[1]})",
l2_arg,
code=error_layer_output,
)
error_layer_input = ErrorCode("layer-mismatch-input", "Invalid layer input", "Thinc")
error_layer_output = ErrorCode("layer-mismatch-output", "Invalid layer output", "Thinc")
class IntrospectChecker(TypeChecker):
def __init__(
self,
errors: Errors,
modules: Dict[str, MypyFile],
options: Options,
tree: MypyFile,
path: str,
plugin: Plugin,
per_line_checking_time_ns: Dict[int, int],
):
self._error_messages: List[str] = []
super().__init__(
errors, modules, options, tree, path, plugin, per_line_checking_time_ns
)