3
3
from __future__ import annotations
4
4
5
5
from collections .abc import Sequence
6
- from typing import TYPE_CHECKING , Callable
7
- from typing_extensions import TypeGuard
6
+ from typing import TYPE_CHECKING , Callable , cast
7
+ from typing_extensions import NewType , TypeGuard
8
8
9
9
from mypy import nodes
10
10
from mypy .maptype import map_instance_to_supertype
11
11
from mypy .typeops import make_simplified_union
12
12
from mypy .types import (
13
13
AnyType ,
14
+ CallableType ,
14
15
Instance ,
15
16
ParamSpecType ,
16
17
ProperType ,
17
18
TupleType ,
18
19
Type ,
19
20
TypedDictType ,
20
21
TypeOfAny ,
22
+ TypeVarId ,
21
23
TypeVarTupleType ,
24
+ TypeVarType ,
22
25
UnionType ,
23
26
UnpackType ,
27
+ flatten_nested_tuples ,
24
28
get_proper_type ,
25
29
)
26
30
27
31
if TYPE_CHECKING :
28
32
from mypy .infer import ArgumentInferContext
29
33
30
34
35
+ IterableType = NewType ("IterableType" , Instance )
36
+ """Represents an instance of `Iterable[T]`."""
37
+
38
+
31
39
def map_actuals_to_formals (
32
40
actual_kinds : list [nodes .ArgKind ],
33
41
actual_names : Sequence [str | None ] | None ,
@@ -216,92 +224,41 @@ def expand_actual_type(
216
224
original_actual = actual_type
217
225
actual_type = get_proper_type (actual_type )
218
226
if actual_kind == nodes .ARG_STAR :
219
- if isinstance (actual_type , UnionType ):
220
- proper_types = [get_proper_type (t ) for t in actual_type .items ]
221
- # special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222
- if is_equal_sized_tuples (proper_types ):
223
- # transform union of tuples into a tuple of unions
224
- # e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225
- tuple_args : list [Type ] = [
226
- make_simplified_union (items )
227
- for items in zip (* (t .items for t in proper_types ))
228
- ]
229
- actual_type = TupleType (
230
- tuple_args ,
231
- # use Iterable[A | B | C] as the fallback type
232
- fallback = Instance (
233
- self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
234
- ),
235
- )
236
- else :
237
- # reinterpret all union items as iterable types (if possible)
238
- # and return the union of the iterable item types results.
239
- from mypy .subtypes import is_subtype
240
-
241
- iterable_type = self .context .iterable_type
242
-
243
- def as_iterable_type (t : Type ) -> Type :
244
- """Map a type to the iterable supertype if it is a subtype."""
245
- p_t = get_proper_type (t )
246
- if isinstance (p_t , Instance ) and is_subtype (t , iterable_type ):
247
- return map_instance_to_supertype (p_t , iterable_type .type )
248
- if isinstance (p_t , TupleType ):
249
- # Convert tuple[A, B, C] to Iterable[A | B | C].
250
- return Instance (iterable_type .type , [make_simplified_union (p_t .items )])
251
- return t
252
-
253
- # create copies of self for each item in the union
254
- sub_expanders = [
255
- ArgTypeExpander (context = self .context ) for _ in actual_type .items
256
- ]
257
- for expander in sub_expanders :
258
- expander .tuple_index = int (self .tuple_index )
259
- expander .kwargs_used = set (self .kwargs_used )
260
-
261
- candidate_type = make_simplified_union (
262
- [
263
- e .expand_actual_type (
264
- as_iterable_type (item ),
265
- actual_kind ,
266
- formal_name ,
267
- formal_kind ,
268
- allow_unpack ,
269
- )
270
- for e , item in zip (sub_expanders , actual_type .items )
271
- ]
272
- )
273
- assert all (expander == sub_expanders [0 ] for expander in sub_expanders )
274
- # carry over the new state if all sub-expanders are the same state
275
- self .tuple_index = int (sub_expanders [0 ].tuple_index )
276
- self .kwargs_used = set (sub_expanders [0 ].kwargs_used )
277
- return candidate_type
278
-
279
- if isinstance (actual_type , TypeVarTupleType ):
280
- # This code path is hit when *Ts is passed to a callable and various
281
- # special-handling didn't catch this. The best thing we can do is to use
282
- # the upper bound.
283
- actual_type = get_proper_type (actual_type .upper_bound )
284
- if isinstance (actual_type , Instance ) and actual_type .args :
285
- from mypy .subtypes import is_subtype
286
-
287
- if is_subtype (actual_type , self .context .iterable_type ):
288
- return map_instance_to_supertype (
289
- actual_type , self .context .iterable_type .type
290
- ).args [0 ]
291
- else :
292
- # We cannot properly unpack anything other
293
- # than `Iterable` type with `*`.
294
- # Just return `Any`, other parts of code would raise
295
- # a different error for improper use.
296
- return AnyType (TypeOfAny .from_error )
297
- elif isinstance (actual_type , TupleType ):
227
+ # parse *args as one of the following:
228
+ # IterableType | TupleType | ParamSpecType | AnyType
229
+ star_args = self .parse_star_args_type (actual_type )
230
+ # star_args = actual_type
231
+
232
+ # print(f"expand_actual_type: {actual_type=} {star_args=}")
233
+
234
+ # if isinstance(star_args, TypeVarTupleType):
235
+ # # This code path is hit when *Ts is passed to a callable and various
236
+ # # special-handling didn't catch this. The best thing we can do is to use
237
+ # # the upper bound.
238
+ # star_args = get_proper_type(star_args.upper_bound)
239
+ # if isinstance(star_args, Instance) and star_args.args:
240
+ # from mypy.subtypes import is_subtype
241
+ #
242
+ # if is_subtype(star_args, self.context.iterable_type):
243
+ # return map_instance_to_supertype(
244
+ # star_args, self.context.iterable_type.type
245
+ # ).args[0]
246
+ # else:
247
+ # # We cannot properly unpack anything other
248
+ # # than `Iterable` type with `*`.
249
+ # # Just return `Any`, other parts of code would raise
250
+ # # a different error for improper use.
251
+ # return AnyType(TypeOfAny.from_error)
252
+ if self .is_iterable_type (star_args ):
253
+ return star_args .args [0 ]
254
+ elif isinstance (star_args , TupleType ):
298
255
# Get the next tuple item of a tuple *arg.
299
- if self .tuple_index >= len (actual_type .items ):
256
+ if self .tuple_index >= len (star_args .items ):
300
257
# Exhausted a tuple -- continue to the next *args.
301
258
self .tuple_index = 1
302
259
else :
303
260
self .tuple_index += 1
304
- item = actual_type .items [self .tuple_index - 1 ]
261
+ item = star_args .items [self .tuple_index - 1 ]
305
262
if isinstance (item , UnpackType ) and not allow_unpack :
306
263
# An unpack item that doesn't have special handling, use upper bound as above.
307
264
unpacked = get_proper_type (item .type )
@@ -315,9 +272,9 @@ def as_iterable_type(t: Type) -> Type:
315
272
)
316
273
item = fallback .args [0 ]
317
274
return item
318
- elif isinstance (actual_type , ParamSpecType ):
275
+ elif isinstance (star_args , ParamSpecType ):
319
276
# ParamSpec is valid in *args but it can't be unpacked.
320
- return actual_type
277
+ return star_args
321
278
else :
322
279
return AnyType (TypeOfAny .from_error )
323
280
elif actual_kind == nodes .ARG_STAR2 :
@@ -349,19 +306,197 @@ def as_iterable_type(t: Type) -> Type:
349
306
# No translation for other kinds -- 1:1 mapping.
350
307
return original_actual
351
308
309
+ def is_iterable (self , typ : Type ) -> bool :
310
+ from mypy .subtypes import is_subtype
311
+
312
+ return is_subtype (typ , self .context .iterable_type )
313
+
314
+ def is_iterable_instance_subtype (self , typ : Type ) -> TypeGuard [Instance ]:
315
+ from mypy .subtypes import is_subtype
316
+
317
+ p_t = get_proper_type (typ )
318
+ return (
319
+ isinstance (p_t , Instance )
320
+ and bool (p_t .args )
321
+ and is_subtype (p_t , self .context .iterable_type )
322
+ )
323
+
324
+ def is_iterable_type (self , typ : Type ) -> TypeGuard [IterableType ]:
325
+ """Check if the type is an Iterable[T] or a subtype of it."""
326
+ p_t = get_proper_type (typ )
327
+ return isinstance (p_t , Instance ) and p_t .type == self .context .iterable_type .type
328
+
329
+ def as_iterable_type (self , typ : Type ) -> IterableType | AnyType :
330
+ """Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331
+ p_t = get_proper_type (typ )
332
+ if self .is_iterable_type (p_t ):
333
+ return p_t
334
+ elif self .is_iterable_instance_subtype (p_t ):
335
+ cls = self .context .iterable_type .type
336
+ return cast (IterableType , map_instance_to_supertype (p_t , cls ))
337
+ elif isinstance (p_t , UnionType ):
338
+ # If the type is a union, map each item to the iterable supertype.
339
+ # the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340
+ converted_types = [self .as_iterable_type (get_proper_type (item )) for item in p_t .items ]
341
+ # if an item could not be interpreted as Iterable[T], we return AnyType
342
+ if all (self .is_iterable_type (it ) for it in converted_types ):
343
+ # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344
+ iterable_types = cast (list [IterableType ], converted_types )
345
+ arg = make_simplified_union ([it .args [0 ] for it in iterable_types ])
346
+ return self .make_iterable_type (arg )
347
+ return AnyType (TypeOfAny .from_error )
348
+ elif isinstance (p_t , TupleType ):
349
+ # maps tuple[A, B, C] -> Iterable[A | B | C]
350
+ # note: proper_elements may contain UnpackType, for instance with
351
+ # tuple[None, *tuple[None, ...]]..
352
+ proper_elements = [get_proper_type (t ) for t in flatten_nested_tuples (p_t .items )]
353
+ args : list [Type ] = []
354
+ for p_e in proper_elements :
355
+ if isinstance (p_e , UnpackType ):
356
+ r = self .as_iterable_type (p_e )
357
+ if self .is_iterable_type (r ):
358
+ args .append (r .args [0 ])
359
+ else :
360
+ args .append (r )
361
+ else :
362
+ args .append (p_e )
363
+ return self .make_iterable_type (make_simplified_union (args ))
364
+ if isinstance (p_t , UnpackType ):
365
+ return self .as_iterable_type (p_t .type )
366
+ if isinstance (p_t , (TypeVarType , TypeVarTupleType )):
367
+ return self .as_iterable_type (p_t .upper_bound )
368
+ # fallback: use the solver to reinterpret the type as Iterable[T]
369
+ if self .is_iterable (p_t ):
370
+ return self ._solve_as_iterable (p_t )
371
+ return AnyType (TypeOfAny .from_error )
372
+
373
+ def make_iterable_type (self , arg : Type ) -> IterableType :
374
+ value = Instance (self .context .iterable_type .type , [arg ])
375
+ return cast (IterableType , value )
376
+
377
+ def parse_star_args_type (
378
+ self , typ : Type
379
+ ) -> TupleType | IterableType | ParamSpecType | AnyType :
380
+ """Parse the type of a *args argument.
381
+
382
+ Returns one TupleType, IterableType, ParamSpecType or AnyType.
383
+ """
384
+ p_t = get_proper_type (typ )
385
+ if isinstance (p_t , (TupleType , ParamSpecType , AnyType )):
386
+ # just return the type as-is
387
+ return p_t
388
+ elif isinstance (p_t , TypeVarTupleType ):
389
+ return self .parse_star_args_type (p_t .upper_bound )
390
+ elif isinstance (p_t , UnionType ):
391
+ proper_items = [get_proper_type (t ) for t in p_t .items ]
392
+ # consider 2 cases:
393
+ # 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None]
394
+ # In this case transform union of same-sized tuples into a tuple of unions
395
+ # e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None]
396
+ if is_equal_sized_tuples (proper_items ):
397
+
398
+ tuple_args : list [Type ] = [
399
+ make_simplified_union (items ) for items in zip (* (t .items for t in proper_items ))
400
+ ]
401
+ actual_type = TupleType (
402
+ tuple_args ,
403
+ # use Iterable[A | B | C] as the fallback type
404
+ fallback = Instance (
405
+ self .context .iterable_type .type , [UnionType .make_union (tuple_args )]
406
+ ),
407
+ )
408
+ return actual_type
409
+ # 2. Union of iterable types, e.g. Iterable[A] | Iterable[B]
410
+ # In this case return Iterable[A | B]
411
+ # Note that this covers unions of differently sized tuples as well.
412
+ else :
413
+ converted_types = [self .as_iterable_type (p_i ) for p_i in proper_items ]
414
+ if all (self .is_iterable_type (it ) for it in converted_types ):
415
+ # all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416
+ iterables = cast (list [IterableType ], converted_types )
417
+ arg = make_simplified_union ([it .args [0 ] for it in iterables ])
418
+ return self .make_iterable_type (arg )
419
+ else :
420
+ # some items in the union are not iterable, return AnyType
421
+ return AnyType (TypeOfAny .from_error )
422
+ elif self .is_iterable_type (parsed := self .as_iterable_type (p_t )):
423
+ # in all other cases, we try to reinterpret the type as Iterable[T]
424
+ return parsed
425
+ return AnyType (TypeOfAny .from_error )
426
+
427
+ def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
428
+ r"""Use the solver to cast a type as Iterable[T].
429
+
430
+ Returns the type as-is if solving fails.
431
+ """
432
+ from mypy .constraints import infer_constraints_for_callable
433
+ from mypy .nodes import ARG_POS
434
+ from mypy .solve import solve_constraints
435
+
436
+ iterable_kind = self .context .iterable_type .type
437
+
438
+ # We first create an upcast function:
439
+ # def [T] (Iterable[T]) -> Iterable[T]: ...
440
+ # and then solve for T, given the input type as the argument.
441
+ T = TypeVarType (
442
+ "T" ,
443
+ "T" ,
444
+ TypeVarId (- 1 ),
445
+ values = [],
446
+ upper_bound = AnyType (TypeOfAny .special_form ),
447
+ default = AnyType (TypeOfAny .special_form ),
448
+ )
449
+ target = Instance (iterable_kind , [T ])
450
+
451
+ upcast_callable = CallableType (
452
+ variables = [T ],
453
+ arg_types = [target ],
454
+ arg_kinds = [ARG_POS ],
455
+ arg_names = [None ],
456
+ ret_type = T ,
457
+ fallback = self .context .function_type ,
458
+ )
459
+ constraints = infer_constraints_for_callable (
460
+ upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
461
+ )
462
+
463
+ (sol ,), _ = solve_constraints ([T ], constraints )
464
+
465
+ if sol is None : # solving failed, return AnyType fallback
466
+ return AnyType (TypeOfAny .from_error )
467
+ return self .make_iterable_type (sol )
468
+
352
469
353
470
def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
354
- """Check if all types are tuples of the same size."""
471
+ """Check if all types are tuples of the same size.
472
+
473
+ We use `flatten_nested_tuples` to deal with nested tuples.
474
+ Note that the result may still contain
475
+ """
355
476
if not types :
356
477
return True
357
478
358
479
iterator = iter (types )
359
- first = next (iterator )
360
- if not isinstance (first , TupleType ):
480
+ typ = next (iterator )
481
+ if not isinstance (typ , TupleType ):
482
+ return False
483
+ flattened_elements = flatten_nested_tuples (typ .items )
484
+ if any (
485
+ isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
486
+ for member in flattened_elements
487
+ ):
488
+ # this can happen e.g. with tuple[int, *tuple[int, ...], int]
361
489
return False
362
- size = first . length ( )
490
+ size = len ( flattened_elements )
363
491
364
- for item in iterator :
365
- if not isinstance (item , TupleType ) or item .length () != size :
492
+ for typ in iterator :
493
+ if not isinstance (typ , TupleType ):
494
+ return False
495
+ flattened_elements = flatten_nested_tuples (typ .items )
496
+ if len (flattened_elements ) != size or any (
497
+ isinstance (get_proper_type (member ), (UnpackType , TypeVarTupleType ))
498
+ for member in flattened_elements
499
+ ):
500
+ # this can happen e.g. with tuple[int, *tuple[int, ...], int]
366
501
return False
367
502
return True
0 commit comments