diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 2c41f2e273cc..e384e8c5e5b3 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -68,6 +68,7 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, @@ -1400,6 +1401,67 @@ def analyze_typeddict_access( fallback=mx.chk.named_type("builtins.function"), name=name, ) + elif name == "get": + # synthesize TypedDict.get() overloads + t = TypeVarType( + "T", + "T", + id=TypeVarId(-1), + values=[], + upper_bound=mx.chk.named_type("builtins.object"), + default=AnyType(TypeOfAny.from_omitted_generics), + ) + str_type = mx.chk.named_type("builtins.str") + fn_type = mx.chk.named_type("builtins.function") + object_type = mx.chk.named_type("builtins.object") + + overloads: list[CallableType] = [] + # add two overloads per TypedDictType spec + for key, val in typ.items.items(): + # first overload: def(Literal[key]) -> val + no_default = CallableType( + arg_types=[LiteralType(key, fallback=str_type)], + arg_kinds=[ARG_POS], + arg_names=[None], + ret_type=val, + fallback=fn_type, + name=name, + ) + # second Overload: def [T] (Literal[key], default: T | Val, /) -> T | Val + with_default = CallableType( + variables=[t], + arg_types=[LiteralType(key, fallback=str_type), UnionType.make_union([val, t])], + arg_kinds=[ARG_POS, ARG_POS], + arg_names=[None, None], + ret_type=UnionType.make_union([val, t]), + fallback=fn_type, + name=name, + ) + overloads.append(no_default) + overloads.append(with_default) + + # finally, add fallback overloads when a key is used that is not in the TypedDict + # def (str) -> object + fallback_no_default = CallableType( + arg_types=[str_type], + arg_kinds=[ARG_POS], + arg_names=[None], + ret_type=object_type, + fallback=fn_type, + name=name, + ) + # def (str, object) -> object + fallback_with_default = CallableType( + arg_types=[str_type, object_type], + arg_kinds=[ARG_POS, ARG_POS], + arg_names=[None, None], + ret_type=object_type, + fallback=fn_type, + name=name, + ) + overloads.append(fallback_no_default) + overloads.append(fallback_with_default) + return Overloaded(overloads) return _analyze_member_access(name, typ.fallback, mx, override_info) diff --git a/mypy/join.py b/mypy/join.py index 099df02680f0..49ddbc35373c 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -69,6 +69,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: # Simplest case: join two types with the same base type (but # potentially different arguments). + last_known_value = ( + None if t.last_known_value != s.last_known_value else t.last_known_value + ) + # Combine type arguments. args: list[Type] = [] # N.B: We use zip instead of indexing because the lengths might have @@ -104,10 +108,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = join_types(ta, sa, self) if len(type_var.values) != 0 and new_type not in type_var.values: self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) if not is_subtype(new_type, type_var.upper_bound): self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) # TODO: contravariant case should use meet but pass seen instances as # an argument to keep track of recursive checks. elif type_var.variance in (INVARIANT, CONTRAVARIANT): @@ -117,7 +121,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = ta elif not is_equivalent(ta, sa): self.seen_instances.pop() - return object_from_instance(t) + return object_from_instance(t, last_known_value=last_known_value) else: # If the types are different but equivalent, then an Any is involved # so using a join in the contravariant case is also OK. @@ -141,11 +145,17 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = join_types(ta, sa, self) assert new_type is not None args.append(new_type) - result: ProperType = Instance(t.type, args) + result: ProperType = Instance(t.type, args, last_known_value=last_known_value) elif t.type.bases and is_proper_subtype( t, s, subtype_context=SubtypeContext(ignore_type_params=True) ): result = self.join_instances_via_supertype(t, s) + elif s.type.bases and is_proper_subtype( + s, t, subtype_context=SubtypeContext(ignore_type_params=True) + ): + result = self.join_instances_via_supertype(s, t) + elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)): + result = self.join_instances_via_supertype(t, s) else: # Now t is not a subtype of s, and t != s. Now s could be a subtype # of t; alternatively, we need to find a common supertype. This works @@ -621,13 +631,16 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType): if t == self.s: + # E.g. Literal["x"], Literal["x"] -> Literal["x"] return t if self.s.fallback.type.is_enum and t.fallback.type.is_enum: return mypy.typeops.make_simplified_union([self.s, t]) return join_types(self.s.fallback, t.fallback) elif isinstance(self.s, Instance) and self.s.last_known_value == t: - return t + # E.g. Literal["x"], Literal["x"]? -> Literal["x"]? + return self.s else: + # E.g. Literal["x"], Literal["y"]? -> str return join_types(self.s, t.fallback) def visit_partial_type(self, t: PartialType) -> ProperType: @@ -848,10 +861,12 @@ def combine_arg_names( return new_names -def object_from_instance(instance: Instance) -> Instance: +def object_from_instance( + instance: Instance, last_known_value: LiteralType | None = None +) -> Instance: """Construct the type 'builtins.object' from an instance type.""" # Use the fact that 'object' is always the last class in the mro. - res = Instance(instance.type.mro[-1], []) + res = Instance(instance.type.mro[-1], [], last_known_value=last_known_value) return res diff --git a/mypy/meet.py b/mypy/meet.py index 349c15e668c3..649fe201e980 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -81,6 +81,30 @@ def meet_types(s: Type, t: Type) -> ProperType: t = get_proper_type(t) if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type: + # special casing for dealing with last known values + lkv: LiteralType | None + + if s.last_known_value is None: + lkv = t.last_known_value + elif t.last_known_value is None: + lkv = s.last_known_value + else: + lkv_meet = meet_types(s.last_known_value, t.last_known_value) + if isinstance(lkv_meet, UninhabitedType): + lkv = None + elif isinstance(lkv_meet, LiteralType): + lkv = lkv_meet + else: + msg = ( + f"Unexpected result: " + f"meet of {s.last_known_value=!s} and {t.last_known_value=!s} " + f"resulted in {lkv_meet!s}" + ) + raise ValueError(msg) + + t = t.copy_modified(last_known_value=lkv) + s = s.copy_modified(last_known_value=lkv) + # Code in checker.py should merge any extra_items where possible, so we # should have only compatible extra_items here. We check this before # the below subtype check, so that extra_attrs will not get erased. @@ -1088,8 +1112,14 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType) and self.s == t: return t - elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s): - return t + elif isinstance(self.s, Instance): + # if is_subtype(t.fallback, self.s): + # return t + if self.s.last_known_value is not None: + # meet(Literal["max"]?, Literal["max"]) -> Literal["max"] + # meet(Literal["sum"]?, Literal["max"]) -> Never + return meet_types(self.s.last_known_value, t) + return self.default(self.s) else: return self.default(self.s) diff --git a/mypy/solve.py b/mypy/solve.py index fbbcac2520ad..19b7255d0307 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: elif top is None: candidate = bottom elif is_subtype(bottom, top): - candidate = bottom + # Need to meet in case like Literal["x"]? <: T <: Literal["x"] + candidate = meet_types(bottom, top) else: candidate = None return candidate diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..497689b740bb 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -549,6 +549,13 @@ def visit_instance(self, left: Instance) -> bool: assert isinstance(erased, Instance) t = erased nominal = True + if self.proper_subtype and right.last_known_value is not None: + if left.last_known_value is None: + # E.g. str is not a proper subtype of Literal["x"]? + nominal = False + else: + # E.g. Literal[A]? <: Literal[B]? requires A <: B + nominal &= self._is_subtype(left.last_known_value, right.last_known_value) if right.type.has_type_var_tuple_type: # For variadic instances we simply find the correct type argument mappings, # all the heavy lifting is done by the tuple subtyping. @@ -629,8 +636,14 @@ def visit_instance(self, left: Instance) -> bool: return True if isinstance(item, Instance): return is_named_instance(item, "builtins.object") - if isinstance(right, LiteralType) and left.last_known_value is not None: - return self._is_subtype(left.last_known_value, right) + if isinstance(right, LiteralType): + if self.proper_subtype: + # Instance types like Literal["sum"]? is *assignable* to Literal["sum"], + # but is not a proper subtype of it. (Literal["sum"]? is a gradual type, + # that is a proper subtype of str, and assignable to Literal["sum"]. + return False + if left.last_known_value is not None: + return self._is_subtype(left.last_known_value, right) if isinstance(right, FunctionLike): # Special case: Instance can be a subtype of Callable / Overloaded. call = find_member("__call__", left, left, is_operator=True) @@ -965,6 +978,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_literal_type(self, left: LiteralType) -> bool: if isinstance(self.right, LiteralType): return left == self.right + elif ( + isinstance(self.right, Instance) + and self.right.last_known_value is not None + and self.proper_subtype + ): + return self._is_subtype(left, self.right.last_known_value) else: return self._is_subtype(left.fallback, self.right) @@ -2127,6 +2146,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool: item = get_proper_type(item) supertype = get_proper_type(supertype) + # Use last known value for Instance types, if available. + # This ensures that e.g. Literal["max"]? is covered by Literal["max"]. + if isinstance(item, Instance) and item.last_known_value is not None: + item = item.last_known_value + # Since runtime type checks will ignore type arguments, erase the types. if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()): supertype = erase_type(supertype) diff --git a/mypy/test/testsubtypes.py b/mypy/test/testsubtypes.py index b75c22bca7f7..5be32f628de1 100644 --- a/mypy/test/testsubtypes.py +++ b/mypy/test/testsubtypes.py @@ -1,7 +1,7 @@ from __future__ import annotations from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT -from mypy.subtypes import is_subtype +from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away from mypy.test.helpers import Suite from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType @@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None: def test_fallback_not_subtype_of_tuple(self) -> None: self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a)) + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" ≲ str -> YES + # str ≲ "x" -> NO + # "x"? ≲ str -> YES + # str ≲ "x"? -> YES + self.assert_subtype(str1, str_type) + self.assert_not_subtype(str_type, str1) + self.assert_subtype(str1_inst, str_type) + self.assert_subtype(str_type, str1_inst) + + # other operand is the same literal + # "x" ≲ "x" -> YES + # "x" ≲ "x"? -> YES + # "x"? ≲ "x" -> YES + # "x"? ≲ "x"? -> YES + self.assert_subtype(str1, str1) + self.assert_subtype(str1, str1_inst) + self.assert_subtype(str1_inst, str1) + self.assert_subtype(str1_inst, str1_inst) + + # other operand is a different literal + # "x" ≲ "y" -> NO + # "x" ≲ "y"? -> YES + # "x"? ≲ "y" -> NO + # "x"? ≲ "y"? -> YES + self.assert_not_subtype(str1, str2) + self.assert_subtype(str1, str2_inst) + self.assert_not_subtype(str1_inst, str2) + self.assert_subtype(str1_inst, str2_inst) + + # check proper subtyping + # other operand is the fallback type + # "x" <: str -> YES + # str <: "x" -> NO + # "x"? <: str -> YES + # str <: "x"? -> NO + self.assert_proper_subtype(str1, str_type) + self.assert_not_proper_subtype(str_type, str1) + self.assert_proper_subtype(str1_inst, str_type) + self.assert_not_proper_subtype(str_type, str1_inst) + + # other operand is the same literal + # "x" <: "x" -> YES + # "x" <: "x"? -> YES + # "x"? <: "x" -> NO + # "x"? <: "x"? -> YES + self.assert_proper_subtype(str1, str1) + self.assert_proper_subtype(str1, str1_inst) + self.assert_not_proper_subtype(str1_inst, str1) + self.assert_proper_subtype(str1_inst, str1_inst) + + # other operand is a different literal + # "x" <: "y" -> NO + # "x" <: "y"? -> NO + # "x"? <: "y" -> NO + # "x"? <: "y"? -> NO + self.assert_not_proper_subtype(str1, str2) + self.assert_not_proper_subtype(str1, str2_inst) + self.assert_not_proper_subtype(str1_inst, str2) + self.assert_not_proper_subtype(str1_inst, str2_inst) + # IDEA: Maybe add these test cases (they are tested pretty well in type # checker tests already): # * more interface subtyping test cases @@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None: # * any type # * generic function types + def assert_proper_subtype(self, s: Type, t: Type) -> None: + assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + + def assert_not_proper_subtype(self, s: Type, t: Type) -> None: + assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + def assert_subtype(self, s: Type, t: Type) -> None: assert is_subtype(s, t), f"{s} not subtype of {t}" @@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None: def assert_unrelated(self, s: Type, t: Type) -> None: self.assert_not_subtype(s, t) self.assert_not_subtype(t, s) + + +class RestrictionSuite(Suite): + # Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``. + + def setUp(self) -> None: + self.fx = TypeFixture() + + def assert_restriction(self, s: Type, t: Type, expected: Type) -> None: + actual = restrict_subtype_away(s, t) + msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)" + self.assertEqual(actual, expected, msg=msg.format(actual, expected)) + + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + uninhabited = self.fx.uninhabited + + # other operand is the fallback type + # "x" - str -> Never + # str - "x" -> str + # "x"? - str -> Never + # str - "x"? -> Never + self.assert_restriction(str1, str_type, uninhabited) + self.assert_restriction(str_type, str1, str_type) + self.assert_restriction(str1_inst, str_type, uninhabited) + self.assert_restriction(str_type, str1_inst, uninhabited) + + # other operand is the same literal + # "x" - "x" -> Never + # "x" - "x"? -> Never + # "x"? - "x" -> Never + # "x"? - "x"? -> Never + self.assert_restriction(str1, str1, uninhabited) + self.assert_restriction(str1, str1_inst, uninhabited) + self.assert_restriction(str1_inst, str1, uninhabited) + self.assert_restriction(str1_inst, str1_inst, uninhabited) + + # other operand is a different literal + # "x" - "y" -> "x" + # "x" - "y"? -> Never + # "x"? - "y" -> "x"? + # "x"? - "y"? -> Never + self.assert_restriction(str1, str2, str1) + self.assert_restriction(str1, str2_inst, uninhabited) + self.assert_restriction(str1_inst, str2, str1_inst) + self.assert_restriction(str1_inst, str2_inst, uninhabited) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 0fe41bc28ecd..bac0da779756 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -601,7 +601,7 @@ def test_simplified_union_with_literals(self) -> None: [fx.lit1_inst, fx.lit3_inst], UnionType([fx.lit1_inst, fx.lit3_inst]) ) self.assert_simplified_union([fx.lit1_inst, fx.uninhabited], fx.lit1_inst) - self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1) + self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1_inst) self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst])) @@ -651,7 +651,46 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), ) - self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst + ) + + def test_simplified_union_with_mixed_str_literals2(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" | str -> str + # str | "x" -> str + # "x"? | str -> str + # str | "x"? -> str + self.assert_simplified_union([str1, str_type], str_type) + self.assert_simplified_union([str_type, str1], str_type) + self.assert_simplified_union([str1_inst, str_type], str_type) + self.assert_simplified_union([str_type, str1_inst], str_type) + + # other operand is the same literal + # "x" | "x" -> "x" + # "x" | "x"? -> "x"? + # "x"? | "x" -> "x"? + # "x"? | "x"? -> "x"? + self.assert_simplified_union([str1, str1], str1) + self.assert_simplified_union([str1, str1_inst], str1_inst) + self.assert_simplified_union([str1_inst, str1], str1_inst) + self.assert_simplified_union([str1_inst, str1_inst], str1_inst) + + # other operand is a different literal + # "x" | "y" -> "x" | "y" + # "x" | "y"? -> "x" | "y"? + # "x"? | "y" -> "x"? | "y" + # "x"? | "y"? -> "x"? | "y"? + self.assert_simplified_union([str1, str2], UnionType([str1, str2])) + self.assert_simplified_union([str1, str2_inst], UnionType([str1, str2_inst])) + self.assert_simplified_union([str1_inst, str2], UnionType([str1_inst, str2])) + self.assert_simplified_union([str1_inst, str2_inst], UnionType([str1_inst, str2_inst])) def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) @@ -1011,6 +1050,39 @@ def test_literal_type(self) -> None: UnionType([lit2, lit3]), UnionType([lit1, lit2]), UnionType([lit2, lit3, lit1]) ) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> str + # str , "x" -> str + # "x"?, str -> str + # str , "x"? -> str + self.assert_join(str1, str_type, str_type) + self.assert_join(str1_inst, str_type, str_type) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x"? + # "x"?, "x" -> "x"? + # "x"?, "x"? -> "x"? + self.assert_join(str1, str1, str1) + self.assert_join(str1, str1_inst, str1_inst) + self.assert_join(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> str (TODO: consider using "x" | "y" (treat real literals like enum)) + # "x" , "y"? -> str + # "x"?, "y" -> str + # "x"?, "y"? -> str + self.assert_join(str1, str2, str_type) + self.assert_join(str1, str2_inst, str_type) + self.assert_join(str1_inst, str2_inst, str_type) + def test_variadic_tuple_joins(self) -> None: # These tests really test just the "arity", to be sure it is handled correctly. self.assert_join( @@ -1304,6 +1376,39 @@ def test_literal_type(self) -> None: assert is_same_type(lit1, narrow_declared_type(lit1, a)) assert is_same_type(lit2, narrow_declared_type(lit2, a)) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> "x" + # str , "x" -> "x" + # "x"?, str -> "x"? + # str , "x"? -> "x"? + self.assert_meet(str1, str_type, str1) + self.assert_meet(str1_inst, str_type, str1_inst) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x" + # "x"?, "x" -> "x" + # "x"?, "x"? -> "x"? + self.assert_meet(str1, str1, str1) + self.assert_meet(str1, str1_inst, str1) + self.assert_meet(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> Never + # "x" , "y"? -> Never + # "x"?, "y" -> Never + # "x"?, "y"? -> str + self.assert_meet_uninhabited(str1, str2) + self.assert_meet_uninhabited(str1, str2_inst) + self.assert_meet(str1_inst, str2_inst, str_type) + # FIX generic interfaces + ranges def assert_meet_uninhabited(self, s: Type, t: Type) -> None: diff --git a/mypy/typeops.py b/mypy/typeops.py index 88b3c5da48ce..a51e8745a641 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -572,6 +572,8 @@ def make_simplified_union( * [int, Any] -> Union[int, Any] (Any types are not simplified away!) * [Any, Any] -> Any * [int, Union[bytes, str]] -> Union[int, bytes, str] + * [Literal[1]?, Literal[1]] -> Literal[1]? + * Literal["max"]?, Literal["max", "sum"] -> Literal["max"]? | Literal["sum"] Note: This must NOT be used during semantic analysis, since TypeInfos may not be fully initialized. @@ -600,13 +602,22 @@ def make_simplified_union( ): simplified_set = try_contracting_literals_in_union(simplified_set) + # Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]? + proper_items: list[ProperType] = [get_proper_type(t) for t in simplified_set] + last_known_values: list[LiteralType | None] = [ + p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items + ] + simplified_set = [ + item for item, p_t in zip(simplified_set, proper_items) if p_t not in last_known_values + ] + result = get_proper_type(UnionType.make_union(simplified_set, line, column)) nitems = len(items) if nitems > 1 and ( nitems > 2 or not (type(items[0]) is NoneType or type(items[1]) is NoneType) ): - # Step 5: At last, we erase any (inconsistent) extra attributes on instances. + # Step 6: At last, we erase any (inconsistent) extra attributes on instances. # Initialize with None instead of an empty set as a micro-optimization. The set # is needed very rarely, so we try to avoid constructing it. diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 53efcc0d22e3..c101f9dc9536 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -4076,7 +4076,7 @@ def check_and(maybe: bool) -> None: bar = None if maybe and (foo := [1])[(bar := 0)]: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" else: reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]" reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" @@ -4102,7 +4102,7 @@ def check_or(maybe: bool) -> None: reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]" else: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" def check_or_nested(maybe: bool) -> None: foo = None diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbb..2e452031c092 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2980,18 +2980,46 @@ z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Litera [case testJoinLiteralAndInstance] from typing import Generic, TypeVar, Literal -T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T", covariant=False, contravariant=False) +S = TypeVar("S", covariant=False, contravariant=False) -class A(Generic[T]): ... +class A_inv(Generic[T]): ... +class A_co(Generic[T_co]): ... -def f(a: A[T], t: T) -> T: ... -def g(a: T, t: A[T]) -> T: ... +def check_inv(obj: A_inv[Literal[1]]) -> None: + def f(a: A_inv[S], t: S) -> S: ... + def g(a: S, t: A_inv[S]) -> S: ... -def check(obj: A[Literal[1]]) -> None: reveal_type(f(obj, 1)) # N: Revealed type is "Literal[1]" - reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "T" of "f" \ - # N: Revealed type is "Any" + reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "S" of "f" \ + # N: Revealed type is "Any" reveal_type(g(1, obj)) # N: Revealed type is "Literal[1]" - reveal_type(g('', obj)) # E: Cannot infer value of type parameter "T" of "g" \ - # N: Revealed type is "Any" + reveal_type(g('', obj)) # E: Cannot infer value of type parameter "S" of "g" \ + # N: Revealed type is "Any" + +def check_co(obj: A_co[Literal[1]]) -> None: + def f(a: A_co[S], t: S) -> S: ... + def g(a: S, t: A_co[S]) -> S: ... + + reveal_type(f(obj, 1)) # N: Revealed type is "builtins.int" + reveal_type(f(obj, '')) # N: Revealed type is "builtins.object" + reveal_type(g(1, obj)) # N: Revealed type is "builtins.int" + reveal_type(g('', obj)) # N: Revealed type is "builtins.object" + +[case testJoinLiteralInstanceAndEnum] +from typing import Final, TypeVar +from enum import StrEnum + +T = TypeVar("T") +def join(a: T, b: T) -> T: ... + +class Foo(StrEnum): + A = "a" + +CONST: Final = "const" + +reveal_type(CONST) # N: Revealed type is "Literal['const']?" +reveal_type(join(Foo.A, CONST)) # N: Revealed type is "builtins.str" +reveal_type(join(CONST, Foo.A)) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index f264167cb067..01c9525b0f35 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1354,7 +1354,7 @@ m: str match m: case a if a := "test": - reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(a) # N: Revealed type is "Literal['test']?" [case testMatchNarrowingPatternGuard] m: object @@ -2627,7 +2627,7 @@ def int_literal() -> None: case other: other # E: Statement is unreachable -def str_literal() -> None: +def str_literal_from_literal() -> None: match 'foo': case 'a' as s: reveal_type(s) # N: Revealed type is "Literal['a']" @@ -2636,6 +2636,16 @@ def str_literal() -> None: case other: other # E: Statement is unreachable + +def str_literal_from_str(arg: str) -> None: + match arg: + case 'a' as s: + reveal_type(s) # N: Revealed type is "Literal['a']" + case str(i): + reveal_type(i) # N: Revealed type is "builtins.str" + case other: + other # E: Statement is unreachable + [case testMatchOperations] # flags: --warn-unreachable @@ -2686,7 +2696,7 @@ match m[k]: match 0: case 0 as i: - reveal_type(i) # N: Revealed type is "Literal[0]?" + reveal_type(i) # N: Revealed type is "Literal[0]" case int(i): i # E: Statement is unreachable case other: diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index dd3f793fd02b..8b7e8441c6b7 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -214,10 +214,10 @@ i(arg=0) # E: Unexpected keyword argument "arg" from typing import Final, NamedTuple, Optional, List if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" l = [y2 := 1, y2 + 2, y2 + 3] reveal_type(y2) # N: Revealed type is "builtins.int" @@ -242,10 +242,10 @@ reveal_type(new_v) # N: Revealed type is "builtins.int" def f(x: int = (c := 4)) -> int: if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" x = (y := 1) + (z := 2) reveal_type(x) # N: Revealed type is "builtins.int" @@ -284,7 +284,7 @@ def f(x: int = (c := 4)) -> int: f(x=(y7 := 3)) reveal_type(y7) # N: Revealed type is "builtins.int" - reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "builtins.int" + reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "Literal[3]?" y8 # E: Name "y8" is not defined y7 = 1.0 # E: Incompatible types in assignment (expression has type "float", variable has type "int") @@ -325,16 +325,16 @@ def check_binder(x: Optional[int], y: Optional[int], z: Optional[int], a: Option reveal_type(y) # N: Revealed type is "Union[builtins.int, None]" if x and (y := 1): - reveal_type(y) # N: Revealed type is "builtins.int" + reveal_type(y) # N: Revealed type is "Literal[1]?" if (a := 1) and x: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[1]?" if (b := 1) or x: reveal_type(b) # N: Revealed type is "builtins.int" if z := 1: - reveal_type(z) # N: Revealed type is "builtins.int" + reveal_type(z) # N: Revealed type is "Literal[1]?" def check_partial() -> None: x = None diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 34cae74d795b..86b917cb5249 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1016,7 +1016,7 @@ class A: pass D = TypedDict('D', {'x': List[int], 'y': int}) d: D reveal_type(d.get('x', [])) # N: Revealed type is "builtins.list[builtins.int]" -d.get('x', ['x']) # E: List item 0 has incompatible type "str"; expected "int" +reveal_type(d.get('x', ['x'])) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]" a = [''] reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]" [builtins fixtures/dict.pyi] @@ -1026,14 +1026,22 @@ reveal_type(d.get('x', a)) # N: Revealed type is "Union[builtins.list[builtins.i from typing import TypedDict D = TypedDict('D', {'x': int, 'y': str}) d: D -d.get() # E: All overload variants of "get" of "Mapping" require at least one argument \ +d.get() # E: All overload variants of "get" require at least one argument \ # N: Possible overload variants: \ - # N: def get(self, k: str) -> object \ - # N: def [V] get(self, k: str, default: object) -> object -d.get('x', 1, 2) # E: No overload variant of "get" of "Mapping" matches argument types "str", "int", "int" \ + # N: def get(Literal['x'], /) -> int \ + # N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \ + # N: def get(Literal['y'], /) -> str \ + # N: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T] \ + # N: def get(str, /) -> object \ + # N: def get(str, object, /) -> object +d.get('x', 1, 2) # E: No overload variant of "get" matches argument types "str", "int", "int" \ # N: Possible overload variants: \ - # N: def get(self, k: str) -> object \ - # N: def [V] get(self, k: str, default: Union[int, V]) -> object + # N: def get(Literal['x'], /) -> int \ + # N: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] \ + # N: def get(Literal['y'], /) -> str \ + # N: def [T] get(Literal['y'], Union[int, T], /) -> Union[str, T] \ + # N: def get(str, /) -> object \ + # N: def get(str, object, /) -> object x = d.get('z') reveal_type(x) # N: Revealed type is "builtins.object" s = '' diff --git a/test-data/unit/check-warnings.test b/test-data/unit/check-warnings.test index a2d201fa301d..acc122d8fb89 100644 --- a/test-data/unit/check-warnings.test +++ b/test-data/unit/check-warnings.test @@ -49,13 +49,6 @@ from typing import cast a = 1 b = cast(object, 1) -[case testCastFromLiteralRedundant] -# flags: --warn-redundant-casts -from typing import cast - -cast(int, 1) -[out] -main:4: error: Redundant cast to "int" [case testCastFromUnionOfAnyOk] # flags: --warn-redundant-casts diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 9b5d8a1ac54c..c1826e7c5d67 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1046,11 +1046,14 @@ reveal_type(d.get(s)) _testTypedDictGet.py:6: note: Revealed type is "Union[builtins.int, None]" _testTypedDictGet.py:7: note: Revealed type is "Union[builtins.str, None]" _testTypedDictGet.py:8: note: Revealed type is "builtins.object" -_testTypedDictGet.py:9: error: All overload variants of "get" of "Mapping" require at least one argument +_testTypedDictGet.py:9: error: All overload variants of "get" require at least one argument _testTypedDictGet.py:9: note: Possible overload variants: -_testTypedDictGet.py:9: note: def get(self, str, /) -> object -_testTypedDictGet.py:9: note: def get(self, str, /, default: object) -> object -_testTypedDictGet.py:9: note: def [_T] get(self, str, /, default: _T) -> object +_testTypedDictGet.py:9: note: def get(Literal['x'], /) -> int +_testTypedDictGet.py:9: note: def [T] get(Literal['x'], Union[int, T], /) -> Union[int, T] +_testTypedDictGet.py:9: note: def get(Literal['y'], /) -> str +_testTypedDictGet.py:9: note: def [T] get(Literal['y'], Union[str, T], /) -> Union[str, T] +_testTypedDictGet.py:9: note: def get(str, /) -> object +_testTypedDictGet.py:9: note: def get(str, object, /) -> object _testTypedDictGet.py:11: note: Revealed type is "builtins.object" [case testTypedDictMappingMethods]