Skip to content

Commit c3ca0d6

Browse files
authored
Add initial singledispatch typechecking support (#10694)
Adds initial support for type checking uses of the singledispatch decorator, including support for passing the dispatch type to register as an argument or using a type annotation. This implementation has several limitations that probably aren't going away any time soon: * The dispatch type of registered implementations is required to be a subtype of the first argument of the main singledispatch function. That allows us to type check calls to singledispatch functions without needing to have all registered implementations available, which allows incremental mode to work by preventing register calls in one module from affecting the type of the singledispatch function in another seemingly unrelated module. That also significantly simplifies the implementation when all register calls are in the same module as the main singledispatch function and makes the error messages much easier to understand. * Registered implementations are not tracked, meaning that we can't show errors about multiple registered implementations having the same dispatch type or other errors that require looking at multiple registered implementations at the same time. There are also some other limitations that aren't supported yet but probably could get better support: * The 2 argument version of register * Checking that the type annotation version of register isn't used before Python 3.7 * Uses of singledispatchmethod * Uses of registry or dispatch * Making redefining an underscore function not an error * Checking that registered functions have the same number of arguments as the fallback
1 parent cae5d3c commit c3ca0d6

File tree

4 files changed

+407
-59
lines changed

4 files changed

+407
-59
lines changed

mypy/plugins/common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
55
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
66
)
7-
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
7+
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
88
from mypy.semanal import set_callable_name
99
from mypy.types import (
1010
CallableType, Overloaded, Type, TypeVarDef, deserialize_type, get_proper_type,
@@ -103,16 +103,15 @@ def add_method(
103103

104104

105105
def add_method_to_class(
106-
api: SemanticAnalyzerPluginInterface,
106+
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface],
107107
cls: ClassDef,
108108
name: str,
109109
args: List[Argument],
110110
return_type: Type,
111111
self_type: Optional[Type] = None,
112112
tvar_def: Optional[TypeVarDef] = None,
113113
) -> None:
114-
"""Adds a new method to a class definition.
115-
"""
114+
"""Adds a new method to a class definition."""
116115
info = cls.info
117116

118117
# First remove any previously generated methods with the same name
@@ -123,7 +122,15 @@ def add_method_to_class(
123122
cls.defs.body.remove(sym.node)
124123

125124
self_type = self_type or fill_typevars(info)
126-
function_type = api.named_type('__builtins__.function')
125+
# TODO: semanal.py and checker.py seem to have subtly different implementations of
126+
# named_type/named_generic_type (starting with the fact that we have to use different names
127+
# for builtins), so it's easier to just check which one we're dealing with here and pick the
128+
# correct function to use than to try to add a named_type method that behaves the same for
129+
# both. We should probably combine those implementations at some point.
130+
if isinstance(api, SemanticAnalyzerPluginInterface):
131+
function_type = api.named_type('__builtins__.function')
132+
else:
133+
function_type = api.named_generic_type('builtins.function', [])
127134

128135
args = [Argument(Var('self'), self_type, None, ARG_POS)] + args
129136
arg_types, arg_names, arg_kinds = [], [], []

mypy/plugins/default.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010
from mypy.plugins.common import try_getting_str_literals
1111
from mypy.types import (
12-
Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType,
12+
FunctionLike, Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType,
1313
TypeVarDef, TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType
1414
)
1515
from mypy.subtypes import is_subtype
@@ -22,19 +22,21 @@ class DefaultPlugin(Plugin):
2222

2323
def get_function_hook(self, fullname: str
2424
) -> Optional[Callable[[FunctionContext], Type]]:
25-
from mypy.plugins import ctypes
25+
from mypy.plugins import ctypes, singledispatch
2626

2727
if fullname == 'contextlib.contextmanager':
2828
return contextmanager_callback
2929
elif fullname == 'builtins.open' and self.python_version[0] == 3:
3030
return open_callback
3131
elif fullname == 'ctypes.Array':
3232
return ctypes.array_constructor_callback
33+
elif fullname == 'functools.singledispatch':
34+
return singledispatch.create_singledispatch_function_callback
3335
return None
3436

3537
def get_method_signature_hook(self, fullname: str
36-
) -> Optional[Callable[[MethodSigContext], CallableType]]:
37-
from mypy.plugins import ctypes
38+
) -> Optional[Callable[[MethodSigContext], FunctionLike]]:
39+
from mypy.plugins import ctypes, singledispatch
3840

3941
if fullname == 'typing.Mapping.get':
4042
return typed_dict_get_signature_callback
@@ -48,11 +50,13 @@ def get_method_signature_hook(self, fullname: str
4850
return typed_dict_delitem_signature_callback
4951
elif fullname == 'ctypes.Array.__setitem__':
5052
return ctypes.array_setitem_callback
53+
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
54+
return singledispatch.call_singledispatch_function_callback
5155
return None
5256

5357
def get_method_hook(self, fullname: str
5458
) -> Optional[Callable[[MethodContext], Type]]:
55-
from mypy.plugins import ctypes
59+
from mypy.plugins import ctypes, singledispatch
5660

5761
if fullname == 'typing.Mapping.get':
5862
return typed_dict_get_callback
@@ -72,6 +76,10 @@ def get_method_hook(self, fullname: str
7276
return ctypes.array_iter_callback
7377
elif fullname == 'pathlib.Path.open':
7478
return path_open_callback
79+
elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
80+
return singledispatch.singledispatch_register_callback
81+
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
82+
return singledispatch.call_singledispatch_function_after_register_argument
7583
return None
7684

7785
def get_attribute_hook(self, fullname: str

mypy/plugins/singledispatch.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from mypy.messages import format_type
2+
from mypy.plugins.common import add_method_to_class
3+
from mypy.nodes import (
4+
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, ARG_STAR, ARG_OPT, Context
5+
)
6+
from mypy.subtypes import is_subtype
7+
from mypy.types import (
8+
AnyType, CallableType, Instance, NoneType, Overloaded, Type, TypeOfAny, get_proper_type,
9+
FunctionLike
10+
)
11+
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
12+
from typing import List, NamedTuple, Optional, Sequence, TypeVar, Union
13+
from typing_extensions import Final
14+
15+
SingledispatchTypeVars = NamedTuple('SingledispatchTypeVars', [
16+
('return_type', Type),
17+
('fallback', CallableType),
18+
])
19+
20+
RegisterCallableInfo = NamedTuple('RegisterCallableInfo', [
21+
('register_type', Type),
22+
('singledispatch_obj', Instance),
23+
])
24+
25+
SINGLEDISPATCH_TYPE = 'functools._SingleDispatchCallable'
26+
27+
SINGLEDISPATCH_REGISTER_METHOD = '{}.register'.format(SINGLEDISPATCH_TYPE) # type: Final
28+
29+
SINGLEDISPATCH_CALLABLE_CALL_METHOD = '{}.__call__'.format(SINGLEDISPATCH_TYPE) # type: Final
30+
31+
32+
def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars:
33+
return SingledispatchTypeVars(*typ.args) # type: ignore
34+
35+
36+
T = TypeVar('T')
37+
38+
39+
def get_first_arg(args: List[List[T]]) -> Optional[T]:
40+
"""Get the element that corresponds to the first argument passed to the function"""
41+
if args and args[0]:
42+
return args[0][0]
43+
return None
44+
45+
46+
REGISTER_RETURN_CLASS = '_SingleDispatchRegisterCallable'
47+
48+
REGISTER_CALLABLE_CALL_METHOD = 'functools.{}.__call__'.format(
49+
REGISTER_RETURN_CLASS
50+
) # type: Final
51+
52+
53+
def make_fake_register_class_instance(api: CheckerPluginInterface, type_args: Sequence[Type]
54+
) -> Instance:
55+
defn = ClassDef(REGISTER_RETURN_CLASS, Block([]))
56+
defn.fullname = 'functools.{}'.format(REGISTER_RETURN_CLASS)
57+
info = TypeInfo(SymbolTable(), defn, "functools")
58+
obj_type = api.named_generic_type('builtins.object', []).type
59+
info.bases = [Instance(obj_type, [])]
60+
info.mro = [info, obj_type]
61+
defn.info = info
62+
63+
func_arg = Argument(Var('name'), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
64+
add_method_to_class(api, defn, '__call__', [func_arg], NoneType())
65+
66+
return Instance(info, type_args)
67+
68+
69+
PluginContext = Union[FunctionContext, MethodContext]
70+
71+
72+
def fail(ctx: PluginContext, msg: str, context: Optional[Context]) -> None:
73+
"""Emit an error message.
74+
75+
This tries to emit an error message at the location specified by `context`, falling back to the
76+
location specified by `ctx.context`. This is helpful when the only context information about
77+
where you want to put the error message may be None (like it is for `CallableType.definition`)
78+
and falling back to the location of the calling function is fine."""
79+
# TODO: figure out if there is some more reliable way of getting context information, so this
80+
# function isn't necessary
81+
if context is not None:
82+
err_context = context
83+
else:
84+
err_context = ctx.context
85+
ctx.api.fail(msg, err_context)
86+
87+
88+
def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
89+
"""Called for functools.singledispatch"""
90+
func_type = get_proper_type(get_first_arg(ctx.arg_types))
91+
if isinstance(func_type, CallableType):
92+
93+
if len(func_type.arg_kinds) < 1:
94+
fail(
95+
ctx,
96+
'Singledispatch function requires at least one argument',
97+
func_type.definition,
98+
)
99+
return ctx.default_return_type
100+
101+
elif func_type.arg_kinds[0] not in (ARG_POS, ARG_OPT, ARG_STAR):
102+
fail(
103+
ctx,
104+
'First argument to singledispatch function must be a positional argument',
105+
func_type.definition,
106+
)
107+
return ctx.default_return_type
108+
109+
# singledispatch returns an instance of functools._SingleDispatchCallable according to
110+
# typeshed
111+
singledispatch_obj = get_proper_type(ctx.default_return_type)
112+
assert isinstance(singledispatch_obj, Instance)
113+
singledispatch_obj.args += (func_type,)
114+
115+
return ctx.default_return_type
116+
117+
118+
def singledispatch_register_callback(ctx: MethodContext) -> Type:
119+
"""Called for functools._SingleDispatchCallable.register"""
120+
assert isinstance(ctx.type, Instance)
121+
# TODO: check that there's only one argument
122+
first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
123+
if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
124+
# HACK: We receieved a class as an argument to register. We need to be able
125+
# to access the function that register is being applied to, and the typeshed definition
126+
# of register has it return a generic Callable, so we create a new
127+
# SingleDispatchRegisterCallable class, define a __call__ method, and then add a
128+
# plugin hook for that.
129+
130+
# is_subtype doesn't work when the right type is Overloaded, so we need the
131+
# actual type
132+
register_type = first_arg_type.items()[0].ret_type
133+
type_args = RegisterCallableInfo(register_type, ctx.type)
134+
register_callable = make_fake_register_class_instance(
135+
ctx.api,
136+
type_args
137+
)
138+
return register_callable
139+
elif isinstance(first_arg_type, CallableType):
140+
# TODO: do more checking for registered functions
141+
register_function(ctx, ctx.type, first_arg_type)
142+
143+
# register doesn't modify the function it's used on
144+
return ctx.default_return_type
145+
146+
147+
def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Type,
148+
register_arg: Optional[Type] = None) -> None:
149+
150+
func = get_proper_type(func)
151+
if not isinstance(func, CallableType):
152+
return
153+
metadata = get_singledispatch_info(singledispatch_obj)
154+
dispatch_type = get_dispatch_type(func, register_arg)
155+
if dispatch_type is None:
156+
# TODO: report an error here that singledispatch requires at least one argument
157+
# (might want to do the error reporting in get_dispatch_type)
158+
return
159+
fallback = metadata.fallback
160+
161+
fallback_dispatch_type = fallback.arg_types[0]
162+
if not is_subtype(dispatch_type, fallback_dispatch_type):
163+
164+
fail(ctx, 'Dispatch type {} must be subtype of fallback function first argument {}'.format(
165+
format_type(dispatch_type), format_type(fallback_dispatch_type)
166+
), func.definition)
167+
return
168+
169+
170+
def get_dispatch_type(func: CallableType, register_arg: Optional[Type]) -> Optional[Type]:
171+
if register_arg is not None:
172+
return register_arg
173+
if func.arg_types:
174+
return func.arg_types[0]
175+
return None
176+
177+
178+
def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
179+
"""Called on the function after passing a type to register"""
180+
register_callable = ctx.type
181+
if isinstance(register_callable, Instance):
182+
type_args = RegisterCallableInfo(*register_callable.args) # type: ignore
183+
func = get_first_arg(ctx.arg_types)
184+
if func is not None:
185+
register_function(ctx, type_args.singledispatch_obj, func, type_args.register_type)
186+
return ctx.default_return_type
187+
188+
189+
def rename_func(func: CallableType, new_name: CallableType) -> CallableType:
190+
"""Return a new CallableType that is `function` with the name of `new_name`"""
191+
if new_name.name is not None:
192+
signature_used = func.with_name(new_name.name)
193+
else:
194+
signature_used = func
195+
return signature_used
196+
197+
198+
def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
199+
"""Called for functools._SingleDispatchCallable.__call__"""
200+
if not isinstance(ctx.type, Instance):
201+
return ctx.default_signature
202+
metadata = get_singledispatch_info(ctx.type)
203+
return metadata.fallback

0 commit comments

Comments
 (0)