Skip to content

Rework starargs with union argument #19651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 227 additions & 24 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,39 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, cast
from typing_extensions import NewType, TypeGuard, TypeIs

from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.typeops import make_simplified_union
from mypy.types import (
AnyType,
CallableType,
Instance,
ParamSpecType,
ProperType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
TypeVarId,
TypeVarTupleType,
TypeVarType,
UnionType,
UnpackType,
flatten_nested_tuples,
get_proper_type,
)

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext


IterableType = NewType("IterableType", Instance)
"""Represents an instance of `Iterable[T]`."""


def map_actuals_to_formals(
actual_kinds: list[nodes.ArgKind],
actual_names: Sequence[str | None] | None,
Expand Down Expand Up @@ -54,6 +66,17 @@ def map_actuals_to_formals(
elif actual_kind == nodes.ARG_STAR:
# We need to know the actual type to map varargs.
actualt = get_proper_type(actual_arg_type(ai))

# Special case for union of equal sized tuples.
if (
isinstance(actualt, UnionType)
and actualt.items
and is_equal_sized_tuples(
proper_types := [get_proper_type(t) for t in actualt.items]
)
):
# pick an arbitrary member
actualt = proper_types[0]
if isinstance(actualt, TupleType):
# A tuple actual maps to a fixed number of formals.
for _ in range(len(actualt.items)):
Expand Down Expand Up @@ -193,32 +216,20 @@ def expand_actual_type(
original_actual = actual_type
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, TypeVarTupleType):
# This code path is hit when *Ts is passed to a callable and various
# special-handling didn't catch this. The best thing we can do is to use
# the upper bound.
actual_type = get_proper_type(actual_type.upper_bound)
if isinstance(actual_type, Instance) and actual_type.args:
from mypy.subtypes import is_subtype

if is_subtype(actual_type, self.context.iterable_type):
return map_instance_to_supertype(
actual_type, self.context.iterable_type.type
).args[0]
else:
# We cannot properly unpack anything other
# than `Iterable` type with `*`.
# Just return `Any`, other parts of code would raise
# a different error for improper use.
return AnyType(TypeOfAny.from_error)
elif isinstance(actual_type, TupleType):
# parse *args as one of the following:
# IterableType | TupleType | ParamSpecType | AnyType
star_args_type = self.parse_star_args_type(actual_type)

if self.is_iterable_instance_type(star_args_type):
return star_args_type.args[0]
elif isinstance(star_args_type, TupleType):
# Get the next tuple item of a tuple *arg.
if self.tuple_index >= len(actual_type.items):
if self.tuple_index >= len(star_args_type.items):
# Exhausted a tuple -- continue to the next *args.
self.tuple_index = 1
else:
self.tuple_index += 1
item = actual_type.items[self.tuple_index - 1]
item = star_args_type.items[self.tuple_index - 1]
if isinstance(item, UnpackType) and not allow_unpack:
# An unpack item that doesn't have special handling, use upper bound as above.
unpacked = get_proper_type(item.type)
Expand All @@ -232,9 +243,9 @@ def expand_actual_type(
)
item = fallback.args[0]
return item
elif isinstance(actual_type, ParamSpecType):
elif isinstance(star_args_type, ParamSpecType):
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
return star_args_type
else:
return AnyType(TypeOfAny.from_error)
elif actual_kind == nodes.ARG_STAR2:
Expand Down Expand Up @@ -265,3 +276,195 @@ def expand_actual_type(
else:
# No translation for other kinds -- 1:1 mapping.
return original_actual

def is_iterable(self, typ: Type) -> bool:
"""Check if the type is an iterable, i.e. implements the Iterable Protocol."""
from mypy.subtypes import is_subtype

return is_subtype(typ, self.context.iterable_type)

def is_iterable_instance_type(self, typ: Type) -> TypeIs[IterableType]:
"""Check if the type is an Iterable[T]."""
p_t = get_proper_type(typ)
return isinstance(p_t, Instance) and p_t.type == self.context.iterable_type.type

def _make_iterable_instance_type(self, arg: Type) -> IterableType:
value = Instance(self.context.iterable_type.type, [arg])
return cast(IterableType, value)

def _solve_as_iterable(self, typ: Type) -> IterableType | AnyType:
r"""Use the solver to cast a type as Iterable[T].

Returns `AnyType` if solving fails.
"""
from mypy.constraints import infer_constraints_for_callable
from mypy.nodes import ARG_POS
from mypy.solve import solve_constraints

# We first create an upcast function:
# def [T] (Iterable[T]) -> Iterable[T]: ...
# and then solve for T, given the input type as the argument.
T = TypeVarType(
"T",
"T",
TypeVarId(-1),
values=[],
upper_bound=AnyType(TypeOfAny.from_omitted_generics),
default=AnyType(TypeOfAny.from_omitted_generics),
)
target = self._make_iterable_instance_type(T)
upcast_callable = CallableType(
variables=[T],
arg_types=[target],
arg_kinds=[ARG_POS],
arg_names=[None],
ret_type=target,
fallback=self.context.function_type,
)
constraints = infer_constraints_for_callable(
upcast_callable, [typ], [ARG_POS], [None], [[0]], self.context
)

(sol,), _ = solve_constraints([T], constraints)

if sol is None: # solving failed, return AnyType fallback
return AnyType(TypeOfAny.from_error)
return self._make_iterable_instance_type(sol)

def as_iterable_type(self, typ: Type) -> IterableType | AnyType:
"""Reinterpret a type as Iterable[T], or return AnyType if not possible.

This function specially handles certain types like UnionType, TupleType, and UnpackType.
Otherwise, the upcasting is performed using the solver.
"""
p_t = get_proper_type(typ)
if self.is_iterable_instance_type(p_t) or isinstance(p_t, AnyType):
return p_t
elif isinstance(p_t, UnionType):
# If the type is a union, map each item to the iterable supertype.
# the return the combined iterable type Iterable[A] | Iterable[B] -> Iterable[A | B]
converted_types = [self.as_iterable_type(get_proper_type(item)) for item in p_t.items]

if any(not self.is_iterable_instance_type(it) for it in converted_types):
# if any item could not be interpreted as Iterable[T], we return AnyType
return AnyType(TypeOfAny.from_error)
else:
# all items are iterable, return Iterable[T₁ | T₂ | ... | Tₙ]
iterable_types = cast(list[IterableType], converted_types)
arg = make_simplified_union([it.args[0] for it in iterable_types])
return self._make_iterable_instance_type(arg)
elif isinstance(p_t, TupleType):
# maps tuple[A, B, C] -> Iterable[A | B | C]
# note: proper_elements may contain UnpackType, for instance with
# tuple[None, *tuple[None, ...]]..
proper_elements = [get_proper_type(t) for t in flatten_nested_tuples(p_t.items)]
args: list[Type] = []
for p_e in proper_elements:
if isinstance(p_e, UnpackType):
r = self.as_iterable_type(p_e)
if self.is_iterable_instance_type(r):
args.append(r.args[0])
else:
# this *should* never happen, since UnpackType should
# only contain TypeVarTuple or a variable length tuple.
# However, we could get an `AnyType(TypeOfAny.from_error)`
# if for some reason the solver was triggered and failed.
args.append(r)
else:
args.append(p_e)
return self._make_iterable_instance_type(make_simplified_union(args))
elif isinstance(p_t, UnpackType):
return self.as_iterable_type(p_t.type)
elif isinstance(p_t, (TypeVarType, TypeVarTupleType)):
return self.as_iterable_type(p_t.upper_bound)
elif self.is_iterable(p_t):
# TODO: add a 'fast path' (needs measurement) that uses the map_instance_to_supertype
# mechanism? (Only if it works: gh-19662)
return self._solve_as_iterable(p_t)
return AnyType(TypeOfAny.from_error)

def parse_star_args_type(
self, typ: Type
) -> TupleType | IterableType | ParamSpecType | AnyType:
"""Parse the type of a ``*args`` argument.

Returns one of TupleType, IterableType, ParamSpecType or AnyType.
Returns AnyType(TypeOfAny.from_error) if the type cannot be parsed or is invalid.
"""
p_t = get_proper_type(typ)
if isinstance(p_t, (TupleType, ParamSpecType, AnyType)):
# just return the type as-is
return p_t
elif isinstance(p_t, TypeVarTupleType):
return self.parse_star_args_type(p_t.upper_bound)
elif isinstance(p_t, UnionType):
proper_items = [get_proper_type(t) for t in p_t.items]
# consider 2 cases:
# 1. Union of equal sized tuples, e.g. tuple[A, B] | tuple[None, None]
# In this case transform union of same-sized tuples into a tuple of unions
# e.g. tuple[A, B] | tuple[None, None] -> tuple[A | None, B | None]
if is_equal_sized_tuples(proper_items):

tuple_args: list[Type] = [
make_simplified_union(items) for items in zip(*(t.items for t in proper_items))
]
actual_type = TupleType(
tuple_args,
# use Iterable[A | B | C] as the fallback type
fallback=Instance(
self.context.iterable_type.type, [UnionType.make_union(tuple_args)]
),
)
return actual_type
# 2. Union of iterable types, e.g. Iterable[A] | Iterable[B]
# In this case return Iterable[A | B]
# Note that this covers unions of differently sized tuples as well.
else:
converted_types = [self.as_iterable_type(p_i) for p_i in proper_items]
if all(self.is_iterable_instance_type(it) for it in converted_types):
# all items are iterable, return Iterable[T1 | T2 | ... | Tn]
iterables = cast(list[IterableType], converted_types)
arg = make_simplified_union([it.args[0] for it in iterables])
return self._make_iterable_instance_type(arg)
else:
# some items in the union are not iterable, return AnyType
return AnyType(TypeOfAny.from_error)
elif self.is_iterable_instance_type(parsed := self.as_iterable_type(p_t)):
# in all other cases, we try to reinterpret the type as Iterable[T]
return parsed
return AnyType(TypeOfAny.from_error)


def is_equal_sized_tuples(types: Sequence[ProperType]) -> TypeGuard[Sequence[TupleType]]:
"""Check if all types are tuples of the same size.

We use `flatten_nested_tuples` to deal with nested tuples.
Note that the result may still contain
"""
if not types:
return True

iterator = iter(types)
typ = next(iterator)
if not isinstance(typ, TupleType):
return False
flattened_elements = flatten_nested_tuples(typ.items)
if any(
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType))
for member in flattened_elements
):
# this can happen e.g. with tuple[int, *tuple[int, ...], int]
return False
size = len(flattened_elements)

for typ in iterator:
if not isinstance(typ, TupleType):
return False
flattened_elements = flatten_nested_tuples(typ.items)
if len(flattened_elements) != size or any(
isinstance(get_proper_type(member), (UnpackType, TypeVarTupleType))
for member in flattened_elements
):
# this can happen e.g. with tuple[int, *tuple[int, ...], int]
return False
return True
48 changes: 39 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,9 @@ def infer_function_type_arguments_pass2(
def argument_infer_context(self) -> ArgumentInferContext:
if self._arg_infer_context_cache is None:
self._arg_infer_context_cache = ArgumentInferContext(
self.chk.named_type("typing.Mapping"), self.chk.named_type("typing.Iterable")
self.chk.named_type("typing.Mapping"),
self.chk.named_type("typing.Iterable"),
self.chk.named_type("builtins.function"),
)
return self._arg_infer_context_cache

Expand Down Expand Up @@ -2670,6 +2672,30 @@ def check_arg(
original_caller_type = get_proper_type(original_caller_type)
callee_type = get_proper_type(callee_type)

if isinstance(callee_type, UnpackType) and not isinstance(caller_type, UnpackType):
# it can happen that the caller_type got expanded.
# since this is from a callable definition, it should be one of the following:
# - TupleType, TypeVarTupleType, or a variable length tuple Instance.
unpack_arg = get_proper_type(callee_type.type)
if isinstance(unpack_arg, TypeVarTupleType):
# substitute with upper bound of the TypeVarTuple
unpack_arg = get_proper_type(unpack_arg.upper_bound)
# note: not using elif, since in the future upper bound may be a finite tuple
if isinstance(unpack_arg, Instance) and unpack_arg.type.fullname == "builtins.tuple":
callee_type = get_proper_type(unpack_arg.args[0])
elif isinstance(unpack_arg, TupleType):
# this branch should currently never hit, but it may hit in the future,
# if it will ever be allowed to upper bound TypeVarTuple with a tuple type.
elements = flatten_nested_tuples(unpack_arg.items)
if m < len(elements):
# pick the corresponding item from the tuple
callee_type = get_proper_type(elements[m])
else:
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, context)
return
else:
raise TypeError(f"did not expect unpack_arg to be of type {type(unpack_arg)=}")

if isinstance(caller_type, DeletedType):
self.msg.deleted_as_rvalue(caller_type, context)
# Only non-abstract non-protocol class can be given where Type[...] is expected...
Expand Down Expand Up @@ -5225,29 +5251,33 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
ctx = ctx_item.type
else:
ctx = None
tt = self.accept(item.expr, ctx)
tt = get_proper_type(tt)
if isinstance(tt, TupleType):
if find_unpack_in_list(tt.items) is not None:
original_arg_type = self.accept(item.expr, ctx)
# convert arg type to one of TupleType, IterableType, AnyType or
arg_type_expander = ArgTypeExpander(self.argument_infer_context())
star_args_type = arg_type_expander.parse_star_args_type(original_arg_type)
if isinstance(star_args_type, TupleType):
if find_unpack_in_list(star_args_type.items) is not None:
if seen_unpack_in_items:
# Multiple unpack items are not allowed in tuples,
# fall back to instance type.
return self.check_lst_expr(e, "builtins.tuple", "<tuple>")
else:
seen_unpack_in_items = True
items.extend(tt.items)
items.extend(star_args_type.items)
# Note: this logic depends on full structure match in tuple_context_matches().
if unpack_in_context:
j += 1
else:
# If there is an unpack in expressions, but not in context, this will
# result in an error later, just do something predictable here.
j += len(tt.items)
j += len(star_args_type.items)
else:
if allow_precise_tuples and not seen_unpack_in_items:
# Handle (x, *y, z), where y is e.g. tuple[Y, ...].
if isinstance(tt, Instance) and self.chk.type_is_iterable(tt):
item_type = self.chk.iterable_item_type(tt, e)
if isinstance(star_args_type, Instance) and self.chk.type_is_iterable(
star_args_type
):
item_type = self.chk.iterable_item_type(star_args_type, e)
mapped = self.chk.named_generic_type("builtins.tuple", [item_type])
items.append(UnpackType(mapped))
seen_unpack_in_items = True
Expand Down
Loading
Loading