Skip to content

Commit 790f390

Browse files
rafaelhaRoger-luo
authored andcommitted
Ensure SourceInfo is preserved in inline pass and offsets are kepts (#541)
This PR ensures that source info is preserved after inlining or cloning in `.similar()`. Additionally I found that offsets were lost (which are now added in `lowering.py`). I tested all of these changes on kirin 0.17.30 together with bloqade-circuit. Blocks QuEraComputing/bloqade-circuit#552 Addresses #540
1 parent 759c092 commit 790f390

File tree

8 files changed

+91
-3
lines changed

8 files changed

+91
-3
lines changed

src/kirin/ir/group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def wrapper(py_func: Callable) -> Method:
204204
f"overwriting function definition of `{py_func.__name__}`"
205205
)
206206

207-
lineno_offset = call_site_frame.f_lineno - 1
207+
lineno_offset = py_func.__code__.co_firstlineno - 1
208208
file = call_site_frame.f_code.co_filename
209209

210210
code = self.lowering.python_function(py_func, lineno_offset=lineno_offset)

src/kirin/ir/nodes/stmt.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ def insert_after(self, stmt: Statement) -> None:
267267
self.parent = stmt.parent
268268
stmt._next_stmt = self
269269

270+
if self.source is None and stmt.source is not None:
271+
self.source = stmt.source
272+
270273
if self.parent:
271274
self.parent._stmt_len += 1
272275

@@ -302,6 +305,9 @@ def insert_before(self, stmt: Statement) -> None:
302305
self.parent = stmt.parent
303306
stmt._prev_stmt = self
304307

308+
if self.source is None and stmt.source is not None:
309+
self.source = stmt.source
310+
305311
if self.parent:
306312
self.parent._stmt_len += 1
307313

@@ -506,6 +512,7 @@ def from_stmt(
506512
attributes=attributes or other.attributes,
507513
result_types=[result.type for result in other._results],
508514
args_slice=other._name_args_slice,
515+
source=other.source,
509516
)
510517
# inherit the hint:
511518
for result, other_result in zip(obj._results, other._results):

src/kirin/lowering/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def _push_stmt(self, stmt: StmtType) -> StmtType:
7575
raise BuildError(
7676
f"Unsupported dialect `{stmt.dialect.name}` from statement {stmt.name}"
7777
)
78-
self.curr_block.stmts.append(stmt)
7978
if stmt.source is None:
8079
stmt.source = self.state.source
80+
self.curr_block.stmts.append(stmt)
8181
return stmt
8282

8383
def _push_block(self, block: Block):

src/kirin/lowering/python/lowering.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def lower_global(self, state: State[ast.AST], node: ast.AST) -> LoweringABC.Resu
138138
def visit(self, state: State[ast.AST], node: ast.AST) -> Result:
139139
if hasattr(node, "lineno"):
140140
state.source = SourceInfo.from_ast(node, state.file)
141+
state.source.offset(state.lineno_offset, state.col_offset)
141142
name = node.__class__.__name__
142143
if name in self.registry.ast_table:
143144
return self.registry.ast_table[name].lower(state, node)
@@ -148,7 +149,8 @@ def generic_visit(self, state: State[ast.AST], node: ast.AST) -> Result:
148149

149150
def visit_Call(self, state: State[ast.AST], node: ast.Call) -> Result:
150151
if hasattr(node.func, "lineno"):
151-
state.source = SourceInfo.from_ast(node.func, state.file)
152+
state.source = SourceInfo.from_ast(node, state.file)
153+
state.source.offset(state.lineno_offset, state.col_offset)
152154

153155
global_callee_result = state.get_global(node.func, no_raise=True)
154156
if global_callee_result is None:

src/kirin/rewrite/inline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ def inline_call_like(
107107

108108
# NOTE: we cannot change region because it may be used elsewhere
109109
inline_region: ir.Region = region.clone()
110+
111+
# Preserve source information by attributing inlined code to the call site
112+
if call_like.source is not None:
113+
for block in inline_region.blocks:
114+
if block.source is None:
115+
block.source = call_like.source
116+
for stmt in block.stmts:
117+
if stmt.source is None:
118+
stmt.source = call_like.source
119+
110120
parent_block: ir.Block = call_like.parent_block
111121
parent_region: ir.Region = call_like.parent_region
112122

test/ir/test_stmt.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from kirin.ir import Block
4+
from kirin.source import SourceInfo
45
from kirin.dialects import py
56

67

@@ -50,3 +51,12 @@ def test_stmt_from_stmt():
5051
y = x.from_stmt(x)
5152

5253
assert y.result.hints["const"] == py.constant.types.Int
54+
55+
56+
def test_stmt_from_stmt_preserves_source_info():
57+
x = py.Constant(1)
58+
x.source = SourceInfo(lineno=1, col_offset=0, end_lineno=None, end_col_offset=None)
59+
60+
y = x.from_stmt(x)
61+
assert y.source == x.source
62+
assert y.source is x.source

test/lowering/test_source_info.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest
2+
3+
from kirin.source import SourceInfo
4+
from kirin.prelude import basic_no_opt
5+
6+
7+
def get_line_of(target: str) -> int:
8+
for i, line in enumerate(open(__file__), 1):
9+
if target in line:
10+
return i
11+
12+
13+
@pytest.mark.parametrize("similar", [True, False])
14+
def test_stmt_source_info(similar: bool):
15+
@basic_no_opt
16+
def test(x: int):
17+
y = 2
18+
a = 4**2
19+
return y + 2 + a
20+
21+
if similar:
22+
test = test.similar()
23+
24+
stmts = test.callable_region.blocks[0].stmts
25+
26+
def get_line_from_source_info(source: SourceInfo) -> int:
27+
return source.lineno + source.lineno_begin
28+
29+
for stmt in stmts:
30+
assert stmt.source.file == __file__
31+
32+
assert get_line_from_source_info(stmts.at(0).source) == get_line_of("y = 2")
33+
assert get_line_from_source_info(stmts.at(2).source) == get_line_of("a = 4**2")
34+
assert get_line_from_source_info(stmts.at(4).source) == get_line_of(
35+
"return y + 2 + a"
36+
)

test/passes/test_inline_pass.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin.prelude import basic_no_opt
22
from kirin.passes.inline import InlinePass
3+
from kirin.dialects.py.constant import Constant
34

45

56
@basic_no_opt
@@ -40,3 +41,25 @@ def main_inline_pass2(x: int):
4041
assert a == b
4142

4243
assert len(main_inline_pass2.callable_region.blocks[0].stmts) == 4
44+
45+
46+
def test_inline_preserves_source_info():
47+
def get_line_of(target: str) -> int:
48+
for i, line in enumerate(open(__file__), 1):
49+
if target in line:
50+
return i
51+
52+
@basic_no_opt
53+
def main_inline_pass(x: int):
54+
y = inline_func(x)
55+
return y + 2
56+
57+
inline = InlinePass(main_inline_pass.dialects)
58+
inline(main_inline_pass)
59+
60+
stmt = main_inline_pass.callable_region.blocks[0].stmts.at(0)
61+
line = stmt.source.lineno + stmt.source.lineno_begin
62+
assert stmt.value.data == 1
63+
assert isinstance(stmt, Constant)
64+
65+
assert get_line_of("return x - 1") == line

0 commit comments

Comments
 (0)