Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.17.30",
"kirin-toolchain~=0.20.0",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down
67 changes: 4 additions & 63 deletions src/bloqade/native/upstream/squin2native.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from itertools import chain
from dataclasses import field, dataclass

from kirin import ir, passes, rewrite
from kirin import ir, rewrite
from kirin.dialects import py, func
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.passes.callgraph import CallGraphPass, ReplaceMethods
from kirin.analysis.callgraph import CallGraph

from bloqade.native import kernel, broadcast
from bloqade.squin.gate import stmts, dialect as gate_dialect
from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph


class GateRule(RewriteRule):
Expand Down Expand Up @@ -46,63 +45,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
return RewriteResult(has_done_something=True)


@dataclass
class UpdateDialectsOnCallGraph(passes.Pass):
"""Update All dialects on the call graph to a new set of dialects given to this pass.

Usage:
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
pass_(some_method)

Note: This pass does not update the dialects of the input method, but copies
all other methods invoked within it before updating their dialects.

"""

fold_pass: passes.Fold = field(init=False)

def __post_init__(self):
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
mt_map = {}

cg = CallGraph(mt)

all_methods = set(sum(map(tuple, cg.defs.values()), ()))
for original_mt in all_methods:
if original_mt is mt:
new_mt = original_mt
else:
new_mt = original_mt.similar(self.dialects)
mt_map[original_mt] = new_mt

result = RewriteResult()

for _, new_mt in mt_map.items():
result = (
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
)
self.fold_pass(new_mt)

return result


@dataclass
class SquinToNativePass(passes.Pass):

call_graph_pass: CallGraphPass = field(init=False)

def __post_init__(self):
rule = rewrite.Walk(GateRule())
self.call_graph_pass = CallGraphPass(
self.dialects, rule, no_raise=self.no_raise
)

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
return self.call_graph_pass.unsafe_run(mt)


class SquinToNative:
"""A Target that converts Squin gates to native gates."""

Expand All @@ -126,11 +68,10 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:

out = mt.similar(new_dialects)
UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
SquinToNativePass(new_dialects, no_raise=no_raise)(out)
CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out)
# verify all kernels in the callgraph
new_callgraph = CallGraph(out)
all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
for ker in all_kernels:
for ker in new_callgraph.edges.keys():
ker.verify()

return out
5 changes: 5 additions & 0 deletions src/bloqade/rewrite/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .callgraph import (
CallGraphPass as CallGraphPass,
ReplaceMethods as ReplaceMethods,
UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph,
)
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
116 changes: 116 additions & 0 deletions src/bloqade/rewrite/passes/callgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from dataclasses import field, dataclass

from kirin import ir, passes, rewrite
from kirin.analysis import CallGraph
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.dialects.func.stmts import Invoke


@dataclass
class ReplaceMethods(RewriteRule):
new_symbols: dict[ir.Method, ir.Method]

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if (
not isinstance(node, Invoke)
or (new_callee := self.new_symbols.get(node.callee)) is None
):
return RewriteResult()

node.replace_by(
Invoke(
inputs=node.inputs,
callee=new_callee,
purity=node.purity,
)
)

return RewriteResult(has_done_something=True)


@dataclass
class UpdateDialectsOnCallGraph(passes.Pass):
"""Update All dialects on the call graph to a new set of dialects given to this pass.

Usage:
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
pass_(some_method)

Note: This pass does not update the dialects of the input method, but copies
all other methods invoked within it before updating their dialects.

"""

fold_pass: passes.Fold = field(init=False)

def __post_init__(self):
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
mt_map = {}

cg = CallGraph(mt)

all_methods = set(sum(map(tuple, cg.defs.values()), ()))
for original_mt in all_methods:
if original_mt is mt:
new_mt = original_mt
else:
new_mt = original_mt.similar(self.dialects)
mt_map[original_mt] = new_mt

result = RewriteResult()

for _, new_mt in mt_map.items():
result = (
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
)
self.fold_pass(new_mt)

return result


@dataclass
class CallGraphPass(passes.Pass):
"""Copy all functions in the call graph and apply a rule to each of them.


Usage:
rule = Walk(SomeRewriteRule())
pass_ = CallGraphPass(rule=rule, dialects=...)
pass_(some_method)

Note: This pass modifies the input method in place, but copies
all methods invoked within it before applying the rule to them.

"""

rule: RewriteRule
"""The rule to apply to each function in the call graph."""

fold_pass: passes.Fold = field(init=False)

def __post_init__(self):
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
result = RewriteResult()
mt_map = {}

cg = CallGraph(mt)

all_methods = set(cg.edges.keys())
for original_mt in all_methods:
if original_mt is mt:
new_mt = original_mt
else:
new_mt = original_mt.similar()
result = self.rule.rewrite(new_mt.code).join(result)
mt_map[original_mt] = new_mt

if result.has_done_something:
for _, new_mt in mt_map.items():
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code)
self.fold_pass(new_mt)

return result
Loading