Skip to content

Commit 049afbc

Browse files
rework starargs with union argument
1 parent 5a78607 commit 049afbc

File tree

4 files changed

+424
-1
lines changed

4 files changed

+424
-1
lines changed

mypy/argmap.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44

55
from collections.abc import Sequence
66
from typing import TYPE_CHECKING, Callable
7+
from typing_extensions import TypeGuard
78

89
from mypy import nodes
910
from mypy.maptype import map_instance_to_supertype
11+
from mypy.typeops import make_simplified_union
1012
from mypy.types import (
1113
AnyType,
1214
Instance,
1315
ParamSpecType,
16+
ProperType,
1417
TupleType,
1518
Type,
1619
TypedDictType,
1720
TypeOfAny,
1821
TypeVarTupleType,
22+
UnionType,
1923
UnpackType,
2024
get_proper_type,
2125
)
@@ -54,6 +58,16 @@ def map_actuals_to_formals(
5458
elif actual_kind == nodes.ARG_STAR:
5559
# We need to know the actual type to map varargs.
5660
actualt = get_proper_type(actual_arg_type(ai))
61+
62+
# Special case for union of equal sized tuples.
63+
if (
64+
isinstance(actualt, UnionType)
65+
and actualt.items
66+
and is_equal_sized_tuples(
67+
proper_types := [get_proper_type(t) for t in actualt.items]
68+
)
69+
):
70+
actualt = proper_types[0]
5771
if isinstance(actualt, TupleType):
5872
# A tuple actual maps to a fixed number of formals.
5973
for _ in range(len(actualt.items)):
@@ -171,6 +185,15 @@ def __init__(self, context: ArgumentInferContext) -> None:
171185
# Type context for `*` and `**` arg kinds.
172186
self.context = context
173187

188+
def __eq__(self, other: object) -> bool:
189+
if isinstance(other, ArgTypeExpander):
190+
return (
191+
self.tuple_index == other.tuple_index
192+
and self.kwargs_used == other.kwargs_used
193+
and self.context == other.context
194+
)
195+
return NotImplemented
196+
174197
def expand_actual_type(
175198
self,
176199
actual_type: Type,
@@ -193,6 +216,66 @@ def expand_actual_type(
193216
original_actual = actual_type
194217
actual_type = get_proper_type(actual_type)
195218
if actual_kind == nodes.ARG_STAR:
219+
if isinstance(actual_type, UnionType):
220+
proper_types = [get_proper_type(t) for t in actual_type.items]
221+
# special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222+
if is_equal_sized_tuples(proper_types):
223+
# transform union of tuples into a tuple of unions
224+
# e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225+
tuple_args: list[Type] = [
226+
make_simplified_union(items)
227+
for items in zip(*(t.items for t in proper_types))
228+
]
229+
actual_type = TupleType(
230+
tuple_args,
231+
# use Iterable[A | B | C] as the fallback type
232+
fallback=Instance(
233+
self.context.iterable_type.type, [UnionType.make_union(tuple_args)]
234+
),
235+
)
236+
else:
237+
# reinterpret all union items as iterable types (if possible)
238+
# and return the union of the iterable item types results.
239+
from mypy.subtypes import is_subtype
240+
241+
iterable_type = self.context.iterable_type
242+
243+
def as_iterable_type(t: Type) -> Type:
244+
"""Map a type to the iterable supertype if it is a subtype."""
245+
p_t = get_proper_type(t)
246+
if isinstance(p_t, Instance) and is_subtype(t, iterable_type):
247+
return map_instance_to_supertype(p_t, iterable_type.type)
248+
if isinstance(p_t, TupleType):
249+
# Convert tuple[A, B, C] to Iterable[A | B | C].
250+
return Instance(iterable_type.type, [make_simplified_union(p_t.items)])
251+
return t
252+
253+
# create copies of self for each item in the union
254+
sub_expanders = [
255+
ArgTypeExpander(context=self.context) for _ in actual_type.items
256+
]
257+
for expander in sub_expanders:
258+
expander.tuple_index = int(self.tuple_index)
259+
expander.kwargs_used = set(self.kwargs_used)
260+
261+
candidate_type = make_simplified_union(
262+
[
263+
e.expand_actual_type(
264+
as_iterable_type(item),
265+
actual_kind,
266+
formal_name,
267+
formal_kind,
268+
allow_unpack,
269+
)
270+
for e, item in zip(sub_expanders, actual_type.items)
271+
]
272+
)
273+
assert all(expander == sub_expanders[0] for expander in sub_expanders)
274+
# carry over the new state if all sub-expanders are the same state
275+
self.tuple_index = int(sub_expanders[0].tuple_index)
276+
self.kwargs_used = set(sub_expanders[0].kwargs_used)
277+
return candidate_type
278+
196279
if isinstance(actual_type, TypeVarTupleType):
197280
# This code path is hit when *Ts is passed to a callable and various
198281
# special-handling didn't catch this. The best thing we can do is to use
@@ -265,3 +348,20 @@ def expand_actual_type(
265348
else:
266349
# No translation for other kinds -- 1:1 mapping.
267350
return original_actual
351+
352+
353+
def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]:
354+
"""Check if all types are tuples of the same size."""
355+
if not types:
356+
return True
357+
358+
iterator = iter(types)
359+
first = next(iterator)
360+
if not isinstance(first, TupleType):
361+
return False
362+
size = first.length()
363+
364+
for item in iterator:
365+
if not isinstance(item, TupleType) or item.length() != size:
366+
return False
367+
return True

mypy/checkexpr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
freshen_function_type_vars,
2828
)
2929
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
30+
from mypy.join import join_type_list
3031
from mypy.literals import literal
3132
from mypy.maptype import map_instance_to_supertype
3233
from mypy.meet import is_overlapping_types, narrow_declared_type
@@ -5227,6 +5228,11 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
52275228
ctx = None
52285229
tt = self.accept(item.expr, ctx)
52295230
tt = get_proper_type(tt)
5231+
if isinstance(tt, UnionType):
5232+
# Coercing union to join allows better inference in some
5233+
# special cases like `tuple[A, B] | tuple[C, D]`
5234+
tt = get_proper_type(join_type_list(tt.items))
5235+
52305236
if isinstance(tt, TupleType):
52315237
if find_unpack_in_list(tt.items) is not None:
52325238
if seen_unpack_in_items:

0 commit comments

Comments
 (0)