diff --git a/src/bloqade/native/upstream/squin2native.py b/src/bloqade/native/upstream/squin2native.py index 97604630..2a9131f2 100644 --- a/src/bloqade/native/upstream/squin2native.py +++ b/src/bloqade/native/upstream/squin2native.py @@ -1,14 +1,13 @@ from itertools import chain -from dataclasses import field, dataclass -from kirin import ir, passes, rewrite +from kirin import ir, rewrite from kirin.dialects import py, func from kirin.rewrite.abc import RewriteRule, RewriteResult -from kirin.passes.callgraph import CallGraphPass, ReplaceMethods from kirin.analysis.callgraph import CallGraph from bloqade.native import kernel, broadcast from bloqade.squin.gate import stmts, dialect as gate_dialect +from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph class GateRule(RewriteRule): @@ -46,63 +45,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult(has_done_something=True) -@dataclass -class UpdateDialectsOnCallGraph(passes.Pass): - """Update All dialects on the call graph to a new set of dialects given to this pass. - - Usage: - pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects) - pass_(some_method) - - Note: This pass does not update the dialects of the input method, but copies - all other methods invoked within it before updating their dialects. - - """ - - fold_pass: passes.Fold = field(init=False) - - def __post_init__(self): - self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - mt_map = {} - - cg = CallGraph(mt) - - all_methods = set(sum(map(tuple, cg.defs.values()), ())) - for original_mt in all_methods: - if original_mt is mt: - new_mt = original_mt - else: - new_mt = original_mt.similar(self.dialects) - mt_map[original_mt] = new_mt - - result = RewriteResult() - - for _, new_mt in mt_map.items(): - result = ( - rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result) - ) - self.fold_pass(new_mt) - - return result - - -@dataclass -class SquinToNativePass(passes.Pass): - - call_graph_pass: CallGraphPass = field(init=False) - - def __post_init__(self): - rule = rewrite.Walk(GateRule()) - self.call_graph_pass = CallGraphPass( - self.dialects, rule, no_raise=self.no_raise - ) - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - return self.call_graph_pass.unsafe_run(mt) - - class SquinToNative: """A Target that converts Squin gates to native gates.""" @@ -126,11 +68,10 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method: out = mt.similar(new_dialects) UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out) - SquinToNativePass(new_dialects, no_raise=no_raise)(out) + CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out) # verify all kernels in the callgraph new_callgraph = CallGraph(out) - all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers) - for ker in all_kernels: + for ker in new_callgraph.edges.keys(): ker.verify() return out diff --git a/src/bloqade/rewrite/passes/__init__.py b/src/bloqade/rewrite/passes/__init__.py index b6236a76..e35a23cb 100644 --- a/src/bloqade/rewrite/passes/__init__.py +++ b/src/bloqade/rewrite/passes/__init__.py @@ -1,2 +1,7 @@ +from .callgraph import ( + CallGraphPass as CallGraphPass, + ReplaceMethods as ReplaceMethods, + UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph, +) from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList diff --git a/src/bloqade/rewrite/passes/callgraph.py b/src/bloqade/rewrite/passes/callgraph.py new file mode 100644 index 00000000..0b5e64f0 --- /dev/null +++ b/src/bloqade/rewrite/passes/callgraph.py @@ -0,0 +1,116 @@ +from dataclasses import field, dataclass + +from kirin import ir, passes, rewrite +from kirin.analysis import CallGraph +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.dialects.func.stmts import Invoke + + +@dataclass +class ReplaceMethods(RewriteRule): + new_symbols: dict[ir.Method, ir.Method] + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if ( + not isinstance(node, Invoke) + or (new_callee := self.new_symbols.get(node.callee)) is None + ): + return RewriteResult() + + node.replace_by( + Invoke( + inputs=node.inputs, + callee=new_callee, + purity=node.purity, + ) + ) + + return RewriteResult(has_done_something=True) + + +@dataclass +class UpdateDialectsOnCallGraph(passes.Pass): + """Update All dialects on the call graph to a new set of dialects given to this pass. + + Usage: + pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects) + pass_(some_method) + + Note: This pass does not update the dialects of the input method, but copies + all other methods invoked within it before updating their dialects. + + """ + + fold_pass: passes.Fold = field(init=False) + + def __post_init__(self): + self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + mt_map = {} + + cg = CallGraph(mt) + + all_methods = set(sum(map(tuple, cg.defs.values()), ())) + for original_mt in all_methods: + if original_mt is mt: + new_mt = original_mt + else: + new_mt = original_mt.similar(self.dialects) + mt_map[original_mt] = new_mt + + result = RewriteResult() + + for _, new_mt in mt_map.items(): + result = ( + rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result) + ) + self.fold_pass(new_mt) + + return result + + +@dataclass +class CallGraphPass(passes.Pass): + """Copy all functions in the call graph and apply a rule to each of them. + + + Usage: + rule = Walk(SomeRewriteRule()) + pass_ = CallGraphPass(rule=rule, dialects=...) + pass_(some_method) + + Note: This pass modifies the input method in place, but copies + all methods invoked within it before applying the rule to them. + + """ + + rule: RewriteRule + """The rule to apply to each function in the call graph.""" + + fold_pass: passes.Fold = field(init=False) + + def __post_init__(self): + self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + result = RewriteResult() + mt_map = {} + + cg = CallGraph(mt) + + all_methods = set(cg.edges.keys()) + for original_mt in all_methods: + if original_mt is mt: + new_mt = original_mt + else: + new_mt = original_mt.similar() + result = self.rule.rewrite(new_mt.code).join(result) + mt_map[original_mt] = new_mt + + if result.has_done_something: + for _, new_mt in mt_map.items(): + rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code) + self.fold_pass(new_mt) + + return result