From 99725c99f92c5f9521f67c3218288cf39cc79786 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 10 Aug 2025 12:36:04 +0200 Subject: [PATCH 1/5] Fix `--strict-equality` for iteratively visited code. --- mypy/errors.py | 54 ++++++++++++++++++++++++++--- mypy/messages.py | 18 +++++++++- test-data/unit/check-narrowing.test | 26 ++++++++++++++ 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/mypy/errors.py b/mypy/errors.py index d75c1c62a1ed..f28eaefefc53 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -230,9 +230,9 @@ def filtered_errors(self) -> list[ErrorInfo]: class IterationDependentErrors: """An `IterationDependentErrors` instance serves to collect the `unreachable`, - `redundant-expr`, and `redundant-casts` errors, as well as the revealed types, - handled by the individual `IterationErrorWatcher` instances sequentially applied to - the same code section.""" + `redundant-expr`, and `redundant-casts` errors, as well as the revealed types and + non-overlapping types, handled by the individual `IterationErrorWatcher` instances + sequentially applied to the same code section.""" # One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per # iteration step. Meaning of the tuple items: ErrorCode, message, line, column, @@ -248,9 +248,18 @@ class IterationDependentErrors: # end_line, end_column: revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]] + # One dictionary of non-overlapping types per iteration step. Meaning of the key + # tuple items: line, column, end_line, end_column, kind: + nonoverlapping_types: list[ + dict[ + tuple[int, int, int | None, int | None, str], tuple[Type, Type] + ], + ] + def __init__(self) -> None: self.uselessness_errors = [] self.unreachable_lines = [] + self.nonoverlapping_types = [] self.revealed_types = defaultdict(list) def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]: @@ -270,6 +279,39 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod context.end_column = error_info[5] yield error_info[1], context, error_info[0] + + def yield_nonoverlapping_types(self) -> Iterator[ + tuple[tuple[list[Type], list[Type]], str, Context] + ]: + """Report expressions were non-overlapping types were detected for all iterations + were the expression was reachable.""" + + selected = set() + for candidate in set(chain(*self.nonoverlapping_types)): + if all( + (candidate in nonoverlap) or (candidate[0] in lines) + for nonoverlap, lines in zip( + self.nonoverlapping_types, self.unreachable_lines + ) + ): + selected.add(candidate) + + persistent_nonoverlaps: dict[ + tuple[int, int, int | None, int | None, str], tuple[list[Type], list[Type]] + ] = defaultdict(lambda: ([], [])) + for nonoverlaps in self.nonoverlapping_types: + for candidate, (left, right) in nonoverlaps.items(): + if candidate in selected: + types = persistent_nonoverlaps[candidate] + types[0].append(left) + types[1].append(right) + + for error_info, types in persistent_nonoverlaps.items(): + context = Context(line=error_info[0], column=error_info[1]) + context.end_line = error_info[2] + context.end_column = error_info[3] + yield (types[0], types[1]), error_info[4], context + def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: """Yield all types revealed in at least one iteration step.""" @@ -282,8 +324,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: class IterationErrorWatcher(ErrorWatcher): """Error watcher that filters and separately collects `unreachable` errors, - `redundant-expr` and `redundant-casts` errors, and revealed types when analysing - code sections iteratively to help avoid making too-hasty reports.""" + `redundant-expr` and `redundant-casts` errors, and revealed types and + non-overlapping types when analysing code sections iteratively to help avoid + making too-hasty reports.""" iteration_dependent_errors: IterationDependentErrors @@ -304,6 +347,7 @@ def __init__( ) self.iteration_dependent_errors = iteration_dependent_errors iteration_dependent_errors.uselessness_errors.append(set()) + iteration_dependent_errors.nonoverlapping_types.append({}) iteration_dependent_errors.unreachable_lines.append(set()) def on_error(self, file: str, info: ErrorInfo) -> bool: diff --git a/mypy/messages.py b/mypy/messages.py index f626d4c71916..5790531c2ad0 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1625,6 +1625,19 @@ def incompatible_typevar_value( ) def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: + + # In loops (and similar cases), the same expression might be analysed multiple + # times and thereby confronted with different types. We only want to raise a + # `comparison-overlap` error if it occurs in all cases and therefore collect the + # respective types of the current iteration here so that we can report the error + # later if it is persistent over all iteration steps: + for watcher in self.errors.get_watchers(): + if isinstance(watcher, IterationErrorWatcher): + watcher.iteration_dependent_errors.nonoverlapping_types[-1][ + (ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind) + ] = (left, right) + return + left_str = "element" if kind == "container" else "left operand" right_str = "container item" if kind == "container" else "right operand" message = "Non-overlapping {} check ({} type: {}, {} type: {})" @@ -2511,8 +2524,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None: for error_info in iter_errors.yield_uselessness_error_infos(): self.fail(*error_info[:2], code=error_info[2]) + msu = mypy.typeops.make_simplified_union + for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types(): + self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context) for types, context in iter_errors.yield_revealed_type_infos(): - self.reveal_type(mypy.typeops.make_simplified_union(types), context) + self.reveal_type(msu(types), context) def quote_type_string(type_string: str) -> str: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7fffd3ce94e5..0b62ca7baf16 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2446,6 +2446,32 @@ while x is not None and b(): x = f() [builtins fixtures/primitives.pyi] +[case testAvoidFalseNonOverlappingEqualityCheckInLoop1] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +x = 1 +while True: + if x == str(): + break + x = str() + if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int") + break +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop2] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +class A: ... +class B: ... +class C: ... + +x = A() +while True: + if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C") + break + x = B() +[builtins fixtures/primitives.pyi] + [case testNarrowPromotionsInsideUnions1] from typing import Union From 016eda607c11f456ddc9da85edc0b6d6bd1606f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 10 Aug 2025 10:43:08 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/errors.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/mypy/errors.py b/mypy/errors.py index f28eaefefc53..6fce24d42d24 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -251,9 +251,7 @@ class IterationDependentErrors: # One dictionary of non-overlapping types per iteration step. Meaning of the key # tuple items: line, column, end_line, end_column, kind: nonoverlapping_types: list[ - dict[ - tuple[int, int, int | None, int | None, str], tuple[Type, Type] - ], + dict[tuple[int, int, int | None, int | None, str], tuple[Type, Type]], ] def __init__(self) -> None: @@ -279,10 +277,9 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod context.end_column = error_info[5] yield error_info[1], context, error_info[0] - - def yield_nonoverlapping_types(self) -> Iterator[ - tuple[tuple[list[Type], list[Type]], str, Context] - ]: + def yield_nonoverlapping_types( + self, + ) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]: """Report expressions were non-overlapping types were detected for all iterations were the expression was reachable.""" @@ -290,9 +287,7 @@ def yield_nonoverlapping_types(self) -> Iterator[ for candidate in set(chain(*self.nonoverlapping_types)): if all( (candidate in nonoverlap) or (candidate[0] in lines) - for nonoverlap, lines in zip( - self.nonoverlapping_types, self.unreachable_lines - ) + for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines) ): selected.add(candidate) From 2182e5d1260c9143c0209a210f31623dad3c73e5 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 10 Aug 2025 21:20:05 +0200 Subject: [PATCH 3/5] add testAvoidFalseNonOverlappingEqualityCheckInLoop3 --- test-data/unit/check-narrowing.test | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 0b62ca7baf16..de0451a02795 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2472,6 +2472,15 @@ while True: x = B() [builtins fixtures/primitives.pyi] +[case testAvoidFalseNonOverlappingEqualityCheckInLoop3] +# flags: --strict-equality + +for y in [1.0]: + if y is not None or y != "None": # E: Non-overlapping equality check (left operand type: "float", right operand type: "Literal['None']") + ... + +[builtins fixtures/primitives.pyi] + [case testNarrowPromotionsInsideUnions1] from typing import Union From edf63d7e8f9b01f4094efff0643e2d26da037857 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 11 Aug 2025 08:09:41 +0200 Subject: [PATCH 4/5] respect that other error watcher may want to filter comparison-overlap errors --- mypy/messages.py | 2 ++ test-data/unit/check-narrowing.test | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/messages.py b/mypy/messages.py index 5790531c2ad0..bfb3a4789da7 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1632,6 +1632,8 @@ def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) # respective types of the current iteration here so that we can report the error # later if it is persistent over all iteration steps: for watcher in self.errors.get_watchers(): + if watcher._filter: + return if isinstance(watcher, IterationErrorWatcher): watcher.iteration_dependent_errors.nonoverlapping_types[-1][ (ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index de0451a02795..04f2c2fccd34 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2476,7 +2476,7 @@ while True: # flags: --strict-equality for y in [1.0]: - if y is not None or y != "None": # E: Non-overlapping equality check (left operand type: "float", right operand type: "Literal['None']") + if y is not None or y != "None": ... [builtins fixtures/primitives.pyi] From 6c3b489446f655f46236a0b527c64e014c1d9d6c Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 11 Aug 2025 21:16:00 +0200 Subject: [PATCH 5/5] respect that other error watcher may want to filter comparison-overlap errors (return -> break) --- mypy/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/messages.py b/mypy/messages.py index bfb3a4789da7..95c74a14de8c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1633,7 +1633,7 @@ def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) # later if it is persistent over all iteration steps: for watcher in self.errors.get_watchers(): if watcher._filter: - return + break if isinstance(watcher, IterationErrorWatcher): watcher.iteration_dependent_errors.nonoverlapping_types[-1][ (ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind)