Skip to content

Commit 9c27701

Browse files
reworked argmapper with star args
1 parent a6f4fcc commit 9c27701

File tree

4 files changed

+414
-205
lines changed

4 files changed

+414
-205
lines changed

mypy/argmap.py

Lines changed: 226 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,39 @@
33
from __future__ import annotations
44

55
from collections.abc import Sequence
6-
from typing import TYPE_CHECKING, Callable
7-
from typing_extensions import TypeGuard
6+
from typing import TYPE_CHECKING, Callable, cast
7+
from typing_extensions import NewType, TypeGuard
88

99
from mypy import nodes
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy.typeops import make_simplified_union
1212
from mypy.types import (
1313
AnyType,
14+
CallableType,
1415
Instance,
1516
ParamSpecType,
1617
ProperType,
1718
TupleType,
1819
Type,
1920
TypedDictType,
2021
TypeOfAny,
22+
TypeVarId,
2123
TypeVarTupleType,
24+
TypeVarType,
2225
UnionType,
2326
UnpackType,
27+
flatten_nested_tuples,
2428
get_proper_type,
2529
)
2630

2731
if TYPE_CHECKING:
2832
from mypy.infer import ArgumentInferContext
2933

3034

35+
IterableType = NewType("IterableType", Instance)
36+
"""Represents an instance of `Iterable[T]`."""
37+
38+
3139
def map_actuals_to_formals(
3240
actual_kinds: list[nodes.ArgKind],
3341
actual_names: Sequence[str | None] | None,
@@ -216,92 +224,41 @@ def expand_actual_type(
216224
original_actual = actual_type
217225
actual_type = get_proper_type(actual_type)
218226
if actual_kind == nodes.ARG_STAR:
219-
if isinstance(actual_type, UnionType):
220-
proper_types = [get_proper_type(t) for t in actual_type.items]
221-
# special case: union of equal sized tuples. (e.g. `tuple[int, int] | tuple[None, None]`)
222-
if is_equal_sized_tuples(proper_types):
223-
# transform union of tuples into a tuple of unions
224-
# e.g. tuple[A, B, C] | tuple[None, None, None] -> tuple[A | None, B | None, C | None]
225-
tuple_args: list[Type] = [
226-
make_simplified_union(items)
227-
for items in zip(*(t.items for t in proper_types))
228-
]
229-
actual_type = TupleType(
230-
tuple_args,
231-
# use Iterable[A | B | C] as the fallback type
232-
fallback=Instance(
233-
self.context.iterable_type.type, [UnionType.make_union(tuple_args)]
234-
),
235-
)
236-
else:
237-
# reinterpret all union items as iterable types (if possible)
238-
# and return the union of the iterable item types results.
239-
from mypy.subtypes import is_subtype
240-
241-
iterable_type = self.context.iterable_type
242-
243-
def as_iterable_type(t: Type) -> Type:
244-
"""Map a type to the iterable supertype if it is a subtype."""
245-
p_t = get_proper_type(t)
246-
if isinstance(p_t, Instance) and is_subtype(t, iterable_type):
247-
return map_instance_to_supertype(p_t, iterable_type.type)
248-
if isinstance(p_t, TupleType):
249-
# Convert tuple[A, B, C] to Iterable[A | B | C].
250-
return Instance(iterable_type.type, [make_simplified_union(p_t.items)])
251-
return t
252-
253-
# create copies of self for each item in the union
254-
sub_expanders = [
255-
ArgTypeExpander(context=self.context) for _ in actual_type.items
256-
]
257-
for expander in sub_expanders:
258-
expander.tuple_index = int(self.tuple_index)
259-
expander.kwargs_used = set(self.kwargs_used)
260-
261-
candidate_type = make_simplified_union(
262-
[
263-
e.expand_actual_type(
264-
as_iterable_type(item),
265-
actual_kind,
266-
formal_name,
267-
formal_kind,
268-
allow_unpack,
269-
)
270-
for e, item in zip(sub_expanders, actual_type.items)
271-
]
272-
)
273-
assert all(expander == sub_expanders[0] for expander in sub_expanders)
274-
# carry over the new state if all sub-expanders are the same state
275-
self.tuple_index = int(sub_expanders[0].tuple_index)
276-
self.kwargs_used = set(sub_expanders[0].kwargs_used)
277-
return candidate_type
278-
279-
if isinstance(actual_type, TypeVarTupleType):
280-
# This code path is hit when *Ts is passed to a callable and various
281-
# special-handling didn't catch this. The best thing we can do is to use
282-
# the upper bound.
283-
actual_type = get_proper_type(actual_type.upper_bound)
284-
if isinstance(actual_type, Instance) and actual_type.args:
285-
from mypy.subtypes import is_subtype
286-
287-
if is_subtype(actual_type, self.context.iterable_type):
288-
return map_instance_to_supertype(
289-
actual_type, self.context.iterable_type.type
290-
).args[0]
291-
else:
292-
# We cannot properly unpack anything other
293-
# than `Iterable` type with `*`.
294-
# Just return `Any`, other parts of code would raise
295-
# a different error for improper use.
296-
return AnyType(TypeOfAny.from_error)
297-
elif isinstance(actual_type, TupleType):
227+
# parse *args as one of the following:
228+
# IterableType | TupleType | ParamSpecType | AnyType
229+
star_args = self.parse_star_args_type(actual_type)
230+
# star_args = actual_type
231+
232+
# print(f"expand_actual_type: {actual_type=} {star_args=}")
233+
234+
# if isinstance(star_args, TypeVarTupleType):
235+
# # This code path is hit when *Ts is passed to a callable and various
236+
# # special-handling didn't catch this. The best thing we can do is to use
237+
# # the upper bound.
238+
# star_args = get_proper_type(star_args.upper_bound)
239+
# if isinstance(star_args, Instance) and star_args.args:
240+
# from mypy.subtypes import is_subtype
241+
#
242+
# if is_subtype(star_args, self.context.iterable_type):
243+
# return map_instance_to_supertype(
244+
# star_args, self.context.iterable_type.type
245+
# ).args[0]
246+
# else:
247+
# # We cannot properly unpack anything other
248+
# # than `Iterable` type with `*`.
249+
# # Just return `Any`, other parts of code would raise
250+
# # a different error for improper use.
251+
# return AnyType(TypeOfAny.from_error)
252+
if self.is_iterable_type(star_args):
253+
return star_args.args[0]
254+
elif isinstance(star_args, TupleType):
298255
# Get the next tuple item of a tuple *arg.
299-
if self.tuple_index >= len(actual_type.items):
256+
if self.tuple_index >= len(star_args.items):
300257
# Exhausted a tuple -- continue to the next *args.
301258
self.tuple_index = 1
302259
else:
303260
self.tuple_index += 1
304-
item = actual_type.items[self.tuple_index - 1]
261+
item = star_args.items[self.tuple_index - 1]
305262
if isinstance(item, UnpackType) and not allow_unpack:
306263
# An unpack item that doesn't have special handling, use upper bound as above.
307264
unpacked = get_proper_type(item.type)
@@ -315,9 +272,9 @@ def as_iterable_type(t: Type) -> Type:
315272
)
316273
item = fallback.args[0]
317274
return item
318-
elif isinstance(actual_type, ParamSpecType):
275+
elif isinstance(star_args, ParamSpecType):
319276
# ParamSpec is valid in *args but it can't be unpacked.
320-
return actual_type
277+
return star_args
321278
else:
322279
return AnyType(TypeOfAny.from_error)
323280
elif actual_kind == nodes.ARG_STAR2:
@@ -349,19 +306,197 @@ def as_iterable_type(t: Type) -> Type:
349306
# No translation for other kinds -- 1:1 mapping.
350307
return original_actual
351308

309+
def is_iterable(self, typ: Type) -> bool:
310+
from mypy.subtypes import is_subtype
311+
312+
return is_subtype(typ, self.context.iterable_type)
313+
314+
def is_iterable_instance_subtype(self, typ: Type) -> TypeGuard[Instance]:
315+
from mypy.subtypes import is_subtype
316+
317+
p_t = get_proper_type(typ)
318+
return (
319+
isinstance(p_t, Instance)
320+
and bool(p_t.args)
321+
and is_subtype(p_t, self.context.iterable_type)
322+
)
323+
324+
def is_iterable_type(self, typ: Type) -> TypeGuard[IterableType]:
325+
"""Check if the type is an Iterable[T] or a subtype of it."""
326+
p_t = get_proper_type(typ)
327+
return isinstance(p_t, Instance) and p_t.type == self.context.iterable_type.type
328+
329+
def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
330+
"""Reinterpret a type as Iterable[T], or return AnyType if not possible."""
331+
p_t = get_proper_type(typ)
332+
if self.is_iterable_type(p_t):
333+
return p_t
334+
elif self.is_iterable_instance_subtype(p_t):
335+
cls = self.context.iterable_type.type
336+
return cast(IterableType, map_instance_to_supertype(p_t, cls))
337+
elif isinstance(p_t, UnionType):
338+
# If the type is a union, map each item to the iterable supertype.
339+
# the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
340+
converted_types = [self.as_iterable_type(get_proper_type(item)) for item in p_t.items]
341+
# if an item could not be interpreted as Iterable[T], we return AnyType
342+
if all(self.is_iterable_type(it) for it in converted_types):
343+
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
344+
iterable_types = cast(list[IterableType], converted_types)
345+
arg = make_simplified_union([it.args[0] for it in iterable_types])
346+
return self.make_iterable_type(arg)
347+
return AnyType(TypeOfAny.from_error)
348+
elif isinstance(p_t, TupleType):
349+
# maps tuple[A, B, C] -> Iterable[A | B | C]
350+
# note: proper_elements may contain UnpackType, for instance with
351+
# tuple[None, *tuple[None, ...]]..
352+
proper_elements = [get_proper_type(t) for t in flatten_nested_tuples(p_t.items)]
353+
args: list[Type] = []
354+
for p_e in proper_elements:
355+
if isinstance(p_e, UnpackType):
356+
r = self.as_iterable_type(p_e)
357+
if self.is_iterable_type(r):
358+
args.append(r.args[0])
359+
else:
360+
args.append(r)
361+
else:
362+
args.append(p_e)
363+
return self.make_iterable_type(make_simplified_union(args))
364+
if isinstance(p_t, UnpackType):
365+
return self.as_iterable_type(p_t.type)
366+
if isinstance(p_t, (TypeVarType, TypeVarTupleType)):
367+
return self.as_iterable_type(p_t.upper_bound)
368+
# fallback: use the solver to reinterpret the type as Iterable[T]
369+
if self.is_iterable(p_t):
370+
return self._solve_as_iterable(p_t)
371+
return AnyType(TypeOfAny.from_error)
372+
373+
def make_iterable_type(self, arg: Type) -> IterableType:
374+
value = Instance(self.context.iterable_type.type, [arg])
375+
return cast(IterableType, value)
376+
377+
def parse_star_args_type(
378+
self, typ: Type
379+
) -> TupleType | IterableType | ParamSpecType | AnyType:
380+
"""Parse the type of a *args argument.
381+
382+
Returns one TupleType, IterableType, ParamSpecType or AnyType.
383+
"""
384+
p_t = get_proper_type(typ)
385+
if isinstance(p_t, (TupleType, ParamSpecType, AnyType)):
386+
# just return the type as-is
387+
return p_t
388+
elif isinstance(p_t, TypeVarTupleType):
389+
return self.parse_star_args_type(p_t.upper_bound)
390+
elif isinstance(p_t, UnionType):
391+
proper_items = [get_proper_type(t) for t in p_t.items]
392+
# consider 2 cases:
393+
# 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None]
394+
# In this case transform union of same-sized tuples into a tuple of unions
395+
# e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None]
396+
if is_equal_sized_tuples(proper_items):
397+
398+
tuple_args: list[Type] = [
399+
make_simplified_union(items) for items in zip(*(t.items for t in proper_items))
400+
]
401+
actual_type = TupleType(
402+
tuple_args,
403+
# use Iterable[A | B | C] as the fallback type
404+
fallback=Instance(
405+
self.context.iterable_type.type, [UnionType.make_union(tuple_args)]
406+
),
407+
)
408+
return actual_type
409+
# 2. Union of iterable types, e.g. Iterable[A] | Iterable[B]
410+
# In this case return Iterable[A | B]
411+
# Note that this covers unions of differently sized tuples as well.
412+
else:
413+
converted_types = [self.as_iterable_type(p_i) for p_i in proper_items]
414+
if all(self.is_iterable_type(it) for it in converted_types):
415+
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
416+
iterables = cast(list[IterableType], converted_types)
417+
arg = make_simplified_union([it.args[0] for it in iterables])
418+
return self.make_iterable_type(arg)
419+
else:
420+
# some items in the union are not iterable, return AnyType
421+
return AnyType(TypeOfAny.from_error)
422+
elif self.is_iterable_type(parsed := self.as_iterable_type(p_t)):
423+
# in all other cases, we try to reinterpret the type as Iterable[T]
424+
return parsed
425+
return AnyType(TypeOfAny.from_error)
426+
427+
def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
428+
r"""Use the solver to cast a type as Iterable[T].
429+
430+
Returns the type as-is if solving fails.
431+
"""
432+
from mypy.constraints import infer_constraints_for_callable
433+
from mypy.nodes import ARG_POS
434+
from mypy.solve import solve_constraints
435+
436+
iterable_kind = self.context.iterable_type.type
437+
438+
# We first create an upcast function:
439+
# def [T] (Iterable[T]) -> Iterable[T]: ...
440+
# and then solve for T, given the input type as the argument.
441+
T = TypeVarType(
442+
"T",
443+
"T",
444+
TypeVarId(-1),
445+
values=[],
446+
upper_bound=AnyType(TypeOfAny.special_form),
447+
default=AnyType(TypeOfAny.special_form),
448+
)
449+
target = Instance(iterable_kind, [T])
450+
451+
upcast_callable = CallableType(
452+
variables=[T],
453+
arg_types=[target],
454+
arg_kinds=[ARG_POS],
455+
arg_names=[None],
456+
ret_type=T,
457+
fallback=self.context.function_type,
458+
)
459+
constraints = infer_constraints_for_callable(
460+
upcast_callable, [typ], [ARG_POS], [None], [[0]], context=self.context
461+
)
462+
463+
(sol,), _ = solve_constraints([T], constraints)
464+
465+
if sol is None: # solving failed, return AnyType fallback
466+
return AnyType(TypeOfAny.from_error)
467+
return self.make_iterable_type(sol)
468+
352469

353470
def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]:
354-
"""Check if all types are tuples of the same size."""
471+
"""Check if all types are tuples of the same size.
472+
473+
We use `flatten_nested_tuples` to deal with nested tuples.
474+
Note that the result may still contain
475+
"""
355476
if not types:
356477
return True
357478

358479
iterator = iter(types)
359-
first = next(iterator)
360-
if not isinstance(first, TupleType):
480+
typ = next(iterator)
481+
if not isinstance(typ, TupleType):
482+
return False
483+
flattened_elements = flatten_nested_tuples(typ.items)
484+
if any(
485+
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType))
486+
for member in flattened_elements
487+
):
488+
# this can happen e.g. with tuple[int, *tuple[int, ...], int]
361489
return False
362-
size = first.length()
490+
size = len(flattened_elements)
363491

364-
for item in iterator:
365-
if not isinstance(item, TupleType) or item.length() != size:
492+
for typ in iterator:
493+
if not isinstance(typ, TupleType):
494+
return False
495+
flattened_elements = flatten_nested_tuples(typ.items)
496+
if len(flattened_elements) != size or any(
497+
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType))
498+
for member in flattened_elements
499+
):
500+
# this can happen e.g. with tuple[int, *tuple[int, ...], int]
366501
return False
367502
return True

0 commit comments

Comments
 (0)