Skip to content

Commit 5e239a3

Browse files
always use the solver meachnism
1 parent df7b3f0 commit 5e239a3

File tree

2 files changed

+80
-97
lines changed

2 files changed

+80
-97
lines changed

mypy/argmap.py

Lines changed: 68 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def map_actuals_to_formals(
7575
proper_types := [get_proper_type(t) for t in actualt.items]
7676
)
7777
):
78+
# pick an arbitrary member
7879
actualt = proper_types[0]
7980
if isinstance(actualt, TupleType):
8081
# A tuple actual maps to a fixed number of formals.
@@ -193,15 +194,6 @@ def __init__(self, context: ArgumentInferContext) -> None:
193194
# Type context for `*` and `**` arg kinds.
194195
self.context = context
195196

196-
def __eq__(self, other: object) -> bool:
197-
if isinstance(other, ArgTypeExpander):
198-
return (
199-
self.tuple_index == other.tuple_index
200-
and self.kwargs_used == other.kwargs_used
201-
and self.context == other.context
202-
)
203-
return NotImplemented
204-
205197
def expand_actual_type(
206198
self,
207199
actual_type: Type,
@@ -227,29 +219,8 @@ def expand_actual_type(
227219
# parse *args as one of the following:
228220
# IterableType | TupleType | ParamSpecType | AnyType
229221
star_args = self.parse_star_args_type(actual_type)
230-
# star_args = actual_type
231-
232-
# print(f"expand_actual_type: {actual_type=} {star_args=}")
233-
234-
# if isinstance(star_args, TypeVarTupleType):
235-
# # This code path is hit when *Ts is passed to a callable and various
236-
# # special-handling didn't catch this. The best thing we can do is to use
237-
# # the upper bound.
238-
# star_args = get_proper_type(star_args.upper_bound)
239-
# if isinstance(star_args, Instance) and star_args.args:
240-
# from mypy.subtypes import is_subtype
241-
#
242-
# if is_subtype(star_args, self.context.iterable_type):
243-
# return map_instance_to_supertype(
244-
# star_args, self.context.iterable_type.type
245-
# ).args[0]
246-
# else:
247-
# # We cannot properly unpack anything other
248-
# # than `Iterable` type with `*`.
249-
# # Just return `Any`, other parts of code would raise
250-
# # a different error for improper use.
251-
# return AnyType(TypeOfAny.from_error)
252-
if self.is_iterable_type(star_args):
222+
223+
if self.is_iterable_instance_type(star_args):
253224
return star_args.args[0]
254225
elif isinstance(star_args, TupleType):
255226
# Get the next tuple item of a tuple *arg.
@@ -321,30 +292,75 @@ def is_iterable_instance_subtype(self, typ: Type) -> TypeGuard[Instance]:
321292
and is_subtype(p_t, self.context.iterable_type)
322293
)
323294

324-
def is_iterable_type(self, typ: Type) -> TypeGuard[IterableType]:
295+
def is_iterable_instance_type(self, typ: Type) -> TypeGuard[IterableType]:
325296
"""Check if the type is an Iterable[T] or a subtype of it."""
326297
p_t = get_proper_type(typ)
327298
return isinstance(p_t, Instance) and p_t.type == self.context.iterable_type.type
328299

300+
def _make_iterable_instance_type(self, arg: Type) -> IterableType:
301+
value = Instance(self.context.iterable_type.type, [arg])
302+
return cast(IterableType, value)
303+
304+
def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
305+
r"""Use the solver to cast a type as Iterable[T].
306+
307+
Returns `AnyType` if solving fails.
308+
"""
309+
from mypy.constraints import infer_constraints_for_callable
310+
from mypy.nodes import ARG_POS
311+
from mypy.solve import solve_constraints
312+
313+
iterable_kind = self.context.iterable_type.type
314+
315+
# We first create an upcast function:
316+
# def [T] (Iterable[T]) -> Iterable[T]: ...
317+
# and then solve for T, given the input type as the argument.
318+
T = TypeVarType(
319+
"T",
320+
"T",
321+
TypeVarId(-1),
322+
values=[],
323+
upper_bound=AnyType(TypeOfAny.special_form),
324+
default=AnyType(TypeOfAny.special_form),
325+
)
326+
target = Instance(iterable_kind, [T])
327+
328+
upcast_callable = CallableType(
329+
variables=[T],
330+
arg_types=[target],
331+
arg_kinds=[ARG_POS],
332+
arg_names=[None],
333+
ret_type=T,
334+
fallback=self.context.function_type,
335+
)
336+
constraints = infer_constraints_for_callable(
337+
upcast_callable, [typ], [ARG_POS], [None], [[0]], context=self.context
338+
)
339+
340+
(sol,), _ = solve_constraints([T], constraints)
341+
342+
if sol is None: # solving failed, return AnyType fallback
343+
return AnyType(TypeOfAny.from_error)
344+
return self._make_iterable_instance_type(sol)
345+
329346
def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
330347
"""Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331348
p_t = get_proper_type(typ)
332-
if self.is_iterable_type(p_t):
349+
if self.is_iterable_instance_type(p_t) or isinstance(p_t, AnyType):
333350
return p_t
334-
elif self.is_iterable_instance_subtype(p_t):
335-
cls = self.context.iterable_type.type
336-
return cast(IterableType, map_instance_to_supertype(p_t, cls))
337351
elif isinstance(p_t, UnionType):
338352
# If the type is a union, map each item to the iterable supertype.
339353
# the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340354
converted_types = [self.as_iterable_type(get_proper_type(item)) for item in p_t.items]
341-
# if an item could not be interpreted as Iterable[T], we return AnyType
342-
if all(self.is_iterable_type(it) for it in converted_types):
355+
356+
if any(not self.is_iterable_instance_type(it) for it in converted_types):
357+
# if any item could not be interpreted as Iterable[T], we return AnyType
358+
return AnyType(TypeOfAny.from_error)
359+
else:
343360
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344361
iterable_types = cast(list[IterableType], converted_types)
345362
arg = make_simplified_union([it.args[0] for it in iterable_types])
346-
return self.make_iterable_type(arg)
347-
return AnyType(TypeOfAny.from_error)
363+
return self._make_iterable_instance_type(arg)
348364
elif isinstance(p_t, TupleType):
349365
# maps tuple[A, B, C] -> Iterable[A | B | C]
350366
# note: proper_elements may contain UnpackType, for instance with
@@ -354,26 +370,24 @@ def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
354370
for p_e in proper_elements:
355371
if isinstance(p_e, UnpackType):
356372
r = self.as_iterable_type(p_e)
357-
if self.is_iterable_type(r):
373+
if self.is_iterable_instance_type(r):
358374
args.append(r.args[0])
359375
else:
376+
# this *should* never happen
360377
args.append(r)
361378
else:
362379
args.append(p_e)
363-
return self.make_iterable_type(make_simplified_union(args))
364-
if isinstance(p_t, UnpackType):
380+
return self._make_iterable_instance_type(make_simplified_union(args))
381+
elif isinstance(p_t, UnpackType):
365382
return self.as_iterable_type(p_t.type)
366-
if isinstance(p_t, (TypeVarType, TypeVarTupleType)):
383+
elif isinstance(p_t, (TypeVarType, TypeVarTupleType)):
367384
return self.as_iterable_type(p_t.upper_bound)
368-
# fallback: use the solver to reinterpret the type as Iterable[T]
369-
if self.is_iterable(p_t):
385+
elif self.is_iterable(p_t):
386+
# TODO: add a 'fast path' (needs measurement) that uses the map_instance_to_supertype
387+
# mechanism? (Only if it works: gh-19662)
370388
return self._solve_as_iterable(p_t)
371389
return AnyType(TypeOfAny.from_error)
372390

373-
def make_iterable_type(self, arg: Type) -> IterableType:
374-
value = Instance(self.context.iterable_type.type, [arg])
375-
return cast(IterableType, value)
376-
377391
def parse_star_args_type(
378392
self, typ: Type
379393
) -> TupleType | IterableType | ParamSpecType | AnyType:
@@ -411,61 +425,19 @@ def parse_star_args_type(
411425
# Note that this covers unions of differently sized tuples as well.
412426
else:
413427
converted_types = [self.as_iterable_type(p_i) for p_i in proper_items]
414-
if all(self.is_iterable_type(it) for it in converted_types):
428+
if all(self.is_iterable_instance_type(it) for it in converted_types):
415429
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416430
iterables = cast(list[IterableType], converted_types)
417431
arg = make_simplified_union([it.args[0] for it in iterables])
418-
return self.make_iterable_type(arg)
432+
return self._make_iterable_instance_type(arg)
419433
else:
420434
# some items in the union are not iterable, return AnyType
421435
return AnyType(TypeOfAny.from_error)
422-
elif self.is_iterable_type(parsed := self.as_iterable_type(p_t)):
436+
elif self.is_iterable_instance_type(parsed := self.as_iterable_type(p_t)):
423437
# in all other cases, we try to reinterpret the type as Iterable[T]
424438
return parsed
425439
return AnyType(TypeOfAny.from_error)
426440

427-
def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
428-
r"""Use the solver to cast a type as Iterable[T].
429-
430-
Returns the type as-is if solving fails.
431-
"""
432-
from mypy.constraints import infer_constraints_for_callable
433-
from mypy.nodes import ARG_POS
434-
from mypy.solve import solve_constraints
435-
436-
iterable_kind = self.context.iterable_type.type
437-
438-
# We first create an upcast function:
439-
# def [T] (Iterable[T]) -> Iterable[T]: ...
440-
# and then solve for T, given the input type as the argument.
441-
T = TypeVarType(
442-
"T",
443-
"T",
444-
TypeVarId(-1),
445-
values=[],
446-
upper_bound=AnyType(TypeOfAny.special_form),
447-
default=AnyType(TypeOfAny.special_form),
448-
)
449-
target = Instance(iterable_kind, [T])
450-
451-
upcast_callable = CallableType(
452-
variables=[T],
453-
arg_types=[target],
454-
arg_kinds=[ARG_POS],
455-
arg_names=[None],
456-
ret_type=T,
457-
fallback=self.context.function_type,
458-
)
459-
constraints = infer_constraints_for_callable(
460-
upcast_callable, [typ], [ARG_POS], [None], [[0]], context=self.context
461-
)
462-
463-
(sol,), _ = solve_constraints([T], constraints)
464-
465-
if sol is None: # solving failed, return AnyType fallback
466-
return AnyType(TypeOfAny.from_error)
467-
return self.make_iterable_type(sol)
468-
469441

470442
def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]:
471443
"""Check if all types are tuples of the same size.

test-data/unit/check-kwargs.test

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def test_union_variable_size_tuples(
724724
NESTED = Union[str, list[NESTED]]
725725
def test_union_recursive(x: Union[list[Union[NESTED, None]], list[NESTED]]) -> None:
726726
reveal_type( {*x} ) # N: Revealed type is "builtins.set[Union[builtins.str, builtins.list[Union[builtins.str, builtins.list[...]]], None]]"
727-
728727
[builtins fixtures/primitives.pyi]
729728

730729

@@ -870,6 +869,18 @@ def test_bad_case(
870869
# E: Argument 1 to "g" has incompatible type "*tuple[Union[A, B, C, None], ...]"; expected "Optional[D]"
871870
[builtins fixtures/tuple.pyi]
872871

872+
873+
[case testListExpressionWithListSubtypeStarArgs]
874+
# https://github.com/python/mypy/issues/19662
875+
class MyList(list[int]): ...
876+
877+
def test(x: MyList, y: list[int]) -> None:
878+
reveal_type( [*x] ) # N: Revealed type is "builtins.list[builtins.int]"
879+
reveal_type( [*y] ) # N: Revealed type is "builtins.list[builtins.int]"
880+
[builtins fixtures/list.pyi]
881+
882+
883+
873884
[case testPassingEmptyDictWithStars]
874885
def f(): pass
875886
def g(x=1): pass

0 commit comments

Comments
 (0)