@@ -75,6 +75,7 @@ def map_actuals_to_formals(
75
75
proper_types := [get_proper_type (t ) for t in actualt .items ]
76
76
)
77
77
):
78
+ # pick an arbitrary member
78
79
actualt = proper_types [0 ]
79
80
if isinstance (actualt , TupleType ):
80
81
# A tuple actual maps to a fixed number of formals.
@@ -193,15 +194,6 @@ def __init__(self, context: ArgumentInferContext) -> None:
193
194
# Type context for `*` and `**` arg kinds.
194
195
self .context = context
195
196
196
- def __eq__ (self , other : object ) -> bool :
197
- if isinstance (other , ArgTypeExpander ):
198
- return (
199
- self .tuple_index == other .tuple_index
200
- and self .kwargs_used == other .kwargs_used
201
- and self .context == other .context
202
- )
203
- return NotImplemented
204
-
205
197
def expand_actual_type (
206
198
self ,
207
199
actual_type : Type ,
@@ -227,29 +219,8 @@ def expand_actual_type(
227
219
# parse *args as one of the following:
228
220
# IterableType | TupleType | ParamSpecType | AnyType
229
221
star_args = self .parse_star_args_type (actual_type )
230
- # star_args = actual_type
231
-
232
- # print(f"expand_actual_type: {actual_type=} {star_args=}")
233
-
234
- # if isinstance(star_args, TypeVarTupleType):
235
- # # This code path is hit when *Ts is passed to a callable and various
236
- # # special-handling didn't catch this. The best thing we can do is to use
237
- # # the upper bound.
238
- # star_args = get_proper_type(star_args.upper_bound)
239
- # if isinstance(star_args, Instance) and star_args.args:
240
- # from mypy.subtypes import is_subtype
241
- #
242
- # if is_subtype(star_args, self.context.iterable_type):
243
- # return map_instance_to_supertype(
244
- # star_args, self.context.iterable_type.type
245
- # ).args[0]
246
- # else:
247
- # # We cannot properly unpack anything other
248
- # # than `Iterable` type with `*`.
249
- # # Just return `Any`, other parts of code would raise
250
- # # a different error for improper use.
251
- # return AnyType(TypeOfAny.from_error)
252
- if self .is_iterable_type (star_args ):
222
+
223
+ if self .is_iterable_instance_type (star_args ):
253
224
return star_args .args [0 ]
254
225
elif isinstance (star_args , TupleType ):
255
226
# Get the next tuple item of a tuple *arg.
@@ -321,30 +292,75 @@ def is_iterable_instance_subtype(self, typ: Type) -> TypeGuard[Instance]:
321
292
and is_subtype (p_t , self .context .iterable_type )
322
293
)
323
294
324
- def is_iterable_type (self , typ : Type ) -> TypeGuard [IterableType ]:
295
+ def is_iterable_instance_type (self , typ : Type ) -> TypeGuard [IterableType ]:
325
296
"""Check if the type is an Iterable[T] or a subtype of it."""
326
297
p_t = get_proper_type (typ )
327
298
return isinstance (p_t , Instance ) and p_t .type == self .context .iterable_type .type
328
299
300
+ def _make_iterable_instance_type (self , arg : Type ) -> IterableType :
301
+ value = Instance (self .context .iterable_type .type , [arg ])
302
+ return cast (IterableType , value )
303
+
304
+ def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
305
+ r"""Use the solver to cast a type as Iterable[T].
306
+
307
+ Returns `AnyType` if solving fails.
308
+ """
309
+ from mypy .constraints import infer_constraints_for_callable
310
+ from mypy .nodes import ARG_POS
311
+ from mypy .solve import solve_constraints
312
+
313
+ iterable_kind = self .context .iterable_type .type
314
+
315
+ # We first create an upcast function:
316
+ # def [T] (Iterable[T]) -> Iterable[T]: ...
317
+ # and then solve for T, given the input type as the argument.
318
+ T = TypeVarType (
319
+ "T" ,
320
+ "T" ,
321
+ TypeVarId (- 1 ),
322
+ values = [],
323
+ upper_bound = AnyType (TypeOfAny .special_form ),
324
+ default = AnyType (TypeOfAny .special_form ),
325
+ )
326
+ target = Instance (iterable_kind , [T ])
327
+
328
+ upcast_callable = CallableType (
329
+ variables = [T ],
330
+ arg_types = [target ],
331
+ arg_kinds = [ARG_POS ],
332
+ arg_names = [None ],
333
+ ret_type = T ,
334
+ fallback = self .context .function_type ,
335
+ )
336
+ constraints = infer_constraints_for_callable (
337
+ upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
338
+ )
339
+
340
+ (sol ,), _ = solve_constraints ([T ], constraints )
341
+
342
+ if sol is None : # solving failed, return AnyType fallback
343
+ return AnyType (TypeOfAny .from_error )
344
+ return self ._make_iterable_instance_type (sol )
345
+
329
346
def as_iterable_type (self , typ : Type ) -> IterableType | AnyType :
330
347
"""Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331
348
p_t = get_proper_type (typ )
332
- if self .is_iterable_type (p_t ):
349
+ if self .is_iterable_instance_type (p_t ) or isinstance ( p_t , AnyType ):
333
350
return p_t
334
- elif self .is_iterable_instance_subtype (p_t ):
335
- cls = self .context .iterable_type .type
336
- return cast (IterableType , map_instance_to_supertype (p_t , cls ))
337
351
elif isinstance (p_t , UnionType ):
338
352
# If the type is a union, map each item to the iterable supertype.
339
353
# the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340
354
converted_types = [self .as_iterable_type (get_proper_type (item )) for item in p_t .items ]
341
- # if an item could not be interpreted as Iterable[T], we return AnyType
342
- if all (self .is_iterable_type (it ) for it in converted_types ):
355
+
356
+ if any (not self .is_iterable_instance_type (it ) for it in converted_types ):
357
+ # if any item could not be interpreted as Iterable[T], we return AnyType
358
+ return AnyType (TypeOfAny .from_error )
359
+ else :
343
360
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344
361
iterable_types = cast (list [IterableType ], converted_types )
345
362
arg = make_simplified_union ([it .args [0 ] for it in iterable_types ])
346
- return self .make_iterable_type (arg )
347
- return AnyType (TypeOfAny .from_error )
363
+ return self ._make_iterable_instance_type (arg )
348
364
elif isinstance (p_t , TupleType ):
349
365
# maps tuple[A, B, C] -> Iterable[A | B | C]
350
366
# note: proper_elements may contain UnpackType, for instance with
@@ -354,26 +370,24 @@ def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
354
370
for p_e in proper_elements :
355
371
if isinstance (p_e , UnpackType ):
356
372
r = self .as_iterable_type (p_e )
357
- if self .is_iterable_type (r ):
373
+ if self .is_iterable_instance_type (r ):
358
374
args .append (r .args [0 ])
359
375
else :
376
+ # this *should* never happen
360
377
args .append (r )
361
378
else :
362
379
args .append (p_e )
363
- return self .make_iterable_type (make_simplified_union (args ))
364
- if isinstance (p_t , UnpackType ):
380
+ return self ._make_iterable_instance_type (make_simplified_union (args ))
381
+ elif isinstance (p_t , UnpackType ):
365
382
return self .as_iterable_type (p_t .type )
366
- if isinstance (p_t , (TypeVarType , TypeVarTupleType )):
383
+ elif isinstance (p_t , (TypeVarType , TypeVarTupleType )):
367
384
return self .as_iterable_type (p_t .upper_bound )
368
- # fallback: use the solver to reinterpret the type as Iterable[T]
369
- if self .is_iterable (p_t ):
385
+ elif self .is_iterable (p_t ):
386
+ # TODO: add a 'fast path' (needs measurement) that uses the map_instance_to_supertype
387
+ # mechanism? (Only if it works: gh-19662)
370
388
return self ._solve_as_iterable (p_t )
371
389
return AnyType (TypeOfAny .from_error )
372
390
373
- def make_iterable_type (self , arg : Type ) -> IterableType :
374
- value = Instance (self .context .iterable_type .type , [arg ])
375
- return cast (IterableType , value )
376
-
377
391
def parse_star_args_type (
378
392
self , typ : Type
379
393
) -> TupleType | IterableType | ParamSpecType | AnyType :
@@ -411,61 +425,19 @@ def parse_star_args_type(
411
425
# Note that this covers unions of differently sized tuples as well.
412
426
else :
413
427
converted_types = [self .as_iterable_type (p_i ) for p_i in proper_items ]
414
- if all (self .is_iterable_type (it ) for it in converted_types ):
428
+ if all (self .is_iterable_instance_type (it ) for it in converted_types ):
415
429
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416
430
iterables = cast (list [IterableType ], converted_types )
417
431
arg = make_simplified_union ([it .args [0 ] for it in iterables ])
418
- return self .make_iterable_type (arg )
432
+ return self ._make_iterable_instance_type (arg )
419
433
else :
420
434
# some items in the union are not iterable, return AnyType
421
435
return AnyType (TypeOfAny .from_error )
422
- elif self .is_iterable_type (parsed := self .as_iterable_type (p_t )):
436
+ elif self .is_iterable_instance_type (parsed := self .as_iterable_type (p_t )):
423
437
# in all other cases, we try to reinterpret the type as Iterable[T]
424
438
return parsed
425
439
return AnyType (TypeOfAny .from_error )
426
440
427
- def _solve_as_iterable (self , typ : Type ) -> IterableType | AnyType :
428
- r"""Use the solver to cast a type as Iterable[T].
429
-
430
- Returns the type as-is if solving fails.
431
- """
432
- from mypy .constraints import infer_constraints_for_callable
433
- from mypy .nodes import ARG_POS
434
- from mypy .solve import solve_constraints
435
-
436
- iterable_kind = self .context .iterable_type .type
437
-
438
- # We first create an upcast function:
439
- # def [T] (Iterable[T]) -> Iterable[T]: ...
440
- # and then solve for T, given the input type as the argument.
441
- T = TypeVarType (
442
- "T" ,
443
- "T" ,
444
- TypeVarId (- 1 ),
445
- values = [],
446
- upper_bound = AnyType (TypeOfAny .special_form ),
447
- default = AnyType (TypeOfAny .special_form ),
448
- )
449
- target = Instance (iterable_kind , [T ])
450
-
451
- upcast_callable = CallableType (
452
- variables = [T ],
453
- arg_types = [target ],
454
- arg_kinds = [ARG_POS ],
455
- arg_names = [None ],
456
- ret_type = T ,
457
- fallback = self .context .function_type ,
458
- )
459
- constraints = infer_constraints_for_callable (
460
- upcast_callable , [typ ], [ARG_POS ], [None ], [[0 ]], context = self .context
461
- )
462
-
463
- (sol ,), _ = solve_constraints ([T ], constraints )
464
-
465
- if sol is None : # solving failed, return AnyType fallback
466
- return AnyType (TypeOfAny .from_error )
467
- return self .make_iterable_type (sol )
468
-
469
441
470
442
def is_equal_sized_tuples (types : Sequence [ProperType ]) -> TypeGuard [Sequence [TupleType ]]:
471
443
"""Check if all types are tuples of the same size.
0 commit comments