Skip to content

Commit 36a0b07

Browse files
Handle functools.Placeholder in partial
Fixes #21313.
1 parent 0ea16df commit 36a0b07

2 files changed

Lines changed: 100 additions & 3 deletions

File tree

mypy/plugins/functools.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
import mypy.semanal
1010
from mypy.argmap import map_actuals_to_formals
1111
from mypy.erasetype import erase_typevars
12+
from mypy.expandtype import expand_type
13+
from mypy.infer import infer_type_arguments
1214
from mypy.nodes import (
1315
ARG_POS,
1416
ARG_STAR2,
1517
SYMBOL_FUNCBASE_TYPES,
1618
ArgKind,
1719
Argument,
1820
CallExpr,
21+
Expression,
22+
MemberExpr,
1923
NameExpr,
2024
Var,
2125
)
@@ -30,6 +34,7 @@
3034
ParamSpecType,
3135
Type,
3236
TypeOfAny,
37+
TypeVarId,
3338
TypeVarType,
3439
UnboundType,
3540
UnionType,
@@ -41,6 +46,7 @@
4146
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
4247

4348
PARTIAL: Final = "functools.partial"
49+
PLACEHOLDER: Final = "functools.Placeholder"
4450

4551

4652
class _MethodInfo:
@@ -134,6 +140,10 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
134140
return comparison_methods
135141

136142

143+
def _is_functools_placeholder(expr: Expression) -> bool:
144+
return isinstance(expr, (NameExpr, MemberExpr)) and expr.fullname == PLACEHOLDER
145+
146+
137147
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
138148
"""Infer a more precise return type for functools.partial"""
139149
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
@@ -184,6 +194,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
184194
actual_arg_kinds = []
185195
actual_arg_names = []
186196
actual_types = []
197+
placeholder_actuals = []
187198
seen_args = set()
188199
for i, param in enumerate(ctx.args[1:], start=1):
189200
for j, a in enumerate(param):
@@ -198,6 +209,9 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
198209
actual_arg_kinds.append(ctx.arg_kinds[i][j])
199210
actual_arg_names.append(ctx.arg_names[i][j])
200211
actual_types.append(ctx.arg_types[i][j])
212+
placeholder_actuals.append(
213+
ctx.arg_kinds[i][j].is_positional() and _is_functools_placeholder(a)
214+
)
201215

202216
formal_to_actual = map_actuals_to_formals(
203217
actual_kinds=actual_arg_kinds,
@@ -215,8 +229,20 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
215229
continue
216230
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})
217231

232+
defaulted_arg_types = list(fn_type.arg_types)
233+
for i, actuals in enumerate(formal_to_actual):
234+
if any(placeholder_actuals[j] for j in actuals):
235+
# functools.Placeholder is a positional sentinel introduced in Python 3.14.
236+
# It occupies the formal slot but does not bind it, so make the validation
237+
# call accept the sentinel while preserving the original type for the
238+
# resulting partial signature below.
239+
defaulted_arg_types[i] = actual_types[
240+
next(j for j in actuals if placeholder_actuals[j])
241+
]
242+
218243
# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
219244
defaulted = fn_type.copy_modified(
245+
arg_types=defaulted_arg_types,
220246
arg_kinds=[
221247
(
222248
ArgKind.ARG_OPT
@@ -273,10 +299,25 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
273299
partial_kinds = []
274300
partial_types = []
275301
partial_names = []
302+
inferred_type_vars: dict[TypeVarId, Type] = {}
303+
if len(bound.arg_types) == len(fn_type.arg_types):
304+
for i, actuals in enumerate(formal_to_actual):
305+
if not actuals or any(placeholder_actuals[j] for j in actuals):
306+
continue
307+
inferred_args = infer_type_arguments(
308+
fn_type.variables, fn_type.arg_types[i], bound.arg_types[i]
309+
)
310+
for type_var, inferred_arg in zip(fn_type.variables, inferred_args):
311+
if inferred_arg is not None and mypy.checker.is_valid_inferred_type(
312+
inferred_arg, ctx.api.options
313+
):
314+
inferred_type_vars[type_var.id] = inferred_arg
276315
# We need to fully apply any positional arguments (they cannot be respecified)
277316
# However, keyword arguments can be respecified, so just give them a default
278317
for i, actuals in enumerate(formal_to_actual):
279-
if len(bound.arg_types) == len(fn_type.arg_types):
318+
if any(placeholder_actuals[j] for j in actuals):
319+
arg_type = expand_type(fn_type.arg_types[i], inferred_type_vars)
320+
elif len(bound.arg_types) == len(fn_type.arg_types):
280321
arg_type = bound.arg_types[i]
281322
if not mypy.checker.is_valid_inferred_type(arg_type, ctx.api.options):
282323
arg_type = fn_type.arg_types[i] # bit of a hack
@@ -285,10 +326,16 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
285326
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
286327
arg_type = fn_type.arg_types[i]
287328

288-
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
329+
if (
330+
not actuals
331+
or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2)
332+
or any(placeholder_actuals[j] for j in actuals)
333+
):
289334
partial_kinds.append(fn_type.arg_kinds[i])
290335
partial_types.append(arg_type)
291-
partial_names.append(fn_type.arg_names[i])
336+
partial_names.append(
337+
None if any(placeholder_actuals[j] for j in actuals) else fn_type.arg_names[i]
338+
)
292339
else:
293340
assert actuals
294341
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):

test-data/unit/check-functools.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,53 @@ def outer_c(arg: Tc) -> None:
726726
use_int_callable(partial(inner, b="")) # E: Argument 1 to "use_int_callable" has incompatible type "partial[str]"; expected "Callable[[int], int]" \
727727
# N: "partial[str].__call__" has type "def __call__(__self, *args: Any, **kwargs: Any) -> str"
728728
[builtins fixtures/tuple.pyi]
729+
730+
[case testFunctoolsPartialPlaceholder]
731+
import functools
732+
from functools import partial, Placeholder as _
733+
from typing import TypeVar
734+
735+
T = TypeVar("T")
736+
737+
738+
def foo(a: int, b: str, c: bool) -> tuple[int, str, bool]: ...
739+
740+
741+
p = partial(foo, _, "x", _)
742+
reveal_type(p) # N: Revealed type is "functools.partial[tuple[builtins.int, builtins.str, builtins.bool]]"
743+
reveal_type(p(1, True)) # N: Revealed type is "tuple[builtins.int, builtins.str, builtins.bool]"
744+
p("bad", True) # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
745+
p(1, 1) # E: Argument 2 to "foo" has incompatible type "int"; expected "bool"
746+
p(a=1, c=True) # E: Unexpected keyword argument "a" for "foo" \
747+
# E: Unexpected keyword argument "c" for "foo"
748+
749+
750+
def same(a: T, b: T) -> T: ...
751+
def same_list(a: T, b: list[T]) -> T: ...
752+
753+
754+
generic = partial(same, _, 1)
755+
reveal_type(generic) # N: Revealed type is "functools.partial[builtins.int]"
756+
generic(2)
757+
generic("bad") # E: Argument 1 to "same" has incompatible type "str"; expected "int"
758+
759+
nested_generic = partial(same_list, _, [1])
760+
reveal_type(nested_generic) # N: Revealed type is "functools.partial[builtins.int]"
761+
nested_generic(2)
762+
nested_generic("bad") # E: Argument 1 to "same_list" has incompatible type "str"; expected "int"
763+
764+
module_attr = partial(foo, functools.Placeholder, "x", functools.Placeholder)
765+
reveal_type(module_attr(1, True)) # N: Revealed type is "tuple[builtins.int, builtins.str, builtins.bool]"
766+
partial(foo, a=_) # E: Argument "a" to "foo" has incompatible type "_PlaceholderType"; expected "int"
767+
[file functools.pyi]
768+
from typing import Any, Callable, Final, Generic, TypeVar
769+
770+
_T = TypeVar("_T")
771+
772+
class _PlaceholderType: ...
773+
Placeholder: Final[_PlaceholderType]
774+
775+
class partial(Generic[_T]):
776+
def __new__(cls, func: Callable[..., _T], /, *args: Any, **kwargs: Any) -> partial[_T]: ...
777+
def __call__(self, *args: Any, **kwargs: Any) -> _T: ...
778+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)