4
4
5
5
from collections .abc import Sequence
6
6
from typing import TYPE_CHECKING , Callable
7
+ from typing_extensions import TypeGuard
7
8
8
9
from mypy import nodes
9
10
from mypy .maptype import map_instance_to_supertype
11
+ from mypy .typeops import make_simplified_union
10
12
from mypy .types import (
11
13
AnyType ,
12
14
Instance ,
13
15
ParamSpecType ,
16
+ ProperType ,
14
17
TupleType ,
15
18
Type ,
16
19
TypedDictType ,
17
20
TypeOfAny ,
18
21
TypeVarTupleType ,
22
+ UnionType ,
19
23
UnpackType ,
20
24
get_proper_type ,
21
25
)
@@ -54,6 +58,16 @@ def map_actuals_to_formals(
54
58
elif actual_kind == nodes .ARG_STAR :
55
59
# We need to know the actual type to map varargs.
56
60
actualt = get_proper_type (actual_arg_type (ai ))
61
+
62
+ # Special case for union of equal sized tuples.
63
+ if (
64
+ isinstance (actualt , UnionType )
65
+ and actualt .items
66
+ and is_equal_sized_tuples (
67
+ proper_types := [get_proper_type (t ) for t in actualt .items ]
68
+ )
69
+ ):
70
+ actualt = proper_types [0 ]
57
71
if isinstance (actualt , TupleType ):
58
72
# A tuple actual maps to a fixed number of formals.
59
73
for _ in range (len (actualt .items )):
@@ -171,6 +185,15 @@ def __init__(self, context: ArgumentInferContext) -> None:
171
185
# Type context for `*` and `**` arg kinds.
172
186
self .context = context
173
187
188
+ def __eq__ (self , other : object ) -> bool :
189
+ if isinstance (other , ArgTypeExpander ):
190
+ return (
191
+ self .tuple_index == other .tuple_index
192
+ and self .kwargs_used == other .kwargs_used
193
+ and self .context == other .context
194
+ )
195
+ return NotImplemented
196
+
174
197
def expand_actual_type (
175
198
self ,
176
199
actual_type : Type ,
@@ -193,6 +216,66 @@ def expand_actual_type(
193
216
original_actual = actual_type
194
217
actual_type = get_proper_type (actual_type )
195
218
if actual_kind == nodes .ARG_STAR :
219
+ if isinstance (actual_type , UnionType ):
220
+ proper_types = [get_proper_type (t ) for t in actual_type .items ]
221
+ # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222
+ if is_equal_sized_tuples (proper_types ):
223
+ # transform union of tuples into a tuple of unions
224
+ # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225
+ tuple_args : list [Type ] = [
226
+ make_simplified_union (items )
227
+ for items in zip (* (t .items for t in proper_types ))
228
+ ]
229
+ actual_type = TupleType (
230
+ tuple_args ,
231
+ # use Iterable[A | B | C] as the fallback type
232
+ fallback = Instance (
233
+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234
+ ),
235
+ )
236
+ else :
237
+ # reinterpret all union items as iterable types (if possible)
238
+ # and return the union of the iterable item types results.
239
+ from mypy .subtypes import is_subtype
240
+
241
+ iterable_type = self .context .iterable_type
242
+
243
+ def as_iterable_type (t : Type ) -> Type :
244
+ """Map a type to the iterable supertype if it is a subtype."""
245
+ p_t = get_proper_type (t )
246
+ if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247
+ return map_instance_to_supertype (p_t , iterable_type .type )
248
+ if isinstance (p_t , TupleType ):
249
+ # Convert tuple[A, B, C] to Iterable[A | B | C].
250
+ return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251
+ return t
252
+
253
+ # create copies of self for each item in the union
254
+ sub_expanders = [
255
+ ArgTypeExpander (context = self .context ) for _ in actual_type .items
256
+ ]
257
+ for expander in sub_expanders :
258
+ expander .tuple_index = int (self .tuple_index )
259
+ expander .kwargs_used = set (self .kwargs_used )
260
+
261
+ candidate_type = make_simplified_union (
262
+ [
263
+ e .expand_actual_type (
264
+ as_iterable_type (item ),
265
+ actual_kind ,
266
+ formal_name ,
267
+ formal_kind ,
268
+ allow_unpack ,
269
+ )
270
+ for e , item in zip (sub_expanders , actual_type .items )
271
+ ]
272
+ )
273
+ assert all (expander == sub_expanders [0 ] for expander in sub_expanders )
274
+ # carry over the new state if all sub-expanders are the same state
275
+ self .tuple_index = int (sub_expanders [0 ].tuple_index )
276
+ self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
277
+ return candidate_type
278
+
196
279
if isinstance (actual_type , TypeVarTupleType ):
197
280
# This code path is hit when *Ts is passed to a callable and various
198
281
# special-handling didn't catch this. The best thing we can do is to use
@@ -265,3 +348,20 @@ def expand_actual_type(
265
348
else :
266
349
# No translation for other kinds -- 1:1 mapping.
267
350
return original_actual
351
+
352
+
353
+ def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
354
+ """Check if all types are tuples of the same size."""
355
+ if not types :
356
+ return True
357
+
358
+ iterator = iter (types )
359
+ first = next (iterator )
360
+ if not isinstance (first , TupleType ):
361
+ return False
362
+ size = first .length ()
363
+
364
+ for item in iterator :
365
+ if not isinstance (item , TupleType ) or item .length () != size :
366
+ return False
367
+ return True
0 commit comments