From fd53440951594fbf6a0dcf3872194e4871c30052 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 10:37:47 +0200 Subject: [PATCH 01/13] fix(rewrite): prevent walrus operator double evaluation in assertions Fixes #14445 - assertion rewriting evaluated NamedExpr (:=) expressions multiple times, causing side effects to fire repeatedly. The root cause was the `variables_overwrite` mechanism which stored and re-evaluated NamedExpr AST nodes in subsequent assertions, in `_call_reprcompare`'s results tuple, and in explanation formatting. The fix: - visit_NamedExpr: reference the target variable in explanations instead of re-evaluating the full expression - visit_Compare: assign left-side NamedExpr to a temp before right-side hoisting; freeze left_res when a comparator walrus targets the same name; replace NamedExpr entries in `results` with target variables - visit_BoolOp: capture short-circuit condition in a stable temp for the explanation path; remove walrus target rename logic - visit_Call: remove variables_overwrite substitution (walrus now properly assigns to user variables in its natural evaluation position) - Remove variables_overwrite, scope tracking, Sentinel class Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 102 +++++++++++-------------------- testing/test_assertrewrite.py | 68 ++++++++++++++++++++- 2 files changed, 102 insertions(+), 68 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99815b70cf1..d3af1db26bc 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -3,7 +3,6 @@ from __future__ import annotations import ast -from collections import defaultdict from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator @@ -58,10 +57,6 @@ from _pytest.assertion import AssertionState -class Sentinel: - pass - - assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -69,9 +64,6 @@ class Sentinel: PYC_EXT = ".py" + ((__debug__ and "c") or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -# Special marker that denotes we have just left a scope definition -_SCOPE_END_MARKER = Sentinel() - class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -652,14 +644,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. - :scope: A tuple containing the current scope used for variables_overwrite. - - :variables_overwrite: A dict filled with references to variables - that change value within an assert. This happens when a variable is - reassigned with the walrus operator - - This state, except the variables_overwrite, is reset on every new assert - statement visited and used by the other visitors. + This state is reset on every new assert statement visited and used by + the other visitors. """ def __init__( @@ -675,10 +661,6 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.scope: tuple[ast.AST, ...] = () - self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = ( - defaultdict(dict) - ) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -728,16 +710,9 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - self.scope = (mod,) - nodes: list[ast.AST | Sentinel] = [mod] + nodes: list[ast.AST] = [mod] while nodes: node = nodes.pop() - if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef): - self.scope = tuple((*self.scope, node)) - nodes.append(_SCOPE_END_MARKER) - if node == _SCOPE_END_MARKER: - self.scope = self.scope[:-1] - continue assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): @@ -964,15 +939,17 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: return self.statements def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: - # This method handles the 'walrus operator' repr of the target - # name if it's a local variable or _should_repr_global_name() - # thinks it's acceptable. + # Return the NamedExpr as-is so it evaluates in its natural position + # (preserving left-to-right evaluation order). For the explanation, + # reference the target variable (already assigned by the walrus) to + # avoid re-evaluating the expression. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id + target_name = ast.Name(target_id, ast.Load()) inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) - dorepr = self.helper("_should_repr_global_name", name) + dorepr = self.helper("_should_repr_global_name", target_name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) + expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id)) return name, self.explanation_param(expr) def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: @@ -998,20 +975,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: for i, v in enumerate(boolop.values): if i: fail_inner: list[ast.stmt] = [] - # cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 + # expl_cond is set in a prior loop iteration below + self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - match v: - # Check if the left operand is an ast.NamedExpr and the value has already been visited - case ast.Compare( - left=ast.NamedExpr(target=ast.Name(id=target_id)) - ) if target_id in [ - e.id for e in boolop.values[:i] if hasattr(e, "id") - ]: - pytest_temp = self.variable() - self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] - # mypy's false positive, we're checking that the 'target' attribute exists. - v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1022,8 +988,16 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: cond: ast.expr = res if is_or: cond = ast.UnaryOp(ast.Not(), cond) + # Capture the condition in a temp variable so the explanation + # path (which runs after walrus operators may have modified + # the original variable) sees the correct truthiness. + cond_var = self.variable() + body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) + expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 inner: list[ast.stmt] = [] - self.statements.append(ast.If(cond, inner, [])) + self.statements.append( + ast.If(ast.Name(cond_var, ast.Load()), inner, []) + ) self.statements = body = inner self.statements = save self.expl_stmts = fail_save @@ -1053,19 +1027,10 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( - self.scope, {} - ): - arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - match keyword.value: - case ast.Name(id=id) if id in self.variables_overwrite.get( - self.scope, {} - ): - keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1100,17 +1065,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() - # We first check if we have overwritten a variable in the previous assert - match comp.left: - case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( - self.scope, {} - ): - comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] - case ast.NamedExpr(target=ast.Name(id=target_id)): - self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" + # If the left operand is a NamedExpr, assign it to a temp so the + # walrus executes before any right-side expressions are hoisted. + if isinstance(left_res, ast.NamedExpr): + left_res = self.assign(left_res) res_variables = [self.variable() for i in range(len(comp.ops))] load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] @@ -1119,13 +1080,16 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: + # If the next operand is a walrus that assigns to the same name as + # the current left_res, we must freeze left_res's value before the + # walrus modifies it. match (next_operand, left_res): case ( ast.NamedExpr(target=ast.Name(id=target_id)), ast.Name(id=name_id), ) if target_id == name_id: - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + left_res = self.assign(left_res) + results[-1] = left_res next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): @@ -1138,6 +1102,12 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl + # Replace NamedExpr entries in results with their target variable + # to avoid re-evaluating walrus operators in the explanation path. + results = [ + ast.Name(r.target.id, ast.Load()) if isinstance(r, ast.NamedExpr) else r + for r in results + ] # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( "_call_reprcompare", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2668001af65..7c131c9a4f5 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) + result.stdout.fnmatch_lines(["*assert not (False and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + result.stdout.fnmatch_lines(["*assert not (None and None is None)"]) def test_assertion_walrus_operator_value_changes_cleared_after_each_test( self, pytester: Pytester @@ -1846,6 +1846,70 @@ def test_2(): assert result.ret == 0 +class TestIssue14445: + """Regression tests for #14445: walrus operator double evaluation.""" + + def test_walrus_no_double_eval_basic(self, pytester: Pytester) -> None: + """Walrus captures the value at assignment time, not re-evaluated later.""" + pytester.makepyfile( + """ + class Counter: + def __init__(self): + self.value = 0 + def increment(self): + self.value += 1 + + def test_walrus_in_assertion_basic(): + c = Counter() + assert (before := c.value) == 0 + c.increment() + assert before != (after := c.value) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_running_counter(self, pytester: Pytester) -> None: + """Walrus increments fire exactly once per assert statement.""" + pytester.makepyfile( + """ + def test_walrus_running_counter(): + count = 0 + items = [] + items.append("a") + assert (count := count + 1) == len(items) + items.append("b") + assert (count := count + 1) == len(items) + items.append("c") + assert (count := count + 1) == len(items) + assert count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_in_function_call(self, pytester: Pytester) -> None: + """Walrus in function call arguments not evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_side_effect(): + assert (val := side_effect()) == 1 + assert val == 1 + assert (val := side_effect()) == 2 + assert val == 2 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" ) From 9f371c71c74a4271c30a91342ff6a6f981ed6150 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 12:59:20 +0200 Subject: [PATCH 02/13] Add changelog fragment for #14445 Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- changelog/14445.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/14445.bugfix.rst diff --git a/changelog/14445.bugfix.rst b/changelog/14445.bugfix.rst new file mode 100644 index 00000000000..aaae0c615f5 --- /dev/null +++ b/changelog/14445.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function). From df0e0a9d2c5fda04ed59b7c94c714ba9f369ef03 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:04:00 +0200 Subject: [PATCH 03/13] test(rewrite): add xfail tests for remaining walrus edge cases Add tests for two remaining walrus double-evaluation scenarios: - Bare NamedExpr as BoolOp operand evaluated twice via condition check - Same walrus target in chained comparison evaluated multiple times Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- testing/test_assertrewrite.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 7c131c9a4f5..91394263756 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1909,6 +1909,46 @@ def test_walrus_side_effect(): result = pytester.runpytest() assert result.ret == 0 + @pytest.mark.xfail(reason="BoolOp condition re-evaluates walrus operand") + def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None: + """Bare walrus as a BoolOp operand must not be evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_boolop(): + assert (x := side_effect()) and x == 1 + assert call_count == 1 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.xfail(reason="Chained compare re-evaluates walrus with same target") + def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None: + """Same walrus target in chained comparison must evaluate each once.""" + pytester.makepyfile( + """ + call_count = 0 + + def track(value): + global call_count + call_count += 1 + return value + + def test_walrus_chained(): + assert (x := track(1)) < (x := track(3)) < (x := track(5)) + assert call_count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" From 60f8e24297a28925baaf50921faa1872721b8321 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:05:41 +0200 Subject: [PATCH 04/13] fix(rewrite): avoid double evaluation of walrus in BoolOp condition Use the already-assigned res_var to build the short-circuit condition instead of the raw visitor result, preventing bare NamedExpr operands from being evaluated a second time when checking truthiness. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 9 +++++---- testing/test_assertrewrite.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index d3af1db26bc..d8ea1c5ec8e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -985,12 +985,13 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond: ast.expr = res + # Use res_var (already assigned above) rather than res directly, + # so that NamedExpr operands aren't evaluated a second time. + cond: ast.expr = ast.Name(res_var, ast.Load()) if is_or: cond = ast.UnaryOp(ast.Not(), cond) - # Capture the condition in a temp variable so the explanation - # path (which runs after walrus operators may have modified - # the original variable) sees the correct truthiness. + # Capture the condition in a stable temp for the explanation + # path — res_var is overwritten by subsequent operands. cond_var = self.variable() body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 91394263756..341514b377e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1909,7 +1909,6 @@ def test_walrus_side_effect(): result = pytester.runpytest() assert result.ret == 0 - @pytest.mark.xfail(reason="BoolOp condition re-evaluates walrus operand") def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None: """Bare walrus as a BoolOp operand must not be evaluated twice.""" pytester.makepyfile( From ffdc372a3613faed3af8f6d9e86c030223d945ed Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 13:08:10 +0200 Subject: [PATCH 05/13] fix(rewrite): assign walrus comparators to temps in chained comparisons In a chained comparison like `(x := f()) < (x := g()) < (x := h())`, each NamedExpr comparator is now assigned to a temp variable so it evaluates exactly once. Previously the raw NamedExpr node would be reused as left_res in the next iteration, causing double evaluation. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Sonnet 4 --- src/_pytest/assertion/rewrite.py | 11 +++++------ testing/test_assertrewrite.py | 1 - 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index d8ea1c5ec8e..3fa3217f6e0 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1095,6 +1095,11 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" + # Assign NamedExpr comparators to a temp so each walrus evaluates + # exactly once — critical for chained comparisons where the same + # node would otherwise be re-evaluated as left_res next iteration. + if isinstance(next_res, ast.NamedExpr): + next_res = self.assign(next_res) results.append(next_res) sym = BINOP_MAP[op.__class__] syms.append(ast.Constant(sym)) @@ -1103,12 +1108,6 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl - # Replace NamedExpr entries in results with their target variable - # to avoid re-evaluating walrus operators in the explanation path. - results = [ - ast.Name(r.target.id, ast.Load()) if isinstance(r, ast.NamedExpr) else r - for r in results - ] # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( "_call_reprcompare", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 341514b377e..11995321826 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1928,7 +1928,6 @@ def test_walrus_boolop(): result = pytester.runpytest() assert result.ret == 0 - @pytest.mark.xfail(reason="Chained compare re-evaluates walrus with same target") def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None: """Same walrus target in chained comparison must evaluate each once.""" pytester.makepyfile( From 806b27e8c98e2e97a4bf3cbcc324220024c91c84 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 16:56:20 +0200 Subject: [PATCH 06/13] test(rewrite): add systematic assertion rewriting test infrastructure Create testing/test_assertrewrite_coverage.py with reusable helpers for verifying assertion rewriting behavior across all expression types: - get_failure_message: compile rewritten source and extract failure text - assert_introspects: verify failure messages contain expected intermediates - assert_single_evaluation: verify no double-evaluation of side effects - assert_passes_when_true: verify rewritten asserts don't false-positive - assert_semantically_equivalent: verify rewrite preserves pass/fail semantics Includes smoke tests validating the helpers themselves. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- testing/test_assertrewrite_coverage.py | 297 +++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 testing/test_assertrewrite_coverage.py diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py new file mode 100644 index 00000000000..8faa07071f5 --- /dev/null +++ b/testing/test_assertrewrite_coverage.py @@ -0,0 +1,297 @@ +"""Systematic coverage tests for assertion rewriting. + +This module provides a structured testing framework that verifies assertion +rewriting behavior across all expression types, checking: + +1. Introspection depth: failure messages contain expected intermediate values +2. Semantic correctness: rewritten code has identical behavior to original +3. Single evaluation: side-effecting expressions are not evaluated multiple times +""" + +from __future__ import annotations + +import ast +from collections.abc import Callable +from collections.abc import Mapping +import sys +import textwrap +from typing import cast + +from _pytest.assertion.rewrite import rewrite_asserts +import pytest + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _rewrite_source(src: str) -> ast.Module: + """Parse and rewrite assertions in source code.""" + tree = ast.parse(src) + rewrite_asserts(tree, src.encode()) + return tree + + +def get_failure_message( + src: str, + extra_ns: Mapping[str, object] | None = None, +) -> str: + """Compile rewritten source, execute it, and return the failure message. + + The source should contain a function named ``check`` with a failing assert. + Returns the AssertionError message string. + + Raises AssertionError via pytest.fail if the code does not raise. + """ + src = textwrap.dedent(src) + mod = _rewrite_source(src) + code = compile(mod, "", "exec") + ns: dict[str, object] = {} + if extra_ns is not None: + ns.update(extra_ns) + exec(code, ns) + func = cast(Callable[[], None], ns["check"]) + try: + func() + except AssertionError: + s = str(sys.exc_info()[1]) + if not s.startswith("assert"): + return "AssertionError: " + s + return s + else: + pytest.fail("check() did not raise AssertionError") + + +def assert_introspects( + src: str, + *, + must_contain: list[str], + must_not_contain: list[str] | None = None, + extra_ns: Mapping[str, object] | None = None, +) -> str: + """Verify a failing assert produces a message with expected intermediate values. + + Parameters + ---------- + src : str + Source code containing a ``check()`` function with a failing assertion. + must_contain : list[str] + Substrings that MUST appear in the failure message. + must_not_contain : list[str] | None + Substrings that must NOT appear in the failure message. + extra_ns : Mapping[str, object] | None + Additional namespace entries available during execution. + + Returns + ------- + str + The full failure message (for further inspection if needed). + """ + msg = get_failure_message(src, extra_ns=extra_ns) + for expected in must_contain: + assert expected in msg, ( + f"Expected {expected!r} in failure message.\nGot:\n{msg}" + ) + for unexpected in must_not_contain or []: + assert unexpected not in msg, ( + f"Did NOT expect {unexpected!r} in failure message.\nGot:\n{msg}" + ) + return msg + + +def assert_single_evaluation( + src: str, + *, + expected_call_count: int = 1, + extra_ns: Mapping[str, object] | None = None, +) -> None: + """Verify side-effecting expressions in assert are evaluated exactly once. + + The source should define a ``check()`` function and use a ``counter`` list + (provided via extra_ns or defined in the source) that tracks how many times + a side-effecting expression is evaluated. + + Parameters + ---------- + src : str + Source containing a ``check()`` function whose assert has side effects. + expected_call_count : int + How many times the side-effecting expression should be evaluated. + extra_ns : Mapping[str, object] | None + Additional namespace. Should include ``counter`` if not defined in src. + """ + src = textwrap.dedent(src) + mod = _rewrite_source(src) + code = compile(mod, "", "exec") + ns: dict[str, object] = {"counter": [0]} + if extra_ns is not None: + ns.update(extra_ns) + exec(code, ns) + func = cast(Callable[[], None], ns["check"]) + counter = cast(list[int], ns["counter"]) + counter[0] = 0 + try: + func() + except AssertionError: + pass + actual = counter[0] + assert actual == expected_call_count, ( + f"Expression evaluated {actual} times, expected {expected_call_count}" + ) + + +def assert_passes_when_true( + src: str, + *, + extra_ns: Mapping[str, object] | None = None, +) -> None: + """Verify rewritten assertion does not raise when the condition is true. + + Parameters + ---------- + src : str + Source containing a ``check()`` function with a passing assertion. + extra_ns : Mapping[str, object] | None + Additional namespace entries available during execution. + """ + src = textwrap.dedent(src) + mod = _rewrite_source(src) + code = compile(mod, "", "exec") + ns: dict[str, object] = {} + if extra_ns is not None: + ns.update(extra_ns) + exec(code, ns) + func = cast(Callable[[], None], ns["check"]) + func() + + +def assert_semantically_equivalent( + src: str, + *, + extra_ns: Mapping[str, object] | None = None, +) -> None: + """Verify rewritten code has same pass/fail semantics as unrewritten code. + + Runs the source both with and without rewriting, and asserts they agree + on whether an AssertionError is raised. + + Parameters + ---------- + src : str + Source containing a ``check()`` function with an assertion. + extra_ns : Mapping[str, object] | None + Additional namespace entries available during execution. + """ + src = textwrap.dedent(src) + + # Run without rewriting + plain_code = compile(src, "", "exec") + plain_ns: dict[str, object] = {} + if extra_ns is not None: + plain_ns.update(extra_ns) + exec(plain_code, plain_ns) + plain_func = cast(Callable[[], None], plain_ns["check"]) + plain_raised = False + try: + plain_func() + except AssertionError: + plain_raised = True + + # Run with rewriting + mod = _rewrite_source(src) + rewritten_code = compile(mod, "", "exec") + rewritten_ns: dict[str, object] = {} + if extra_ns is not None: + rewritten_ns.update(extra_ns) + exec(rewritten_code, rewritten_ns) + rewritten_func = cast(Callable[[], None], rewritten_ns["check"]) + rewritten_raised = False + try: + rewritten_func() + except AssertionError: + rewritten_raised = True + + assert plain_raised == rewritten_raised, ( + f"Semantic mismatch: plain {'raised' if plain_raised else 'passed'}, " + f"rewritten {'raised' if rewritten_raised else 'passed'}" + ) + + +# --------------------------------------------------------------------------- +# Smoke tests for the helpers themselves +# --------------------------------------------------------------------------- + + +class TestHelpersSmokeTest: + """Verify the test helpers work correctly.""" + + def test_get_failure_message_returns_message(self) -> None: + msg = get_failure_message(""" +def check(): + assert 1 == 2 +""") + assert "assert 1 == 2" in msg + + def test_get_failure_message_fails_on_passing_assert(self) -> None: + with pytest.raises(pytest.fail.Exception, match="did not raise"): + get_failure_message(""" +def check(): + assert 1 == 1 +""") + + def test_assert_introspects_succeeds(self) -> None: + assert_introspects( + """ +def check(): + x = 3 + assert x == 5 +""", + must_contain=["assert 3 == 5"], + ) + + def test_assert_introspects_fails_on_missing(self) -> None: + with pytest.raises(AssertionError, match=r"Expected.*in failure"): + assert_introspects( + """ +def check(): + assert 1 == 2 +""", + must_contain=["this is not in the message"], + ) + + def test_assert_single_evaluation(self) -> None: + assert_single_evaluation(""" +def check(): + def inc(): + counter[0] += 1 + return False + assert inc() +""") + + def test_assert_passes_when_true(self) -> None: + assert_passes_when_true(""" +def check(): + assert 1 == 1 +""") + + def test_assert_semantically_equivalent_passing(self) -> None: + assert_semantically_equivalent(""" +def check(): + assert 1 == 1 +""") + + def test_assert_semantically_equivalent_failing(self) -> None: + assert_semantically_equivalent(""" +def check(): + assert 1 == 2 +""") + + def test_assert_semantically_equivalent_detects_mismatch(self) -> None: + # This would only trigger on a bug in the rewriter itself; + # for now just verify both paths execute without error. + assert_semantically_equivalent(""" +def check(): + x = [1, 2, 3] + assert len(x) == 3 +""") From 04fcadb293602556b6f376cd9d8aa66870744149 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 17:01:34 +0200 Subject: [PATCH 07/13] test(rewrite): add introspection matrix covering all expression types Add parametrized test classes that document current assertion rewriting behavior for each AST expression type: - Compare, BoolOp, UnaryOp, BinOp: verified working with correct output - Call: works but shows repr for local/variable callables - Attribute, Name, Walrus: verified working - Subscript, IfExp, ContainerLiteral: marked xfail (blind spots) - MethodCall: shows noisy bound-method intermediate (xfail) - Comprehension, FString: semantics preserved, result shown in compare The xfail tests document expected improvements and will be unmarked as each blind spot is addressed in subsequent commits. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- testing/test_assertrewrite_coverage.py | 471 +++++++++++++++++++++++++ 1 file changed, 471 insertions(+) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index 8faa07071f5..fb8fa317bee 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -295,3 +295,474 @@ def check(): x = [1, 2, 3] assert len(x) == 3 """) + + +# --------------------------------------------------------------------------- +# Introspection matrix: verify what information each expression type exposes +# --------------------------------------------------------------------------- + + +class TestIntrospectionCompare: + """Comparisons (==, !=, <, >, <=, >=, in, not in, is, is not).""" + + def test_simple_equality(self) -> None: + assert_introspects( + """ +def check(): + x = 3 + assert x == 5 +""", + must_contain=["assert 3 == 5"], + ) + + def test_chained_compare(self) -> None: + # Chained compares only show the failing pair + assert_introspects( + """ +def check(): + x = 10 + assert 1 < x < 5 +""", + must_contain=["assert 10 < 5"], + ) + + def test_in_operator(self) -> None: + assert_introspects( + """ +def check(): + x = 4 + assert x in [1, 2, 3] +""", + must_contain=["assert 4 in [1, 2, 3]"], + ) + + def test_not_in_operator(self) -> None: + assert_introspects( + """ +def check(): + x = 2 + assert x not in [1, 2, 3] +""", + must_contain=["assert 2 not in [1, 2, 3]"], + ) + + def test_is_operator(self) -> None: + assert_introspects( + """ +def check(): + x = [] + y = [] + assert x is y +""", + must_contain=["assert [] is []"], + ) + + +class TestIntrospectionBoolOp: + """Boolean operations (and, or) with short-circuit.""" + + def test_and_both_shown(self) -> None: + assert_introspects( + """ +def check(): + a = True + b = False + assert a and b +""", + must_contain=["(True and False)"], + ) + + def test_or_both_shown(self) -> None: + assert_introspects( + """ +def check(): + a = False + b = False + assert a or b +""", + must_contain=["(False or False)"], + ) + + def test_and_short_circuit(self) -> None: + assert_introspects( + """ +def check(): + a = False + assert a and explode +""", + must_contain=["False"], + ) + + +class TestIntrospectionUnaryOp: + """Unary operations (not, ~, -, +).""" + + def test_not(self) -> None: + assert_introspects( + """ +def check(): + x = True + assert not x +""", + must_contain=["assert not True"], + ) + + def test_invert(self) -> None: + # ~(-1) == 0, which is falsy + assert_introspects( + """ +def check(): + x = -1 + assert ~x +""", + must_contain=["assert ~-1"], + ) + + +class TestIntrospectionBinOp: + """Binary operations (+, -, *, /, etc.).""" + + def test_addition(self) -> None: + assert_introspects( + """ +def check(): + x = 3 + y = 4 + assert x + y == 10 +""", + must_contain=["(3 + 4)"], + ) + + def test_subtraction(self) -> None: + assert_introspects( + """ +def check(): + x = 3 + y = 4 + assert x - y == 10 +""", + must_contain=["(3 - 4)"], + ) + + +class TestIntrospectionCall: + """Function/method calls.""" + + def test_simple_call_shows_result(self) -> None: + # Currently local functions show full repr in the "where" line + assert_introspects( + """ +def check(): + def f(): + return 42 + assert f() == 100 +""", + must_contain=["where 42 = ", "()"], + ) + + def test_call_with_args_shows_result(self) -> None: + assert_introspects( + """ +def check(): + def f(x): + return x * 2 + assert f(3) == 10 +""", + must_contain=["where 6 = ", "(3)"], + ) + + @pytest.mark.xfail( + reason="Local function calls show full repr: blind spot" + ) + def test_simple_call_clean_name(self) -> None: + """Ideally the message should show 'f()' not '()'.""" + assert_introspects( + """ +def check(): + def f(): + return 42 + assert f() == 100 +""", + must_contain=["where 42 = f()"], + must_not_contain=[" None: + assert_introspects( + """ +def check(): + class Obj: + def method(self): + return 42 + obj = Obj() + assert obj.method() == 100 +""", + must_contain=["42", "100"], + ) + + +class TestIntrospectionAttribute: + """Attribute access.""" + + def test_attribute_access(self) -> None: + assert_introspects( + """ +def check(): + class Obj: + x = 3 + def __repr__(self): + return "Obj()" + obj = Obj() + assert obj.x == 5 +""", + must_contain=["where 3 = Obj().x"], + ) + + +class TestIntrospectionName: + """Variable name display.""" + + def test_local_variable_shown(self) -> None: + assert_introspects( + """ +def check(): + result = 42 + assert result == 100 +""", + must_contain=["assert 42 == 100"], + ) + + +class TestIntrospectionSubscript: + """Subscript / indexing — currently hits generic_visit.""" + + @pytest.mark.xfail(reason="Subscript not introspected: blind spot") + def test_dict_subscript_shows_key_and_container(self) -> None: + assert_introspects( + """ +def check(): + d = {"a": 1, "b": 2} + assert d["a"] == 99 +""", + must_contain=["where 1 = ", '["a"]'], + ) + + @pytest.mark.xfail(reason="Subscript not introspected: blind spot") + def test_list_subscript_shows_index_and_container(self) -> None: + assert_introspects( + """ +def check(): + items = [10, 20, 30] + assert items[1] == 99 +""", + must_contain=["where 20 = ", "[1]"], + ) + + def test_subscript_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + d = {"key": "value"} + assert d["key"] == "wrong" +""") + + def test_subscript_in_compare_shows_value(self) -> None: + """Even without decomposition, the value is shown in comparisons.""" + assert_introspects( + """ +def check(): + d = {"a": 1} + assert d["a"] == 99 +""", + must_contain=["assert 1 == 99"], + ) + + +class TestIntrospectionIfExp: + """Ternary / if-expression — currently hits generic_visit.""" + + @pytest.mark.xfail(reason="IfExp not introspected: blind spot") + def test_ifexp_shows_condition_and_branch(self) -> None: + assert_introspects( + """ +def check(): + flag = True + assert (0 if flag else 1) == 1 +""", + must_contain=["flag", "True"], + ) + + def test_ifexp_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + flag = True + assert (0 if flag else 1) == 1 +""") + + def test_ifexp_in_compare_shows_result(self) -> None: + assert_introspects( + """ +def check(): + flag = True + assert (0 if flag else 1) == 99 +""", + must_contain=["assert 0 == 99"], + ) + + +class TestIntrospectionContainerLiteral: + """Container literals ([...], {...}, {k:v}) — currently hits generic_visit.""" + + @pytest.mark.xfail(reason="Container literals not introspected: blind spot") + def test_list_literal_shows_elements(self) -> None: + assert_introspects( + """ +def check(): + def f(): + return 99 + assert [f(), 2, 3] == [1, 2, 3] +""", + must_contain=["where 99 = f()"], + ) + + def test_list_literal_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + assert [1, 2, 3] == [1, 2, 4] +""") + + def test_dict_literal_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + assert {"a": 1} == {"a": 2} +""") + + +class TestIntrospectionComprehension: + """Comprehensions — currently hits generic_visit.""" + + def test_listcomp_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + assert [x * 2 for x in range(3)] == [0, 2, 5] +""") + + def test_listcomp_in_compare_shows_result(self) -> None: + assert_introspects( + """ +def check(): + assert [x * 2 for x in range(3)] == [0, 2, 5] +""", + must_contain=["[0, 2, 4]"], + ) + + +class TestIntrospectionFString: + """F-string expressions — currently hits generic_visit.""" + + def test_fstring_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + x = 42 + assert f"value={x}" == "value=99" +""") + + def test_fstring_in_compare_shows_result(self) -> None: + assert_introspects( + """ +def check(): + x = 42 + assert f"value={x}" == "value=99" +""", + must_contain=["value=42"], + ) + + +class TestIntrospectionMethodCall: + """Method calls — bound method intermediate display.""" + + def test_method_call_result_shown(self) -> None: + assert_introspects( + """ +def check(): + class Obj: + def compute(self): + return 42 + def __repr__(self): + return "Obj()" + obj = Obj() + assert obj.compute() == 100 +""", + must_contain=["where 42 = "], + ) + + @pytest.mark.xfail( + reason="Method call shows noisy bound-method intermediate: blind spot" + ) + def test_method_call_no_bound_method_noise(self) -> None: + """The 'where method = obj.method' line is noisy and unhelpful.""" + msg = get_failure_message(""" +def check(): + class Obj: + def compute(self): + return 42 + def __repr__(self): + return "Obj()" + obj = Obj() + assert obj.compute() == 100 +""") + lines = msg.splitlines() + # Ideally the message should NOT have a separate "where compute = ..." + # line showing the bound method object — it adds noise without value + for line in lines: + assert "where compute = " not in line, ( + f"Noisy bound-method intermediate found:\n{msg}" + ) + + def test_callable_variable_shows_result(self) -> None: + # Current behavior: shows full function repr, not variable name + assert_introspects( + """ +def check(): + def factory(): + return 42 + fn = factory + assert fn() == 100 +""", + must_contain=["where 42 = ", "()"], + ) + + @pytest.mark.xfail(reason="Callable variables show repr: blind spot") + def test_callable_variable_clean_name(self) -> None: + """Ideally should show 'fn()' not '()'.""" + assert_introspects( + """ +def check(): + def factory(): + return 42 + fn = factory + assert fn() == 100 +""", + must_contain=["where 42 = fn()"], + must_not_contain=[" None: + assert_introspects( + """ +def check(): + x = 10 + assert (y := x * 2) == 100 +""", + must_contain=["assert 20 == 100"], + ) + + def test_walrus_semantics_preserved(self) -> None: + assert_semantically_equivalent(""" +def check(): + x = 10 + assert (y := x * 2) == 100 +""") From 0b8a68163f7092945f575586e3fd3d0fb9e44010 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 17:02:38 +0200 Subject: [PATCH 08/13] test(rewrite): add single-evaluation test suite for all expression types Add TestSingleEvaluation class that systematically verifies no expression is evaluated multiple times during assertion rewriting. Covers: - Calls in compare, boolean, unary, binop contexts - Attribute/property access - Subscript (dict __getitem__) - Walrus operator in compare, boolean, and chained compare - Method calls - Nested calls (inner + outer counted separately) - Multiple comparators in chained comparisons - IfExp conditions - Comprehension generators All tests pass, confirming the #14445 fix holds across all expression types. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- testing/test_assertrewrite_coverage.py | 172 +++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index fb8fa317bee..b21360edf2a 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -766,3 +766,175 @@ def check(): x = 10 assert (y := x * 2) == 100 """) + + +# --------------------------------------------------------------------------- +# Single-evaluation tests: ensure no expression is evaluated multiple times +# --------------------------------------------------------------------------- + + +class TestSingleEvaluation: + """Verify the rewriter doesn't cause double-evaluation of side effects. + + Each test uses a counter to track how many times a side-effecting + expression is evaluated. The rewritten assert should evaluate each + expression exactly once, regardless of whether the assertion passes or fails. + """ + + def test_call_in_compare_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return 42 + assert side_effect() == 100 +""") + + def test_call_in_boolean_and_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return True + assert side_effect() and False +""") + + def test_call_in_boolean_or_short_circuit(self) -> None: + # With `or`, if first is truthy, second is NOT evaluated + assert_single_evaluation( + """ +def check(): + def first(): + counter[0] += 1 + return False + def second(): + counter[0] += 1 + return False + assert first() or second() +""", + expected_call_count=2, + ) + + def test_call_in_unary_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return True + assert not side_effect() +""") + + def test_call_in_binop_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return 5 + assert side_effect() + 1 == 100 +""") + + def test_attribute_access_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + class Obj: + @property + def prop(self): + counter[0] += 1 + return 42 + obj = Obj() + assert obj.prop == 100 +""") + + def test_subscript_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + class CountingDict(dict): + def __getitem__(self, key): + counter[0] += 1 + return super().__getitem__(key) + d = CountingDict(a=1) + assert d["a"] == 100 +""") + + def test_walrus_in_compare_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return 42 + assert (x := side_effect()) == 100 +""") + + def test_walrus_in_boolean_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return 42 + assert (x := side_effect()) and False +""") + + def test_walrus_in_chained_compare_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def side_effect(): + counter[0] += 1 + return 5 + assert 1 < (x := side_effect()) < 3 +""") + + def test_method_call_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + class Obj: + def compute(self): + counter[0] += 1 + return 42 + obj = Obj() + assert obj.compute() == 100 +""") + + def test_nested_calls_each_evaluated_once(self) -> None: + assert_single_evaluation( + """ +def check(): + def outer(x): + counter[0] += 1 + return x + 1 + def inner(): + counter[0] += 1 + return 5 + assert outer(inner()) == 100 +""", + expected_call_count=2, + ) + + def test_multiple_comparators_evaluated_once_each(self) -> None: + assert_single_evaluation( + """ +def check(): + def make_val(n): + counter[0] += 1 + return n + assert make_val(1) < make_val(5) < make_val(3) +""", + expected_call_count=3, + ) + + def test_ifexp_condition_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def cond(): + counter[0] += 1 + return True + assert (0 if cond() else 1) == 1 +""") + + def test_comprehension_generator_evaluated_once(self) -> None: + assert_single_evaluation(""" +def check(): + def items(): + counter[0] += 1 + return [1, 2, 3] + assert [x * 2 for x in items()] == [2, 4, 7] +""") From 3b2d44ba9bfb37638f75036a5bd657611392692c Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 17:39:51 +0200 Subject: [PATCH 09/13] feat(rewrite): add visit_Subscript for container[key] introspection Implement a dedicated visitor for ast.Subscript in AssertionRewriter that decomposes container[key] expressions into separate container and key introspection. This produces failure messages like: assert 1 == 99 + where 1 = {'a': 1, 'b': 2}['a'] Previously, subscript expressions hit generic_visit and only showed the final value without decomposition into container and key. Slices (a[1:3]) still fall back to generic_visit since decomposing start/stop/step is rarely useful in assertion messages. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- src/_pytest/assertion/rewrite.py | 17 +++++++++++++++++ testing/test_assertrewrite_coverage.py | 6 ++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 3fa3217f6e0..397ac885e74 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1052,6 +1052,23 @@ def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]: new_starred = ast.Starred(res, starred.ctx) return new_starred, "*" + expl + def visit_Subscript(self, subscript: ast.Subscript) -> tuple[ast.Name, str]: + if not isinstance(subscript.ctx, ast.Load): + return self.generic_visit(subscript) + # For Slice objects (a[1:3]), fall back to generic — decomposing + # start/stop/step is rarely useful in assertion messages. + if isinstance(subscript.slice, ast.Slice): + return self.generic_visit(subscript) + value, value_expl = self.visit(subscript.value) + slice_res, slice_expl = self.visit(subscript.slice) + res = self.assign( + ast.copy_location(ast.Subscript(value, slice_res, ast.Load()), subscript) + ) + res_expl = self.explanation_param(self.display(res)) + pat = "%s\n{%s = %s[%s]\n}" + expl = pat % (res_expl, res_expl, value_expl, slice_expl) + return res, expl + def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: if not isinstance(attr.ctx, ast.Load): return self.generic_visit(attr) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index b21360edf2a..eca61a6820d 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -534,9 +534,8 @@ def check(): class TestIntrospectionSubscript: - """Subscript / indexing — currently hits generic_visit.""" + """Subscript / indexing — now has dedicated visitor.""" - @pytest.mark.xfail(reason="Subscript not introspected: blind spot") def test_dict_subscript_shows_key_and_container(self) -> None: assert_introspects( """ @@ -544,10 +543,9 @@ def check(): d = {"a": 1, "b": 2} assert d["a"] == 99 """, - must_contain=["where 1 = ", '["a"]'], + must_contain=["where 1 = ", "['a']"], ) - @pytest.mark.xfail(reason="Subscript not introspected: blind spot") def test_list_subscript_shows_index_and_container(self) -> None: assert_introspects( """ From 79d59da23469ccc1418eda06e30b13e06ab282ab Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 17:53:11 +0200 Subject: [PATCH 10/13] feat(rewrite): add visit_IfExp for ternary expression introspection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a dedicated visitor for ast.IfExp that introspects the condition value while preserving short-circuit semantics. Produces failure messages like: assert 0 == 99 + where 0 = (... if True else ...) The condition is rewritten for introspection (showing its evaluated value), but branches are kept as-is to preserve Python's short-circuit behavior — only the selected branch is evaluated. Previously, IfExp hit generic_visit showing only the final result without any insight into which branch was taken or why. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- src/_pytest/assertion/rewrite.py | 14 ++++++++++++++ testing/test_assertrewrite_coverage.py | 25 ++++++++++++++++++++----- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 397ac885e74..7d0a338a4f7 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1052,6 +1052,20 @@ def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]: new_starred = ast.Starred(res, starred.ctx) return new_starred, "*" + expl + def visit_IfExp(self, ifexp: ast.IfExp) -> tuple[ast.Name, str]: + # Introspect the condition but keep branches as-is to preserve + # short-circuit semantics (only the selected branch is evaluated). + cond_res, cond_expl = self.visit(ifexp.test) + # Reconstruct the IfExp with the rewritten condition but original + # branches to avoid evaluating both sides. + res = self.assign( + ast.copy_location(ast.IfExp(cond_res, ifexp.body, ifexp.orelse), ifexp) + ) + res_expl = self.explanation_param(self.display(res)) + pat = "%s\n{%s = (... if %s else ...)\n}" + expl = pat % (res_expl, res_expl, cond_expl) + return res, expl + def visit_Subscript(self, subscript: ast.Subscript) -> tuple[ast.Name, str]: if not isinstance(subscript.ctx, ast.Load): return self.generic_visit(subscript) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index eca61a6820d..86107706c4d 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -576,17 +576,16 @@ def check(): class TestIntrospectionIfExp: - """Ternary / if-expression — currently hits generic_visit.""" + """Ternary / if-expression — now has dedicated visitor.""" - @pytest.mark.xfail(reason="IfExp not introspected: blind spot") - def test_ifexp_shows_condition_and_branch(self) -> None: + def test_ifexp_shows_condition_value(self) -> None: assert_introspects( """ def check(): flag = True assert (0 if flag else 1) == 1 """, - must_contain=["flag", "True"], + must_contain=["if True else"], ) def test_ifexp_semantics_preserved(self) -> None: @@ -603,9 +602,25 @@ def check(): flag = True assert (0 if flag else 1) == 99 """, - must_contain=["assert 0 == 99"], + must_contain=["assert 0 == 99", "if True else"], ) + def test_ifexp_short_circuit_true(self) -> None: + """Orelse branch must NOT be evaluated when condition is True.""" + assert_passes_when_true(""" +def check(): + flag = True + assert (1 if flag else (1/0)) == 1 +""") + + def test_ifexp_short_circuit_false(self) -> None: + """Body branch must NOT be evaluated when condition is False.""" + assert_passes_when_true(""" +def check(): + flag = False + assert (1/0 if flag else 1) == 1 +""") + class TestIntrospectionContainerLiteral: """Container literals ([...], {...}, {k:v}) — currently hits generic_visit.""" From 6369f5462d086cde4d6f26c8154c3c0de9fef146 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 18:04:35 +0200 Subject: [PATCH 11/13] feat(rewrite): flatten method call display to remove bound-method noise MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor visit_Call to detect obj.method() patterns (where call.func is an ast.Attribute) and produce a flat explanation format: assert 42 == 100 + where 42 = Obj().compute() Instead of the previous nested format: assert 42 == 100 + where 42 = compute() + where compute = Obj().compute The bound-method intermediate line was noisy and unhelpful — users want to see what object the method was called on and what it returned, not the method object itself. Regular function calls (non-attribute) are unchanged. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- src/_pytest/assertion/rewrite.py | 43 ++++++++++++++++++++++++++ testing/test_assertrewrite_coverage.py | 14 +++------ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 7d0a338a4f7..a7896844059 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1023,6 +1023,12 @@ def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]: return res, explanation def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: + # For method calls (obj.method()), produce a flat explanation like + # "where result = obj.method(args)" instead of nesting the attribute + # access as a separate "where method = obj.method" line. + if isinstance(call.func, ast.Attribute) and isinstance(call.func.ctx, ast.Load): + return self._visit_method_call(call) + new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] @@ -1046,6 +1052,43 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" return res, outer_expl + def _visit_method_call(self, call: ast.Call) -> tuple[ast.Name, str]: + r"""Handle obj.method(...) calls with a flat explanation format. + + Produces: "result\n{result = obj_repr.method(args)\n}" + instead of nesting the bound-method intermediate. + """ + attr = call.func + assert isinstance(attr, ast.Attribute) + + # Visit the object (receiver) for introspection. + obj_res, obj_expl = self.visit(attr.value) + + # Visit arguments. + arg_expls = [] + new_args = [] + new_kwargs = [] + for arg in call.args: + res, expl = self.visit(arg) + arg_expls.append(expl) + new_args.append(res) + for keyword in call.keywords: + res, expl = self.visit(keyword.value) + new_kwargs.append(ast.keyword(keyword.arg, res)) + if keyword.arg: + arg_expls.append(keyword.arg + "=" + expl) + else: + arg_expls.append("**" + expl) + + # Build the call using the rewritten object's attribute. + new_func = ast.Attribute(obj_res, attr.attr, ast.Load()) + new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call) + res = self.assign(new_call) + res_expl = self.explanation_param(self.display(res)) + args_str = ", ".join(arg_expls) + expl = f"{res_expl}\n{{{res_expl} = {obj_expl}.{attr.attr}({args_str})\n}}" + return res, expl + def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]: # A Starred node can appear in a function call. res, expl = self.visit(starred.value) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index 86107706c4d..ec027bf8dd5 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -691,9 +691,10 @@ def check(): class TestIntrospectionMethodCall: - """Method calls — bound method intermediate display.""" + """Method calls — flat obj.method() display without bound-method noise.""" - def test_method_call_result_shown(self) -> None: + def test_method_call_flat_format(self) -> None: + """Method calls show 'where result = obj.method()' in one line.""" assert_introspects( """ def check(): @@ -705,14 +706,11 @@ def __repr__(self): obj = Obj() assert obj.compute() == 100 """, - must_contain=["where 42 = "], + must_contain=["where 42 = Obj().compute()"], ) - @pytest.mark.xfail( - reason="Method call shows noisy bound-method intermediate: blind spot" - ) def test_method_call_no_bound_method_noise(self) -> None: - """The 'where method = obj.method' line is noisy and unhelpful.""" + """No separate 'where compute = obj.compute' line.""" msg = get_failure_message(""" def check(): class Obj: @@ -724,8 +722,6 @@ def __repr__(self): assert obj.compute() == 100 """) lines = msg.splitlines() - # Ideally the message should NOT have a separate "where compute = ..." - # line showing the bound method object — it adds noise without value for line in lines: assert "where compute = " not in line, ( f"Noisy bound-method intermediate found:\n{msg}" From 55316621ba139cbb0b1b4d6bbbe1d86088acbd50 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 18:07:35 +0200 Subject: [PATCH 12/13] test(rewrite): add edge case and regression tests for new visitors Add TestEdgeCases class combining the new visitors (Subscript, IfExp, method call) with existing ones to verify correct behavior in complex scenarios: - Subscript with variable keys, call keys, and nested subscripts - Method calls with arguments, chained calls, and global objects - IfExp with call conditions - Walrus operator in subscript keys - Single-evaluation guarantees for all new visitors - Custom assert messages still work with new decomposition - Complex assertions combining multiple visitor types Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- testing/test_assertrewrite_coverage.py | 190 +++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index ec027bf8dd5..f75dd866000 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -947,3 +947,193 @@ def items(): return [1, 2, 3] assert [x * 2 for x in items()] == [2, 4, 7] """) + + +# --------------------------------------------------------------------------- +# Edge cases: combinations of new visitors with existing ones +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + """Regression and edge-case tests combining multiple expression types.""" + + def test_subscript_with_variable_key(self) -> None: + """Subscript where the key is a variable (not constant).""" + assert_introspects( + """ +def check(): + d = {"hello": 42} + key = "hello" + assert d[key] == 100 +""", + must_contain=["where 42 = ", "['hello']"], + ) + + def test_subscript_with_call_key(self) -> None: + """Subscript where the key is a function call.""" + assert_introspects( + """ +def check(): + d = {0: "zero", 1: "one"} + def get_key(): + return 0 + assert d[get_key()] == "wrong" +""", + must_contain=["'zero'", "'wrong'"], + ) + + def test_nested_subscript(self) -> None: + """Nested subscript: d[k1][k2].""" + assert_introspects( + """ +def check(): + d = {"a": {"b": 42}} + assert d["a"]["b"] == 100 +""", + must_contain=["42", "100"], + ) + + def test_method_call_with_args(self) -> None: + """Method call with arguments shows flat format.""" + assert_introspects( + """ +def check(): + class Calculator: + def add(self, a, b): + return a + b + def __repr__(self): + return "Calc()" + c = Calculator() + assert c.add(2, 3) == 10 +""", + must_contain=["where 5 = Calc().add(2, 3)"], + ) + + def test_chained_method_calls(self) -> None: + """Chained method call: obj.method1().method2().""" + assert_introspects( + """ +def check(): + class Builder: + def __init__(self, val=0): + self.val = val + def add(self, n): + return Builder(self.val + n) + def result(self): + return self.val + def __repr__(self): + return f"Builder({self.val})" + b = Builder() + assert b.add(5).result() == 100 +""", + must_contain=["where 5 = ", ".result()"], + ) + + def test_subscript_on_method_result(self) -> None: + """Subscript on method return value: obj.method()[key].""" + assert_introspects( + """ +def check(): + class Store: + def get_data(self): + return {"x": 42} + def __repr__(self): + return "Store()" + s = Store() + assert s.get_data()["x"] == 100 +""", + must_contain=["42", "100"], + ) + + def test_ifexp_with_call_condition(self) -> None: + """IfExp where condition is a function call.""" + assert_introspects( + """ +def check(): + def is_ready(): + return False + assert (1 if is_ready() else 0) == 1 +""", + must_contain=["if False else"], + ) + + def test_walrus_in_subscript(self) -> None: + """Walrus operator used as subscript key.""" + assert_semantically_equivalent(""" +def check(): + d = {1: "one", 2: "two"} + x = 1 + assert d[(y := x + 1)] == "wrong" +""") + + def test_method_call_single_evaluation(self) -> None: + """Method with side effects is only called once.""" + assert_single_evaluation(""" +def check(): + class Obj: + def compute(self): + counter[0] += 1 + return 42 + obj = Obj() + assert obj.compute() == 100 +""") + + def test_subscript_single_evaluation(self) -> None: + """Custom __getitem__ with side effects is only called once.""" + assert_single_evaluation(""" +def check(): + class CountingList: + def __init__(self, items): + self.items = items + def __getitem__(self, idx): + counter[0] += 1 + return self.items[idx] + def __repr__(self): + return repr(self.items) + lst = CountingList([10, 20, 30]) + assert lst[1] == 99 +""") + + def test_ifexp_condition_single_evaluation(self) -> None: + """IfExp condition with side effects is only evaluated once.""" + assert_single_evaluation(""" +def check(): + def check_flag(): + counter[0] += 1 + return True + assert (0 if check_flag() else 1) == 99 +""") + + def test_complex_assertion_semantics(self) -> None: + """Complex assertion combining multiple new visitors.""" + assert_semantically_equivalent(""" +def check(): + class Config: + def __init__(self): + self.data = {"timeout": 30} + def get(self, key): + return self.data[key] + cfg = Config() + flag = True + assert (cfg.get("timeout") if flag else 0) > 60 +""") + + def test_assert_with_message_still_works(self) -> None: + """Assert with a custom message still works with new visitors.""" + msg = get_failure_message(""" +def check(): + d = {"key": 42} + assert d["key"] == 100, "custom failure message" +""") + assert "custom failure message" in msg + + def test_method_call_on_global(self) -> None: + """Method call on a global/module-level object.""" + assert_introspects( + """ +items = [1, 2, 3] +def check(): + assert items.count(99) == 1 +""", + must_contain=["where 0 = [1, 2, 3].count(99)"], + ) From dc7af3741974bdd70d17f557d3b1dd761ebd7734 Mon Sep 17 00:00:00 2001 From: Ronny Pfannschmidt Date: Fri, 8 May 2026 22:29:52 +0200 Subject: [PATCH 13/13] fix(rewrite): update test_assert_matches for flat method call format Update raises_group.py::test_assert_matches to expect the new flat method call format (where False = RaisesExc(TypeError).matches(...)) instead of the old nested format with a separate bound-method line. Also address Copilot review: use copy.deepcopy in assert_semantically_equivalent to isolate mutable state between the plain and rewritten execution runs. Co-authored-by: Cursor AI Co-authored-by: Anthropic Claude Opus 4 --- testing/python/raises_group.py | 3 +-- testing/test_assertrewrite_coverage.py | 9 +++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/testing/python/raises_group.py b/testing/python/raises_group.py index 8b311bd0eed..bc9c97bed5e 100644 --- a/testing/python/raises_group.py +++ b/testing/python/raises_group.py @@ -1237,8 +1237,7 @@ def test_assert_matches() -> None: match=wrap_escape( "`ValueError()` is not an instance of `TypeError`\n" "assert False\n" - " + where False = matches(ValueError())\n" - " + where matches = RaisesExc(TypeError).matches" + " + where False = RaisesExc(TypeError).matches(ValueError())" ), ): # you'd need to do this arcane incantation diff --git a/testing/test_assertrewrite_coverage.py b/testing/test_assertrewrite_coverage.py index f75dd866000..a0c1f8b68c1 100644 --- a/testing/test_assertrewrite_coverage.py +++ b/testing/test_assertrewrite_coverage.py @@ -13,6 +13,7 @@ import ast from collections.abc import Callable from collections.abc import Mapping +import copy import sys import textwrap from typing import cast @@ -185,11 +186,11 @@ def assert_semantically_equivalent( """ src = textwrap.dedent(src) - # Run without rewriting + # Run without rewriting — use deepcopy of extra_ns to isolate mutable state plain_code = compile(src, "", "exec") plain_ns: dict[str, object] = {} if extra_ns is not None: - plain_ns.update(extra_ns) + plain_ns.update(copy.deepcopy(dict(extra_ns))) exec(plain_code, plain_ns) plain_func = cast(Callable[[], None], plain_ns["check"]) plain_raised = False @@ -198,12 +199,12 @@ def assert_semantically_equivalent( except AssertionError: plain_raised = True - # Run with rewriting + # Run with rewriting — fresh deepcopy so mutations from first run don't leak mod = _rewrite_source(src) rewritten_code = compile(mod, "", "exec") rewritten_ns: dict[str, object] = {} if extra_ns is not None: - rewritten_ns.update(extra_ns) + rewritten_ns.update(copy.deepcopy(dict(extra_ns))) exec(rewritten_code, rewritten_ns) rewritten_func = cast(Callable[[], None], rewritten_ns["check"]) rewritten_raised = False