Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ Feng Ma
Florian Bruhin
Florian Dahlitz
Floris Bruynooghe
Frank Hoffmann
Fraser Stark
Gabriel Landau
Gabriel Reis
Expand Down
1 change: 1 addition & 0 deletions changelog/12818.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Assertion rewriting now preserves the source ranges of the original instructions, making it play well with tools that deal with the ``AST``, like `executing <https://github.com/alexmojaki/executing>`__.
21 changes: 14 additions & 7 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
return ast.copy_location(ast.Name(name, ast.Load()), expr)

def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
Expand Down Expand Up @@ -975,7 +975,10 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
if getattr(node, "lineno", None) is None:
# apply the assertion location to all generated ast nodes without source location
# and preserve the location of existing nodes or generated nodes with an correct location.
ast.copy_location(node, assert_)
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
Expand Down Expand Up @@ -1052,15 +1055,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
return res, pattern % (operand_expl,)

def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
res = self.assign(
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
)
return res, explanation

def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
Expand Down Expand Up @@ -1089,7 +1094,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
arg_expls.append("**" + expl)

expl = "{}({})".format(func_expl, ", ".join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
Expand All @@ -1105,7 +1110,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
res = self.assign(
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s.%s\n}"
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
Expand Down Expand Up @@ -1146,7 +1153,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
Expand Down
211 changes: 207 additions & 4 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

import ast
import dis
import errno
from functools import partial
import glob
import importlib
import inspect
import marshal
import os
from pathlib import Path
Expand Down Expand Up @@ -131,10 +133,211 @@
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert isinstance(n, (ast.stmt, ast.expr))
assert n.lineno == 3
assert n.col_offset == 0
assert n.end_lineno == 6
assert n.end_col_offset == 3
for location in [
(n.lineno, n.col_offset),
(n.end_lineno, n.end_col_offset),
]:
assert (3, 0) <= location <= (6, 3)

def test_positions_are_preserved(self) -> None:
"""Ensure AST positions are preserved during rewriting (#12818)."""

def preserved(code: str) -> None:
s = textwrap.dedent(code)
locations = []

def loc(msg: str | None = None) -> None:
frame = inspect.currentframe()
assert frame
frame = frame.f_back
assert frame
frame = frame.f_back
assert frame

offset = frame.f_lasti

instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}

# skip CACHE instructions
while offset not in instructions and offset >= 0:
offset -= 1

Check warning on line 163 in testing/test_assertrewrite.py

View check run for this annotation

Codecov / codecov/patch

testing/test_assertrewrite.py#L163

Added line #L163 was not covered by tests

instruction = instructions[offset]
if sys.version_info >= (3, 11):
position = instruction.positions

Check warning on line 167 in testing/test_assertrewrite.py

View check run for this annotation

Codecov / codecov/patch

testing/test_assertrewrite.py#L167

Added line #L167 was not covered by tests
else:
position = instruction.starts_line

locations.append((msg, instruction.opname, position))

globals = {"loc": loc}

m = rewrite(s)
mod = compile(m, "<string>", "exec")
exec(mod, globals, globals)
transformed_locations = locations
locations = []

mod = compile(s, "<string>", "exec")
exec(mod, globals, globals)
original_locations = locations

assert len(original_locations) > 0
assert original_locations == transformed_locations

preserved("""
def f():
loc()
return 8

assert f() in [8]
assert (f()
in
[8])
""")

preserved("""
class T:
def __init__(self):
loc("init")
def __getitem__(self,index):
loc("getitem")
return index

assert T()[5] == 5
assert (T
()
[5]
==
5)
""")

for name, op in [
("pos", "+"),
("neg", "-"),
("invert", "~"),
]:
preserved(f"""
class T:
def __{name}__(self):
loc("{name}")
return "{name}"

assert {op}T() == "{name}"
assert ({op}
T
()
==
"{name}")
""")

for name, op in [
("add", "+"),
("sub", "-"),
("mul", "*"),
("truediv", "/"),
("floordiv", "//"),
("mod", "%"),
("pow", "**"),
("lshift", "<<"),
("rshift", ">>"),
("or", "|"),
("xor", "^"),
("and", "&"),
("matmul", "@"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc("{name}")
return other

def __r{name}__(self,other):
loc("r{name}")
return other

assert T() {op} 2 == 2
assert 2 {op} T() == 2

assert (T
()
{op}
2
==
2)

assert (2
{op}
T
()
==
2)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", "<"),
("le", "<="),
("gt", ">"),
("ge", ">="),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True

assert T() {op} 5
assert (T
()
{op}
5)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", ">"),
("le", ">="),
("gt", "<"),
("ge", "<="),
("contains", "in"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True

assert 5 {op} T()
assert (5
{op}
T
())
""")

preserved("""
def func(value):
loc("func")
return value

class T:
def __iter__(self):
loc("iter")
return iter([5])

assert func(*T()) == 5
""")

preserved("""
class T:
def __getattr__(self,name):
loc()
return name

assert T().attr == "attr"
""")

def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
Expand Down
Loading