Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4a31b16
Refactor QASM codegen with new Interpreter framework. Still has linti…
zhenrongliew Oct 3, 2025
325faa0
Fixed codegen to match QASM2 function declaration syntax (classical b…
zhenrongliew Oct 3, 2025
9fbeac4
temp
zhenrongliew Oct 6, 2025
2889aff
Fixed linting
zhenrongliew Oct 6, 2025
7ff08c8
Fix invoke method type checking complain
zhenrongliew Oct 6, 2025
004e7e4
included ssacfg in dialect groups
zhenrongliew Oct 7, 2025
60c2852
Refactor SymbolTable implementation and remove unused test files
zhenrongliew Oct 7, 2025
9831f65
refactor STIM codegen to work with new API. updated STIM/emit test ca…
zhenrongliew Oct 15, 2025
86a1ca4
Updated stim_circuit tests to new codegen
zhenrongliew Oct 16, 2025
13c165e
Refactor error handling to use InterpreterError instead of EmitError …
zhenrongliew Oct 16, 2025
0c2172f
Refactor run_analysis calls to use the updated run.
zhenrongliew Oct 20, 2025
5172e42
added xfail for some (under refactoring) QASM2 codegen tests
zhenrongliew Oct 20, 2025
ee2e178
Merge conflicts with main
zhenrongliew Oct 21, 2025
ae02b34
Revert to sync with main
zhenrongliew Oct 21, 2025
ab57782
Refactor EmitStimMain initialization to use StringIO for output buffe…
zhenrongliew Oct 22, 2025
6f72eb6
merge conflicts
zhenrongliew Oct 22, 2025
982b4f3
tests calling old emit
zhenrongliew Oct 22, 2025
80f781a
fix io into default
zhenrongliew Oct 22, 2025
434d645
fix isort complaints
zhenrongliew Oct 22, 2025
ccaf33d
black formatter
zhenrongliew Oct 22, 2025
a8f8b67
isort on test
zhenrongliew Oct 22, 2025
452b08d
Mark tests as "xfail".
zhenrongliew Oct 22, 2025
ae296f8
Add correlated_error_count to EmitStimMain and update write_line meth…
zhenrongliew Oct 24, 2025
346eefb
Resolve merge conflicts with `david/571-kirin-upgrade-branch`
zhenrongliew Oct 27, 2025
28461af
Merge with `david/571-kirin-upgrade-branch`
zhenrongliew Oct 27, 2025
46dfca9
Replace EmitError with InterpreterError in emit_circuit and related t…
zhenrongliew Oct 27, 2025
869b588
Add self argument to the body of the loads method
zhenrongliew Oct 27, 2025
4a9354c
Remove SymbolTable and WorkList from codegen classes QASM and STIM. C…
zhenrongliew Oct 28, 2025
f5d3ba1
moved `run` method to base `EmitABC`.
zhenrongliew Oct 28, 2025
88db9e9
Merge origin/david/571-kirin-upgrade-branch into dl/codegen
zhenrongliew Oct 29, 2025
c6573cd
updated stim debug
zhenrongliew Oct 29, 2025
c2ab442
Refactor: Remove unused import and add reset method to Emit classes
zhenrongliew Oct 29, 2025
34ca7e5
Move `circuit` field from EmitCirqFrame to EmitCirq
zhenrongliew Oct 29, 2025
e8c2ecf
early return in ifs marked as xfail.
zhenrongliew Oct 29, 2025
c89b643
isort
zhenrongliew Oct 29, 2025
8d05cf3
Update src/bloqade/cirq_utils/emit/base.py
zhenrongliew Oct 30, 2025
ec3a530
Merge branch 'david/571-kirin-upgrade-branch' into dl/codegen
zhenrongliew Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/bloqade/analysis/fidelity/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ def run_analysis(
self, method: ir.Method, args: tuple | None = None, *, no_raise: bool = True
) -> tuple[ForwardFrame, Any]:
self._run_address_analysis(method, no_raise=no_raise)
return super().run_analysis(method, args, no_raise=no_raise)
return super().run(method)

def _run_address_analysis(self, method: ir.Method, no_raise: bool):
addr_analysis = AddressAnalysis(self.dialects)
addr_frame, _ = addr_analysis.run_analysis(method=method, no_raise=no_raise)
addr_frame, _ = addr_analysis.run(method=method)
self.addr_frame = addr_frame

# NOTE: make sure we have as many probabilities as we have addresses
self.atom_survival_probability = [1.0] * addr_analysis.qubit_count

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()
3 changes: 3 additions & 0 deletions src/bloqade/analysis/measure_id/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ def get_const_value(
return hint.data

return None

def method_self(self, method: ir.Method) -> MeasureId:
return self.lattice.bottom()
80 changes: 61 additions & 19 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import cirq
from kirin import ir, types, interp
from kirin.emit import EmitABC, EmitError, EmitFrame
from kirin.emit import EmitABC, EmitFrame
from kirin.interp import MethodTable, impl
from kirin.dialects import py, func
from kirin.dialects import py, func, ilist
from typing_extensions import Self

from bloqade.squin import kernel
Expand Down Expand Up @@ -102,7 +102,7 @@ def main():
and isinstance(mt.code, func.Function)
and not mt.code.signature.output.is_subseteq(types.NoneType)
):
raise EmitError(
raise interp.exceptions.InterpreterError(
"The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
" Set `ignore_returns = True` in order to simply ignore the return values and emit a circuit."
)
Expand All @@ -116,12 +116,14 @@ def main():

symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
raise EmitError("The method is not a symbol, cannot emit circuit!")
raise interp.exceptions.InterpreterError(
"The method is not a symbol, cannot emit circuit!"
)

sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()

if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
raise EmitError(
raise interp.exceptions.InterpreterError(
f"The method {sym_name} does not have a signature, cannot emit circuit!"
)

Expand All @@ -135,7 +137,7 @@ def main():

assert first_stmt is not None, "Method has no statements!"
if len(args_ssa) - 1 != len(args):
raise EmitError(
raise interp.exceptions.InterpreterError(
f"The method {sym_name} takes {len(args_ssa) - 1} arguments, but you passed in {len(args)} via the `args` keyword!"
)

Expand All @@ -147,17 +149,22 @@ def main():
new_func = func.Function(
sym_name=sym_name, body=callable_region, signature=new_signature
)
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
mt_ = ir.Method(
dialects=mt.dialects,
code=new_func,
sym_name=sym_name,
)

AggressiveUnroll(mt_.dialects).fixpoint(mt_)
return emitter.run(mt_, args=())
emitter.initialize()
emitter.run(mt_)
return emitter.circuit


@dataclass
class EmitCirqFrame(EmitFrame):
qubit_index: int = 0
qubits: Sequence[cirq.Qid] | None = None
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)


def _default_kernel():
Expand All @@ -166,23 +173,24 @@ def _default_kernel():

@dataclass
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
keys = ["emit.cirq", "main"]
keys = ("emit.cirq", "emit.main")
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
void = cirq.Circuit()
qubits: Sequence[cirq.Qid] | None = None
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)

def initialize(self) -> Self:
return super().initialize()

def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
self, node: ir.Statement, *, has_parent_access: bool = False
) -> EmitCirqFrame:
return EmitCirqFrame(
code, has_parent_access=has_parent_access, qubits=self.qubits
node, has_parent_access=has_parent_access, qubits=self.qubits
)

def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
return self.run_callable(method.code, args)
return self.call(method, *args)

def run_callable_region(
self,
Expand All @@ -196,7 +204,7 @@ def run_callable_region(
# NOTE: skip self arg
frame.set_values(block_args[1:], args)

results = self.eval_stmt(frame, code)
results = self.frame_eval(frame, code)
if isinstance(results, tuple):
if len(results) == 0:
return self.void
Expand All @@ -206,24 +214,36 @@ def run_callable_region(

def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
for stmt in block.stmts:
result = self.eval_stmt(frame, stmt)
result = self.frame_eval(frame, stmt)
if isinstance(result, tuple):
frame.set_values(stmt.results, result)

return frame.circuit
return self.circuit

def reset(self):
pass


@func.dialect.register(key="emit.cirq")
class __FuncEmit(MethodTable):

@impl(func.Function)
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
emit.run_ssacfg_region(frame, stmt.body, ())
return (frame.circuit,)
for block in stmt.body.blocks:
frame.current_block = block
for s in block.stmts:
frame.current_stmt = s
stmt_results = emit.frame_eval(frame, s)
if isinstance(stmt_results, tuple):
if len(stmt_results) != 0:
frame.set_values(s.results, stmt_results)
continue

return (emit.circuit,)

@impl(func.Invoke)
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
raise EmitError(
raise interp.exceptions.InterpreterError(
"Function invokes should need to be inlined! "
"If you called the emit_circuit method, that should have happened, please report this issue."
)
Expand All @@ -233,6 +253,12 @@ def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
# NOTE: should only be hit if ignore_returns == True
return ()

@impl(func.ConstantNone)
def emit_constant_none(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.ConstantNone
):
return ()


@py.indexing.dialect.register(key="emit.cirq")
class __Concrete(interp.MethodTable):
Expand All @@ -241,3 +267,19 @@ class __Concrete(interp.MethodTable):
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
# NOTE: no support for indexing into single statements in cirq
return ()

@interp.impl(py.Constant)
def emit_constant(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: py.Constant):
return (stmt.value.data,) # pyright: ignore[reportAttributeAccessIssue]


@ilist.dialect.register(key="emit.cirq")
class __IList(interp.MethodTable):
@interp.impl(ilist.New)
def new_ilist(
self,
emit: EmitCirq,
frame: interp.Frame,
stmt: ilist.New,
):
return (ilist.IList(data=frame.get_values(stmt.values)),)
16 changes: 8 additions & 8 deletions src/bloqade/cirq_utils/emit/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def hermitian(
):
qubits = frame.get(stmt.qubits)
cirq_op = getattr(cirq, stmt.name.upper())
frame.circuit.append(cirq_op.on_each(qubits))
emit.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(gate.stmts.S)
Expand All @@ -36,7 +36,7 @@ def unitary(
if stmt.adjoint:
cirq_op = cirq_op ** (-1)

frame.circuit.append(cirq_op.on_each(qubits))
emit.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(gate.stmts.SqrtX)
Expand All @@ -58,7 +58,7 @@ def sqrt(
else:
cirq_op = cirq.YPowGate(exponent=exponent)

frame.circuit.append(cirq_op.on_each(qubits))
emit.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(gate.stmts.CX)
Expand All @@ -71,7 +71,7 @@ def control(
targets = frame.get(stmt.targets)
cirq_op = getattr(cirq, stmt.name.upper())
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
frame.circuit.append(cirq_op.on_each(cirq_qubits))
emit.circuit.append(cirq_op.on_each(cirq_qubits))
return ()

@impl(gate.stmts.Rx)
Expand All @@ -84,7 +84,7 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.RotationGat
angle = turns * 2 * math.pi
cirq_op = getattr(cirq, stmt.name.title())(rads=angle)

frame.circuit.append(cirq_op.on_each(qubits))
emit.circuit.append(cirq_op.on_each(qubits))
return ()

@impl(gate.stmts.U3)
Expand All @@ -95,10 +95,10 @@ def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.U3):
phi = frame.get(stmt.phi) * 2 * math.pi
lam = frame.get(stmt.lam) * 2 * math.pi

frame.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
emit.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))

frame.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
emit.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))

frame.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
emit.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))

return ()
8 changes: 4 additions & 4 deletions src/bloqade/cirq_utils/emit/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def depolarize(
p = frame.get(stmt.p)
qubits = frame.get(stmt.qubits)
cirfq_op = cirq.depolarize(p, n_qubits=1).on_each(qubits)
frame.circuit.append(cirfq_op)
interp.circuit.append(cirfq_op)
return ()

@impl(noise.stmts.Depolarize2)
Expand All @@ -46,7 +46,7 @@ def depolarize2(
targets = frame.get(stmt.targets)
cirq_qubits = [(ctrl, target) for ctrl, target in zip(controls, targets)]
cirq_op = cirq.depolarize(p, n_qubits=2).on_each(cirq_qubits)
frame.circuit.append(cirq_op)
interp.circuit.append(cirq_op)
return ()

@impl(noise.stmts.SingleQubitPauliChannel)
Expand All @@ -62,7 +62,7 @@ def single_qubit_pauli_channel(
qubits = frame.get(stmt.qubits)

cirq_op = cirq.asymmetric_depolarize(px, py, pz).on_each(qubits)
frame.circuit.append(cirq_op)
interp.circuit.append(cirq_op)

return ()

Expand All @@ -85,6 +85,6 @@ def two_qubit_pauli_channel(
cirq_op = cirq.asymmetric_depolarize(
error_probabilities=error_probabilities
).on_each(cirq_qubits)
frame.circuit.append(cirq_op)
interp.circuit.append(cirq_op)

return ()
4 changes: 2 additions & 2 deletions src/bloqade/cirq_utils/emit/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def measure_qubit_list(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
):
qbits = frame.get(stmt.qubits)
frame.circuit.append(cirq.measure(qbits))
emit.circuit.append(cirq.measure(qbits))
return (emit.void,)

@impl(qubit.Reset)
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Reset):
qubits = frame.get(stmt.qubits)
frame.circuit.append(
emit.circuit.append(
cirq.ResetChannel().on_each(*qubits),
)
return ()
4 changes: 1 addition & 3 deletions src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,7 @@ def run(
# NOTE: create a new register of appropriate size
n_qubits = len(self.qreg_index)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(
func.Invoke((n.result,), callee=qalloc, kwargs=())
).result
self.qreg = frame.push(func.Invoke((n.result,), callee=qalloc)).result

self.visit(state, stmt)

Expand Down
1 change: 0 additions & 1 deletion src/bloqade/native/upstream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .squin2native import (
GateRule as GateRule,
SquinToNative as SquinToNative,
SquinToNativePass as SquinToNativePass,
)
2 changes: 1 addition & 1 deletion src/bloqade/pyqrack/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def task(
kwargs = {}

address_analysis = AddressAnalysis(dialects=kernel.dialects)
frame, _ = address_analysis.run_analysis(kernel)
frame, _ = address_analysis.run(kernel)
if self.min_qubits == 0 and any(
isinstance(a, (UnknownQubit, UnknownReg)) for a in frame.entries.values()
):
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/pyqrack/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _get_interp(self, mt: ir.Method[Params, RetType]):
return PyQrackInterpreter(mt.dialects, memory=DynamicMemory(options))
else:
address_analysis = AddressAnalysis(mt.dialects)
frame, _ = address_analysis.run_analysis(mt)
frame, _ = address_analysis.run(mt)
if self.min_qubits == 0 and any(
isinstance(a, UnknownQubit) for a in frame.entries.values()
):
Expand Down
27 changes: 19 additions & 8 deletions src/bloqade/qasm2/dialects/expr/_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def emit_func(

args: list[ast.Node] = []
cparams, qparams = [], []
for arg in stmt.body.blocks[0].args:
entry_args = stmt.body.blocks[0].args
user_args = entry_args[1:] if len(entry_args) > 0 else []

for arg in user_args:
assert arg.name is not None

args.append(ast.Name(id=arg.name))
Expand All @@ -29,14 +32,22 @@ def emit_func(
else:
cparams.append(arg.name)

emit.run_ssacfg_region(frame, stmt.body, tuple(args))
emit.output = ast.Gate(
name=stmt.sym_name,
cparams=cparams,
qparams=qparams,
body=frame.body,
frame.worklist.append(interp.Successor(stmt.body.blocks[0], *args))
if len(entry_args) > 0:
frame.set(entry_args[0], ast.Name(stmt.sym_name or "gate"))

while (succ := frame.worklist.pop()) is not None:
frame.set_values(succ.block.args[1:], succ.block_args)
block_header = emit.emit_block(frame, succ.block)
frame.block_ref[succ.block] = block_header
return (
ast.Gate(
name=stmt.sym_name,
cparams=cparams,
qparams=qparams,
body=frame.body,
),
)
return ()

@interp.impl(stmts.ConstInt)
@interp.impl(stmts.ConstFloat)
Expand Down
Loading
Loading