diff --git a/src/kirin/ir/group.py b/src/kirin/ir/group.py index 5cbeaa50f..99175a7b4 100644 --- a/src/kirin/ir/group.py +++ b/src/kirin/ir/group.py @@ -216,7 +216,7 @@ def wrapper(py_func: Callable) -> Method: f"`{py_func.__name__}` is already defined in the current scope and is not a Method." ) - lineno_offset = call_site_frame.f_lineno - 1 + lineno_offset = py_func.__code__.co_firstlineno - 1 file = call_site_frame.f_code.co_filename code = self.lowering.python_function(py_func, lineno_offset=lineno_offset) diff --git a/src/kirin/ir/nodes/stmt.py b/src/kirin/ir/nodes/stmt.py index 0581b6649..fe30e3fb5 100644 --- a/src/kirin/ir/nodes/stmt.py +++ b/src/kirin/ir/nodes/stmt.py @@ -267,6 +267,9 @@ def insert_after(self, stmt: Statement) -> None: self.parent = stmt.parent stmt._next_stmt = self + if self.source is None and stmt.source is not None: + self.source = stmt.source + if self.parent: self.parent._stmt_len += 1 @@ -302,6 +305,9 @@ def insert_before(self, stmt: Statement) -> None: self.parent = stmt.parent stmt._prev_stmt = self + if self.source is None and stmt.source is not None: + self.source = stmt.source + if self.parent: self.parent._stmt_len += 1 @@ -506,6 +512,7 @@ def from_stmt( attributes=attributes or other.attributes, result_types=[result.type for result in other._results], args_slice=other._name_args_slice, + source=other.source, ) # inherit the hint: for result, other_result in zip(obj._results, other._results): diff --git a/src/kirin/lowering/frame.py b/src/kirin/lowering/frame.py index 6eb88ef2e..60893faab 100644 --- a/src/kirin/lowering/frame.py +++ b/src/kirin/lowering/frame.py @@ -75,9 +75,9 @@ def _push_stmt(self, stmt: StmtType) -> StmtType: raise BuildError( f"Unsupported dialect `{stmt.dialect.name}` from statement {stmt.name}" ) - self.curr_block.stmts.append(stmt) if stmt.source is None: stmt.source = self.state.source + self.curr_block.stmts.append(stmt) return stmt def _push_block(self, block: Block): diff --git a/src/kirin/lowering/python/lowering.py b/src/kirin/lowering/python/lowering.py index edf303224..4cff7b719 100644 --- a/src/kirin/lowering/python/lowering.py +++ b/src/kirin/lowering/python/lowering.py @@ -138,6 +138,7 @@ def lower_global(self, state: State[ast.AST], node: ast.AST) -> LoweringABC.Resu def visit(self, state: State[ast.AST], node: ast.AST) -> Result: if hasattr(node, "lineno"): state.source = SourceInfo.from_ast(node, state.file) + state.source.offset(state.lineno_offset, state.col_offset) name = node.__class__.__name__ if name in self.registry.ast_table: return self.registry.ast_table[name].lower(state, node) @@ -148,7 +149,8 @@ def generic_visit(self, state: State[ast.AST], node: ast.AST) -> Result: def visit_Call(self, state: State[ast.AST], node: ast.Call) -> Result: if hasattr(node.func, "lineno"): - state.source = SourceInfo.from_ast(node.func, state.file) + state.source = SourceInfo.from_ast(node, state.file) + state.source.offset(state.lineno_offset, state.col_offset) global_callee_result = state.get_global(node.func, no_raise=True) if global_callee_result is None: diff --git a/src/kirin/rewrite/inline.py b/src/kirin/rewrite/inline.py index 7c75aa1f2..a15d2fd5a 100644 --- a/src/kirin/rewrite/inline.py +++ b/src/kirin/rewrite/inline.py @@ -97,6 +97,16 @@ def inline_call_like( # NOTE: we cannot change region because it may be used elsewhere inline_region: ir.Region = region.clone() + + # Preserve source information by attributing inlined code to the call site + if call_like.source is not None: + for block in inline_region.blocks: + if block.source is None: + block.source = call_like.source + for stmt in block.stmts: + if stmt.source is None: + stmt.source = call_like.source + parent_block: ir.Block = call_like.parent_block parent_region: ir.Region = call_like.parent_region diff --git a/test/ir/test_stmt.py b/test/ir/test_stmt.py index 3b1c7faef..1cb15777d 100644 --- a/test/ir/test_stmt.py +++ b/test/ir/test_stmt.py @@ -1,6 +1,7 @@ import pytest from kirin.ir import Block +from kirin.source import SourceInfo from kirin.dialects import py @@ -50,3 +51,12 @@ def test_stmt_from_stmt(): y = x.from_stmt(x) assert y.result.hints["const"] == py.constant.types.Int + + +def test_stmt_from_stmt_preserves_source_info(): + x = py.Constant(1) + x.source = SourceInfo(lineno=1, col_offset=0, end_lineno=None, end_col_offset=None) + + y = x.from_stmt(x) + assert y.source == x.source + assert y.source is x.source diff --git a/test/lowering/test_source_info.py b/test/lowering/test_source_info.py new file mode 100644 index 000000000..93afc2f7f --- /dev/null +++ b/test/lowering/test_source_info.py @@ -0,0 +1,36 @@ +import pytest + +from kirin.source import SourceInfo +from kirin.prelude import basic_no_opt + + +def get_line_of(target: str) -> int: + for i, line in enumerate(open(__file__), 1): + if target in line: + return i + + +@pytest.mark.parametrize("similar", [True, False]) +def test_stmt_source_info(similar: bool): + @basic_no_opt + def test(x: int): + y = 2 + a = 4**2 + return y + 2 + a + + if similar: + test = test.similar() + + stmts = test.callable_region.blocks[0].stmts + + def get_line_from_source_info(source: SourceInfo) -> int: + return source.lineno + source.lineno_begin + + for stmt in stmts: + assert stmt.source.file == __file__ + + assert get_line_from_source_info(stmts.at(0).source) == get_line_of("y = 2") + assert get_line_from_source_info(stmts.at(2).source) == get_line_of("a = 4**2") + assert get_line_from_source_info(stmts.at(4).source) == get_line_of( + "return y + 2 + a" + ) diff --git a/test/passes/test_inline_pass.py b/test/passes/test_inline_pass.py index 9c5aed872..de5230875 100644 --- a/test/passes/test_inline_pass.py +++ b/test/passes/test_inline_pass.py @@ -1,5 +1,6 @@ from kirin.prelude import basic_no_opt from kirin.passes.inline import InlinePass +from kirin.dialects.py.constant import Constant @basic_no_opt @@ -40,3 +41,25 @@ def main_inline_pass2(x: int): assert a == b assert len(main_inline_pass2.callable_region.blocks[0].stmts) == 4 + + +def test_inline_preserves_source_info(): + def get_line_of(target: str) -> int: + for i, line in enumerate(open(__file__), 1): + if target in line: + return i + + @basic_no_opt + def main_inline_pass(x: int): + y = inline_func(x) + return y + 2 + + inline = InlinePass(main_inline_pass.dialects) + inline(main_inline_pass) + + stmt = main_inline_pass.callable_region.blocks[0].stmts.at(0) + line = stmt.source.lineno + stmt.source.lineno_begin + assert stmt.value.data == 1 + assert isinstance(stmt, Constant) + + assert get_line_of("return x - 1") == line