4
4
5
5
from collections .abc import Sequence
6
6
from typing import TYPE_CHECKING , Callable
7
+ from typing_extensions import TypeGuard
7
8
8
9
from mypy import nodes
9
- from mypy .join import join_type_list
10
10
from mypy .maptype import map_instance_to_supertype
11
11
from mypy .typeops import make_simplified_union
12
12
from mypy .types import (
13
13
AnyType ,
14
14
Instance ,
15
15
ParamSpecType ,
16
+ ProperType ,
16
17
TupleType ,
17
18
Type ,
18
19
TypedDictType ,
@@ -62,10 +63,11 @@ def map_actuals_to_formals(
62
63
if (
63
64
isinstance (actualt , UnionType )
64
65
and actualt .items
65
- and is_equal_sized_tuples (actualt .items )
66
+ and is_equal_sized_tuples (
67
+ proper_types := [get_proper_type (t ) for t in actualt .items ]
68
+ )
66
69
):
67
- # Arbitrarily pick the first item in the union.
68
- actualt = get_proper_type (actualt .items [0 ])
70
+ actualt = proper_types [0 ]
69
71
if isinstance (actualt , TupleType ):
70
72
# A tuple actual maps to a fixed number of formals.
71
73
for _ in range (len (actualt .items )):
@@ -215,18 +217,38 @@ def expand_actual_type(
215
217
actual_type = get_proper_type (actual_type )
216
218
if actual_kind == nodes .ARG_STAR :
217
219
if isinstance (actual_type , UnionType ):
218
- # special case 1: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
219
- # special case 2: union contains no static sized tuples. (e.g. `list[str | None] | list[str]`)
220
- if is_equal_sized_tuples (actual_type .items ) or not any (
221
- isinstance (get_proper_type (t ), TupleType ) for t in actual_type .items
222
- ):
223
- # If the actual type is a union, try expanding it.
224
- # Example: f(*args), where args is `list[str | None] | list[str]`,
225
- # Example: f(*args), where args is `tuple[A, B, C] | tuple[None, None, None]`
226
- # Note: there is potential for combinatorial explosion here:
227
- # f(*x1, *x2, .. *xn), if xₖ is a union of nₖ differently sized tuples,
228
- # then there are n₁ * n₂ * ... * nₖ possible combinations of pointer positions.
229
- # therefore, we only take this branch if all union members consume the same number of items.
220
+ proper_types = [get_proper_type (t ) for t in actual_type .items ]
221
+ # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222
+ if is_equal_sized_tuples (proper_types ):
223
+ # transform union of tuples into a tuple of unions
224
+ # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225
+ tuple_args : list [Type ] = [
226
+ make_simplified_union (items )
227
+ for items in zip (* (t .items for t in proper_types ))
228
+ ]
229
+ actual_type = TupleType (
230
+ tuple_args ,
231
+ # use Iterable[A | B | C] as the fallback type
232
+ fallback = Instance (
233
+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234
+ ),
235
+ )
236
+ else :
237
+ # reinterpret all union items as iterable types (if possible)
238
+ # and return the union of the iterable item types results.
239
+ from mypy .subtypes import is_subtype
240
+
241
+ iterable_type = self .context .iterable_type
242
+
243
+ def as_iterable_type (t : Type ) -> Type :
244
+ """Map a type to the iterable supertype if it is a subtype."""
245
+ p_t = get_proper_type (t )
246
+ if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247
+ return map_instance_to_supertype (p_t , iterable_type .type )
248
+ if isinstance (p_t , TupleType ):
249
+ # Convert tuple[A, B, C] to Iterable[A | B | C].
250
+ return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251
+ return t
230
252
231
253
# create copies of self for each item in the union
232
254
sub_expanders = [
@@ -239,7 +261,11 @@ def expand_actual_type(
239
261
candidate_type = make_simplified_union (
240
262
[
241
263
e .expand_actual_type (
242
- item , actual_kind , formal_name , formal_kind , allow_unpack
264
+ as_iterable_type (item ),
265
+ actual_kind ,
266
+ formal_name ,
267
+ formal_kind ,
268
+ allow_unpack ,
243
269
)
244
270
for e , item in zip (sub_expanders , actual_type .items )
245
271
]
@@ -249,28 +275,6 @@ def expand_actual_type(
249
275
self .tuple_index = int (sub_expanders [0 ].tuple_index )
250
276
self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
251
277
return candidate_type
252
- else :
253
- # otherwise, we fall back to checking using the join of the union members.
254
- # for better results we first map all instances to Iterable[T]
255
- from mypy .subtypes import is_subtype
256
-
257
- iterable_type = self .context .iterable_type
258
-
259
- def as_iterable_type (t : Type ) -> Type :
260
- """Map a type to the iterable supertype if it is a subtype."""
261
- p_t = get_proper_type (t )
262
- if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
263
- return map_instance_to_supertype (p_t , iterable_type .type )
264
- if isinstance (p_t , TupleType ):
265
- # Convert tuple[A, B, C] to Iterable[A | B | C].
266
- return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
267
- return t
268
-
269
- joined_type = join_type_list ([as_iterable_type (t ) for t in actual_type .items ])
270
- assert not isinstance (get_proper_type (joined_type ), TupleType )
271
- return self .expand_actual_type (
272
- joined_type , actual_kind , formal_name , formal_kind , allow_unpack
273
- )
274
278
275
279
if isinstance (actual_type , TypeVarTupleType ):
276
280
# This code path is hit when *Ts is passed to a callable and various
@@ -346,19 +350,18 @@ def as_iterable_type(t: Type) -> Type:
346
350
return original_actual
347
351
348
352
349
- def is_equal_sized_tuples (types : Sequence [Type ]) -> bool :
353
+ def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [ Sequence [ TupleType ]] :
350
354
"""Check if all types are tuples of the same size."""
351
355
if not types :
352
356
return True
353
357
354
358
iterator = iter (types )
355
- first = get_proper_type ( next (iterator ) )
359
+ first = next (iterator )
356
360
if not isinstance (first , TupleType ):
357
361
return False
358
362
size = first .length ()
359
363
360
364
for item in iterator :
361
- p_t = get_proper_type (item )
362
- if not isinstance (p_t , TupleType ) or p_t .length () != size :
365
+ if not isinstance (item , TupleType ) or item .length () != size :
363
366
return False
364
367
return True
0 commit comments