Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ Mark Abramowitz
Mark Dickinson
Mark Vong
Marko Pacak
marko1olo
Markus Unterwaditzer
Martijn Faassen
Martin Altmayer
Expand Down
1 change: 1 addition & 0 deletions changelog/14445.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting leaking walrus operator state into later assertions in the same function.
32 changes: 30 additions & 2 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
self.statements: list[ast.stmt] = []
self.variables: list[str] = []
self.variable_counter = itertools.count()
self.variables_overwrite[self.scope] = {}
self.variable_restore_names: list[tuple[str, str]] = []

if self.enable_assertion_pass_hook:
self.format_variables: list[str] = []
Expand All @@ -881,6 +883,15 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:

negation = ast.UnaryOp(ast.Not(), top_condition)

def restore_walrus_targets() -> list[ast.Assign]:
return [
ast.Assign(
[ast.Name(target_id, ast.Store())],
ast.Name(temp_id, ast.Load()),
)
for target_id, temp_id in self.variable_restore_names
]

if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
msg = self.pop_format_context(ast.Constant(explanation))

Expand All @@ -899,6 +910,7 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
raise_ = ast.Raise(exc, None)
statements_fail = []
statements_fail.extend(self.expl_stmts)
statements_fail.extend(restore_walrus_targets())
statements_fail.append(raise_)

# Passed
Expand All @@ -918,7 +930,10 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
[*self.expl_stmts, hook_call_pass],
[],
)
statements_pass: list[ast.stmt] = [hook_impl_test]
statements_pass: list[ast.stmt] = [
hook_impl_test,
*restore_walrus_targets(),
]

# Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass)
Expand Down Expand Up @@ -947,7 +962,9 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)

body.extend(restore_walrus_targets())
body.append(raise_)
self.statements.extend(restore_walrus_targets())

# Clear temporary variables by setting them to None.
if self.variables:
Expand Down Expand Up @@ -1001,6 +1018,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
restore_after_operand: tuple[str, str] | None = None
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
Expand All @@ -1012,9 +1030,14 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
restore_after_operand = (target_id, pytest_temp)
else:
restore_after_operand = None
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
if restore_after_operand is not None:
self.variable_restore_names.append(restore_after_operand)
expl_format = self.pop_format_context(ast.Constant(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
Expand Down Expand Up @@ -1119,13 +1142,16 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
restore_after_compare: tuple[str, str] | None = None
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
temp_id = self.variable()
next_operand.target.id = temp_id
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
restore_after_compare = (name_id, temp_id)

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
Expand All @@ -1137,6 +1163,8 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
expls.append(ast.Constant(expl))
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
self.statements.append(ast.Assign([store_names[i]], res_expr))
if restore_after_compare is not None:
self.variable_restore_names.append(restore_after_compare)
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
expl_call = self.helper(
Expand Down
49 changes: 49 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,55 @@ def test_walrus_operator_not_override_value():
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_value_changes_cleared_after_each_assert(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
class Counter:
def __init__(self):
self.value = 0

def increment(self):
self.value += 1

def test_walrus_operator_change_value_between_asserts():
counter = Counter()
assert (before := counter.value) == 0
counter.increment()
assert before != (after := counter.value)
assert before == 0
assert after == 1
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_restore_with_assertion_pass_hook(
self, pytester: Pytester
) -> None:
pytester.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
pytester.makepyfile(
"""
def test_walrus_operator_pass_compare_restore():
a = "Hello"
assert a != (a := a.lower())
assert a == "hello"

def test_walrus_operator_pass_bool_restore():
a = True
assert a and ((a := False) is False) and (a is False)
assert a is False

def test_walrus_operator_fail_compare_explanation():
a = "Hello"
assert a == (a := a.lower())
"""
)
result = pytester.runpytest()
result.assert_outcomes(passed=2, failed=1)
result.stdout.fnmatch_lines(["*assert 'Hello' == 'hello'"])

def test_assertion_namedexpr_compare_left_overwrite(
self, pytester: Pytester
) -> None:
Expand Down
Loading