Skip to content

Commit 7146a29

Browse files
simplified away second branch
1 parent 0e76d9e commit 7146a29

File tree

1 file changed

+46
-43
lines changed

1 file changed

+46
-43
lines changed

mypy/argmap.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
from collections.abc import Sequence
66
from typing import TYPE_CHECKING, Callable
7+
from typing_extensions import TypeGuard
78

89
from mypy import nodes
9-
from mypy.join import join_type_list
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy.typeops import make_simplified_union
1212
from mypy.types import (
1313
AnyType,
1414
Instance,
1515
ParamSpecType,
16+
ProperType,
1617
TupleType,
1718
Type,
1819
TypedDictType,
@@ -62,10 +63,11 @@ def map_actuals_to_formals(
6263
if (
6364
isinstance(actualt, UnionType)
6465
and actualt.items
65-
and is_equal_sized_tuples(actualt.items)
66+
and is_equal_sized_tuples(
67+
proper_types := [get_proper_type(t) for t in actualt.items]
68+
)
6669
):
67-
# Arbitrarily pick the first item in the union.
68-
actualt = get_proper_type(actualt.items[0])
70+
actualt = proper_types[0]
6971
if isinstance(actualt, TupleType):
7072
# A tuple actual maps to a fixed number of formals.
7173
for _ in range(len(actualt.items)):
@@ -215,18 +217,38 @@ def expand_actual_type(
215217
actual_type = get_proper_type(actual_type)
216218
if actual_kind == nodes.ARG_STAR:
217219
if isinstance(actual_type, UnionType):
218-
# special case 1: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
219-
# special case 2: union contains no static sized tuples. (e.g. `list[str | None] | list[str]`)
220-
if is_equal_sized_tuples(actual_type.items) or not any(
221-
isinstance(get_proper_type(t), TupleType) for t in actual_type.items
222-
):
223-
# If the actual type is a union, try expanding it.
224-
# Example: f(*args), where args is `list[str | None] | list[str]`,
225-
# Example: f(*args), where args is `tuple[A, B, C] | tuple[None, None, None]`
226-
# Note: there is potential for combinatorial explosion here:
227-
# f(*x1, *x2, .. *xn), if xₖ is a union of nₖ differently sized tuples,
228-
# then there are n₁ * n₂ * ... * nₖ possible combinations of pointer positions.
229-
# therefore, we only take this branch if all union members consume the same number of items.
220+
proper_types = [get_proper_type(t) for t in actual_type.items]
221+
# special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222+
if is_equal_sized_tuples(proper_types):
223+
# transform union of tuples into a tuple of unions
224+
# e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225+
tuple_args: list[Type] = [
226+
make_simplified_union(items)
227+
for items in zip(*(t.items for t in proper_types))
228+
]
229+
actual_type = TupleType(
230+
tuple_args,
231+
# use Iterable[A | B | C] as the fallback type
232+
fallback=Instance(
233+
self.context.iterable_type.type, [UnionType.make_union(tuple_args)]
234+
),
235+
)
236+
else:
237+
# reinterpret all union items as iterable types (if possible)
238+
# and return the union of the iterable item types results.
239+
from mypy.subtypes import is_subtype
240+
241+
iterable_type = self.context.iterable_type
242+
243+
def as_iterable_type(t: Type) -> Type:
244+
"""Map a type to the iterable supertype if it is a subtype."""
245+
p_t = get_proper_type(t)
246+
if isinstance(p_t, Instance) and is_subtype(t, iterable_type):
247+
return map_instance_to_supertype(p_t, iterable_type.type)
248+
if isinstance(p_t, TupleType):
249+
# Convert tuple[A, B, C] to Iterable[A | B | C].
250+
return Instance(iterable_type.type, [make_simplified_union(p_t.items)])
251+
return t
230252

231253
# create copies of self for each item in the union
232254
sub_expanders = [
@@ -239,7 +261,11 @@ def expand_actual_type(
239261
candidate_type = make_simplified_union(
240262
[
241263
e.expand_actual_type(
242-
item, actual_kind, formal_name, formal_kind, allow_unpack
264+
as_iterable_type(item),
265+
actual_kind,
266+
formal_name,
267+
formal_kind,
268+
allow_unpack,
243269
)
244270
for e, item in zip(sub_expanders, actual_type.items)
245271
]
@@ -249,28 +275,6 @@ def expand_actual_type(
249275
self.tuple_index = int(sub_expanders[0].tuple_index)
250276
self.kwargs_used = set(sub_expanders[0].kwargs_used)
251277
return candidate_type
252-
else:
253-
# otherwise, we fall back to checking using the join of the union members.
254-
# for better results we first map all instances to Iterable[T]
255-
from mypy.subtypes import is_subtype
256-
257-
iterable_type = self.context.iterable_type
258-
259-
def as_iterable_type(t: Type) -> Type:
260-
"""Map a type to the iterable supertype if it is a subtype."""
261-
p_t = get_proper_type(t)
262-
if isinstance(p_t, Instance) and is_subtype(t, iterable_type):
263-
return map_instance_to_supertype(p_t, iterable_type.type)
264-
if isinstance(p_t, TupleType):
265-
# Convert tuple[A, B, C] to Iterable[A | B | C].
266-
return Instance(iterable_type.type, [make_simplified_union(p_t.items)])
267-
return t
268-
269-
joined_type = join_type_list([as_iterable_type(t) for t in actual_type.items])
270-
assert not isinstance(get_proper_type(joined_type), TupleType)
271-
return self.expand_actual_type(
272-
joined_type, actual_kind, formal_name, formal_kind, allow_unpack
273-
)
274278

275279
if isinstance(actual_type, TypeVarTupleType):
276280
# This code path is hit when *Ts is passed to a callable and various
@@ -346,19 +350,18 @@ def as_iterable_type(t: Type) -> Type:
346350
return original_actual
347351

348352

349-
def is_equal_sized_tuples(types: Sequence[Type]) -> bool:
353+
def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]:
350354
"""Check if all types are tuples of the same size."""
351355
if not types:
352356
return True
353357

354358
iterator = iter(types)
355-
first = get_proper_type(next(iterator))
359+
first = next(iterator)
356360
if not isinstance(first, TupleType):
357361
return False
358362
size = first.length()
359363

360364
for item in iterator:
361-
p_t = get_proper_type(item)
362-
if not isinstance(p_t, TupleType) or p_t.length() != size:
365+
if not isinstance(item, TupleType) or item.length() != size:
363366
return False
364367
return True

0 commit comments

Comments
 (0)