267 lines
11 KiB
Python
267 lines
11 KiB
Python
import itertools
|
|
from typing import Dict, List
|
|
|
|
from mypy.checker import TypeChecker
|
|
from mypy.errorcodes import ErrorCode
|
|
from mypy.errors import Errors
|
|
from mypy.nodes import CallExpr, Decorator, Expression, FuncDef, MypyFile, NameExpr
|
|
from mypy.options import Options
|
|
from mypy.plugin import CheckerPluginInterface, FunctionContext, Plugin
|
|
from mypy.subtypes import is_subtype
|
|
from mypy.types import CallableType, Instance, Type, TypeVarType
|
|
|
|
thinc_model_fullname = "thinc.model.Model"
|
|
chained_out_fullname = "thinc.types.XY_YZ_OutT"
|
|
intoin_outtoout_out_fullname = "thinc.types.XY_XY_OutT"
|
|
|
|
|
|
def plugin(version: str):
|
|
return ThincPlugin
|
|
|
|
|
|
class ThincPlugin(Plugin):
|
|
def __init__(self, options: Options) -> None:
|
|
super().__init__(options)
|
|
|
|
def get_function_hook(self, fullname: str):
|
|
return function_hook
|
|
|
|
|
|
def function_hook(ctx: FunctionContext) -> Type:
|
|
try:
|
|
return get_reducers_type(ctx)
|
|
except AssertionError:
|
|
# Add more function callbacks here
|
|
return ctx.default_return_type
|
|
|
|
|
|
def get_reducers_type(ctx: FunctionContext) -> Type:
|
|
"""
|
|
Determine a more specific model type for functions that combine models.
|
|
|
|
This function operates on function *calls*. It analyzes each function call
|
|
by looking at the function definition and the arguments passed as part of
|
|
the function call, then determines a more specific return type for the
|
|
function call.
|
|
|
|
This method accepts a `FunctionContext` as part of the Mypy plugin
|
|
interface. This function context provides easy access to:
|
|
* `args`: List of "actual arguments" filling each "formal argument" of the
|
|
called function. "Actual arguments" are those passed to the function
|
|
as part of the function call. "Formal arguments" are the parameters
|
|
defined by the function definition. The same actual argument may serve
|
|
to fill multiple formal arguments. In some cases the relationship may
|
|
even be ambiguous. For example, calling `range(*args)`, the actual
|
|
argument `*args` may fill the `start`, `stop` or `step` formal
|
|
arguments, depending on the length of the list.
|
|
|
|
The `args` list is of length `num_formals`, with each element
|
|
corresponding to a formal argument. Each value in the `args` list is a
|
|
list of actual arguments which may fill the formal argument. For
|
|
example, in the function call `range(*args, num)`, `num` may fill the
|
|
`start`, `end` or `step` formal arguments depending on the length of
|
|
`args`, so type-checking needs to consider all of these possibilities.
|
|
* `arg_types`: Type annotation (or inferred type) of each argument. Like
|
|
`args`, this value is a list of lists with an outer list entry for each
|
|
formal argument and an inner list entry for each possible actual
|
|
argument for the formal argument.
|
|
* `arg_kinds`: "Kind" of argument passed to the function call. Argument
|
|
kinds include positional, star (`*args`), named (`x=y`) and star2
|
|
(`**kwargs`) arguments (among others). Like `args`, this value is a list
|
|
of lists.
|
|
* `context`: AST node representing the function call with all available
|
|
type information. Notable attributes include:
|
|
* `args` and `arg_kinds`: Simple list of actual arguments, not mapped to
|
|
formal arguments.
|
|
* `callee`: AST node representing the function being called. Typically
|
|
this is a `NameExpr`. To resolve this node to the function definition
|
|
it references, accessing `callee.node` will usually return either a
|
|
`FuncDef` or `Decorator` node.
|
|
* etc.
|
|
|
|
This function infers a more specific type for model-combining functions by
|
|
making certain assumptions about how the function operates based on the
|
|
order of its formal arguments and its return type.
|
|
|
|
If the return type is `Model[InT, XY_YZ_OutT]`, the output of each
|
|
argument is expected to be used as the input to the next argument. It's
|
|
therefore necessary to check that the output type of each model is
|
|
compatible with the input type of the following model. The combined model
|
|
has the type `Model[InT, OutT]`, where `InT` is the input type of the
|
|
first model and `OutT` is the output type of the last model.
|
|
|
|
If the return type is `Model[InT, XY_XY_OutT]`, all model arguments
|
|
receive input of the same type and are expected to produce output of the
|
|
same type. It's therefore necessary to check that all models have the same
|
|
input types and the same output types. The combined model has the type
|
|
`Model[InT, OutT]`, where `InT` is the input type of all model arguments
|
|
and `OutT` is the output type of all model arguments.
|
|
|
|
Raises:
|
|
AssertionError: Raised if a more specific model type couldn't be
|
|
determined, indicating that the default general return type should
|
|
be used.
|
|
"""
|
|
# Verify that we have a type-checking API and a default return type (presumably a
|
|
# `thinc.model.Model` instance)
|
|
assert isinstance(ctx.api, TypeChecker)
|
|
assert isinstance(ctx.default_return_type, Instance)
|
|
|
|
# Verify that we're inspecting a function call to a callable defined or decorated function
|
|
assert isinstance(ctx.context, CallExpr)
|
|
callee = ctx.context.callee
|
|
assert isinstance(callee, NameExpr)
|
|
callee_node = callee.node
|
|
assert isinstance(callee_node, (FuncDef, Decorator))
|
|
callee_node_type = callee_node.type
|
|
assert isinstance(callee_node_type, CallableType)
|
|
|
|
# Verify that the callable returns a `thinc.model.Model`
|
|
# TODO: Use `map_instance_to_supertype` to map subtypes to `Model` instances.
|
|
# (figure out how to look up the `TypeInfo` for a class outside of the module being type-checked)
|
|
callee_return_type = callee_node_type.ret_type
|
|
assert isinstance(callee_return_type, Instance)
|
|
assert callee_return_type.type.fullname == thinc_model_fullname
|
|
assert callee_return_type.args
|
|
assert len(callee_return_type.args) == 2
|
|
|
|
# Obtain the output type parameter of the `thinc.model.Model` return type
|
|
# of the called API function
|
|
out_type = callee_return_type.args[1]
|
|
|
|
# Check if the `Model`'s output type parameter is one of the "special
|
|
# type variables" defined to represent model composition (chaining) and
|
|
# homogeneous reduction
|
|
assert isinstance(out_type, TypeVarType)
|
|
assert out_type.fullname
|
|
if out_type.fullname not in {intoin_outtoout_out_fullname, chained_out_fullname}:
|
|
return ctx.default_return_type
|
|
|
|
# Extract type of each argument used to call the API function, making sure that they are also
|
|
# `thinc.model.Model` instances
|
|
args = list(itertools.chain(*ctx.args))
|
|
arg_types = []
|
|
for arg_type in itertools.chain(*ctx.arg_types):
|
|
# TODO: Use `map_instance_to_supertype` to map subtypes to `Model` instances.
|
|
assert isinstance(arg_type, Instance)
|
|
assert arg_type.type.fullname == thinc_model_fullname
|
|
assert len(arg_type.args) == 2
|
|
arg_types.append(arg_type)
|
|
|
|
# Collect neighboring pairs of arguments and their types
|
|
arg_pairs = list(zip(args[:-1], args[1:]))
|
|
arg_types_pairs = list(zip(arg_types[:-1], arg_types[1:]))
|
|
|
|
# Determine if passed models will be chained or if they all need to have
|
|
# the same input and output type
|
|
if out_type.fullname == chained_out_fullname:
|
|
# Models will be chained, meaning that the output of each model will
|
|
# be passed as the input to the next model
|
|
# Verify that model inputs and outputs are compatible
|
|
for (arg1, arg2), (type1, type2) in zip(arg_pairs, arg_types_pairs):
|
|
assert isinstance(type1, Instance)
|
|
assert isinstance(type2, Instance)
|
|
assert type1.type.fullname == thinc_model_fullname
|
|
assert type2.type.fullname == thinc_model_fullname
|
|
check_chained(
|
|
l1_arg=arg1, l1_type=type1, l2_arg=arg2, l2_type=type2, api=ctx.api
|
|
)
|
|
|
|
# Generated model takes the first model's input and returns the last model's output
|
|
return Instance(
|
|
ctx.default_return_type.type, [arg_types[0].args[0], arg_types[-1].args[1]]
|
|
)
|
|
elif out_type.fullname == intoin_outtoout_out_fullname:
|
|
# Models must have the same input and output types
|
|
# Verify that model inputs and outputs are compatible
|
|
for (arg1, arg2), (type1, type2) in zip(arg_pairs, arg_types_pairs):
|
|
assert isinstance(type1, Instance)
|
|
assert isinstance(type2, Instance)
|
|
assert type1.type.fullname == thinc_model_fullname
|
|
assert type2.type.fullname == thinc_model_fullname
|
|
check_intoin_outtoout(
|
|
l1_arg=arg1, l1_type=type1, l2_arg=arg2, l2_type=type2, api=ctx.api
|
|
)
|
|
|
|
# Generated model accepts and returns the same types as all passed models
|
|
return Instance(
|
|
ctx.default_return_type.type, [arg_types[0].args[0], arg_types[0].args[1]]
|
|
)
|
|
|
|
# Make sure the default return type is returned if no branch was selected
|
|
assert False, "Thinc mypy plugin error: it should return before this point"
|
|
|
|
|
|
def check_chained(
|
|
*,
|
|
l1_arg: Expression,
|
|
l1_type: Instance,
|
|
l2_arg: Expression,
|
|
l2_type: Instance,
|
|
api: CheckerPluginInterface,
|
|
):
|
|
if not is_subtype(l1_type.args[1], l2_type.args[0]):
|
|
api.fail(
|
|
f"Layer outputs type ({l1_type.args[1]}) but the next layer expects ({l2_type.args[0]}) as an input",
|
|
l1_arg,
|
|
code=error_layer_output,
|
|
)
|
|
api.fail(
|
|
f"Layer input type ({l2_type.args[0]}) is not compatible with output ({l1_type.args[1]}) from previous layer",
|
|
l2_arg,
|
|
code=error_layer_input,
|
|
)
|
|
|
|
|
|
def check_intoin_outtoout(
|
|
*,
|
|
l1_arg: Expression,
|
|
l1_type: Instance,
|
|
l2_arg: Expression,
|
|
l2_type: Instance,
|
|
api: CheckerPluginInterface,
|
|
):
|
|
if l1_type.args[0] != l2_type.args[0]:
|
|
api.fail(
|
|
f"Layer input ({l1_type.args[0]}) not compatible with next layer input ({l2_type.args[0]})",
|
|
l1_arg,
|
|
code=error_layer_input,
|
|
)
|
|
api.fail(
|
|
f"Layer input ({l2_type.args[0]}) not compatible with previous layer input ({l1_type.args[0]})",
|
|
l2_arg,
|
|
code=error_layer_input,
|
|
)
|
|
if l1_type.args[1] != l2_type.args[1]:
|
|
api.fail(
|
|
f"Layer output ({l1_type.args[1]}) not compatible with next layer output ({l2_type.args[1]})",
|
|
l1_arg,
|
|
code=error_layer_output,
|
|
)
|
|
api.fail(
|
|
f"Layer output ({l2_type.args[1]}) not compatible with previous layer output ({l1_type.args[1]})",
|
|
l2_arg,
|
|
code=error_layer_output,
|
|
)
|
|
|
|
|
|
error_layer_input = ErrorCode("layer-mismatch-input", "Invalid layer input", "Thinc")
|
|
error_layer_output = ErrorCode("layer-mismatch-output", "Invalid layer output", "Thinc")
|
|
|
|
|
|
class IntrospectChecker(TypeChecker):
|
|
def __init__(
|
|
self,
|
|
errors: Errors,
|
|
modules: Dict[str, MypyFile],
|
|
options: Options,
|
|
tree: MypyFile,
|
|
path: str,
|
|
plugin: Plugin,
|
|
per_line_checking_time_ns: Dict[int, int],
|
|
):
|
|
self._error_messages: List[str] = []
|
|
super().__init__(
|
|
errors, modules, options, tree, path, plugin, per_line_checking_time_ns
|
|
)
|