diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index d1438b00..7a6034c5 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -1,8 +1,7 @@ -from typing import TypeVar from dataclasses import field from kirin import ir, interp -from kirin.analysis import Forward, const +from kirin.analysis import Forward from kirin.analysis.forward import ForwardFrame from bloqade.types import QubitType @@ -15,7 +14,7 @@ class AddressAnalysis(Forward[Address]): This analysis pass can be used to track the global addresses of qubits and wires. """ - keys = ["qubit.address"] + keys = ("qubit.address",) lattice = Address next_address: int = field(init=False) @@ -24,37 +23,22 @@ def initialize(self): self.next_address: int = 0 return self - @property - def qubit_count(self) -> int: - """Total number of qubits found by the analysis.""" - return self.next_address - - T = TypeVar("T") - - def get_const_value(self, typ: type[T], value: ir.SSAValue) -> T: - if isinstance(hint := value.hints.get("const"), const.Value): - data = hint.data - if isinstance(data, typ): - return hint.data - raise interp.InterpreterError( - f"Expected constant value , got {data}" - ) - raise interp.InterpreterError( - f"Expected constant value , got {value}" - ) + def method_self(self, method: ir.Method) -> Address: + return self.lattice.bottom() - def eval_stmt_fallback( - self, frame: ForwardFrame[Address], stmt: ir.Statement - ) -> tuple[Address, ...] | interp.SpecialValue[Address]: + def eval_fallback( + self, frame: ForwardFrame[Address], node: ir.Statement + ) -> interp.StatementResult[Address]: return tuple( ( self.lattice.top() if result.type.is_subseteq(QubitType) else self.lattice.bottom() ) - for result in stmt.results + for result in node.results ) - def run_method(self, method: ir.Method, args: tuple[Address, ...]): - # NOTE: we do not support dynamic calls here, thus no need to propagate method object - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + @property + def qubit_count(self) -> int: + """Total number of qubits found by the analysis.""" + return self.next_address diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index ee5d8414..5749a7e0 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -4,7 +4,7 @@ from kirin import interp from kirin.analysis import ForwardFrame, const -from kirin.dialects import cf, py, scf, func, ilist +from kirin.dialects import py, scf, func, ilist from bloqade import squin @@ -74,7 +74,7 @@ class PyIndexing(interp.MethodTable): @interp.impl(py.GetItem) def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem): # Integer index into the thing being indexed - idx = interp.get_const_value(int, stmt.index) + idx = interp.expect_const(stmt.index, int) # The object being indexed into obj = frame.get(stmt.obj) # The `data` attributes holds onto other Address types @@ -116,15 +116,8 @@ def invoke(self, interp_: AddressAnalysis, frame: interp.Frame, stmt: func.Invok # TODO: support lambda? -@cf.dialect.register(key="qubit.address") -class Cf(cf.typeinfer.TypeInfer): - # NOTE: cf just re-use the type infer method table - # it's the same process as type infer. - pass - - @scf.dialect.register(key="qubit.address") -class Scf(scf.absint.Methods): +class Scf(interp.MethodTable): @interp.impl(scf.For) def for_loop( @@ -134,7 +127,7 @@ def for_loop( stmt: scf.For, ): if not isinstance(hint := stmt.iterable.hints.get("const"), const.Value): - return interp_.eval_stmt_fallback(frame, stmt) + return interp_.eval_fallback(frame, stmt) iterable = hint.data loop_vars = frame.get_values(stmt.initializers) @@ -144,7 +137,7 @@ def for_loop( # NOTE: we need to actually run iteration in case there are # new allocations/re-assign in the loop body. for _ in iterable: - with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame: + with interp_.new_frame(stmt) as body_frame: body_frame.entries.update(frame.entries) body_frame.set_values( block_args, @@ -218,7 +211,7 @@ def new( frame: ForwardFrame[Address], stmt: squin.qubit.New, ): - n_qubits = interp_.get_const_value(int, stmt.n_qubits) + n_qubits = interp_.expect_const(stmt.n_qubits, int) addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits)) interp_.next_address += n_qubits return (addr,) diff --git a/src/bloqade/qasm2/__init__.py b/src/bloqade/qasm2/__init__.py index a6c61517..3a978507 100644 --- a/src/bloqade/qasm2/__init__.py +++ b/src/bloqade/qasm2/__init__.py @@ -1,11 +1,11 @@ from bloqade.types import Qubit as Qubit, QubitType as QubitType from . import ( - emit as emit, glob as glob, parse as parse, dialects as dialects, parallel as parallel, + emit as emit, ) from .types import ( Bit as Bit, @@ -15,6 +15,6 @@ CRegType as CRegType, QRegType as QRegType, ) -from .groups import gate as gate, main as main, extended as extended +from .groups import main as main, extended as extended from ._wrappers import * # noqa: F403 from ._qasm_loading import loads as loads, loadfile as loadfile diff --git a/src/bloqade/qasm2/emit/__init__.py b/src/bloqade/qasm2/_emit/__init__.py similarity index 100% rename from src/bloqade/qasm2/emit/__init__.py rename to src/bloqade/qasm2/_emit/__init__.py diff --git a/src/bloqade/qasm2/emit/base.py b/src/bloqade/qasm2/_emit/base.py similarity index 100% rename from src/bloqade/qasm2/emit/base.py rename to src/bloqade/qasm2/_emit/base.py diff --git a/src/bloqade/qasm2/emit/gate.py b/src/bloqade/qasm2/_emit/gate.py similarity index 100% rename from src/bloqade/qasm2/emit/gate.py rename to src/bloqade/qasm2/_emit/gate.py diff --git a/src/bloqade/qasm2/emit/main.py b/src/bloqade/qasm2/_emit/main.py similarity index 100% rename from src/bloqade/qasm2/emit/main.py rename to src/bloqade/qasm2/_emit/main.py diff --git a/src/bloqade/qasm2/emit/target.py b/src/bloqade/qasm2/_emit/target.py similarity index 100% rename from src/bloqade/qasm2/emit/target.py rename to src/bloqade/qasm2/_emit/target.py diff --git a/src/bloqade/qasm2/dialects/core/__init__.py b/src/bloqade/qasm2/dialects/core/__init__.py index 0c65c8a6..f8448009 100644 --- a/src/bloqade/qasm2/dialects/core/__init__.py +++ b/src/bloqade/qasm2/dialects/core/__init__.py @@ -1,3 +1,3 @@ -from . import _emit as _emit, address as address, _typeinfer as _typeinfer +from . import address as address, _typeinfer as _typeinfer, _emit as _emit from .stmts import * # noqa: F403 from ._dialect import dialect as dialect diff --git a/src/bloqade/qasm2/dialects/core/_emit.py b/src/bloqade/qasm2/dialects/core/_emit.py index 783a44fb..85344cac 100644 --- a/src/bloqade/qasm2/dialects/core/_emit.py +++ b/src/bloqade/qasm2/dialects/core/_emit.py @@ -1,68 +1,62 @@ +from __future__ import annotations + from kirin import interp +from bloqade.qasm2.emit import QASM2, Frame from bloqade.qasm2.parse import ast -from bloqade.qasm2.emit.main import EmitQASM2Main, EmitQASM2Frame from . import stmts from ._dialect import dialect -@dialect.register(key="emit.qasm2.main") +@dialect.register(key="emit.qasm2") class Core(interp.MethodTable): @interp.impl(stmts.CRegNew) - def emit_creg_new( - self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegNew - ): - n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_bits)) + def emit_creg_new(self, emit: QASM2, frame: Frame, stmt: stmts.CRegNew): + n_bits = frame.get_casted(stmt.n_bits, ast.Number) # check if its int first, because Int.is_integer() is added for >=3.12 assert isinstance(n_bits.value, int), "expected integer" - name = emit.ssa_id[stmt.result] + name = frame.ssa[stmt.result] frame.body.append(ast.CReg(name=name, size=int(n_bits.value))) return (ast.Name(name),) @interp.impl(stmts.QRegNew) - def emit_qreg_new( - self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.QRegNew - ): - n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_qubits)) + def emit_qreg_new(self, emit: QASM2, frame: Frame, stmt: stmts.QRegNew): + n_bits = frame.get_casted(stmt.n_qubits, ast.Number) assert isinstance(n_bits.value, int), "expected integer" - name = emit.ssa_id[stmt.result] + name = frame.ssa[stmt.result] frame.body.append(ast.QReg(name=name, size=int(n_bits.value))) return (ast.Name(name),) @interp.impl(stmts.Reset) - def emit_reset(self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Reset): - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_reset(self, emit: QASM2, frame: Frame, stmt: stmts.Reset): + qarg: ast.Name | ast.Bit = frame.get(stmt.qarg) # type: ignore frame.body.append(ast.Reset(qarg=qarg)) return () @interp.impl(stmts.Measure) - def emit_measure( - self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Measure - ): - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) - carg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.carg)) + def emit_measure(self, emit: QASM2, frame: Frame, stmt: stmts.Measure): + qarg: ast.Bit | ast.Name = frame.get(stmt.qarg) # type: ignore + carg: ast.Name | ast.Bit = frame.get(stmt.carg) # type: ignore frame.body.append(ast.Measure(qarg=qarg, carg=carg)) return () @interp.impl(stmts.CRegEq) - def emit_creg_eq( - self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegEq - ): - lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs)) - rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs)) + def emit_creg_eq(self, emit: QASM2, frame: Frame, stmt: stmts.CRegEq): + lhs = frame.get_casted(stmt.lhs, ast.Expr) + rhs = frame.get_casted(stmt.rhs, ast.Expr) return (ast.Cmp(lhs=lhs, rhs=rhs),) @interp.impl(stmts.CRegGet) @interp.impl(stmts.QRegGet) def emit_qreg_get( self, - emit: EmitQASM2Main, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.QRegGet | stmts.CRegGet, ): - reg = emit.assert_node(ast.Name, frame.get(stmt.reg)) - idx = emit.assert_node(ast.Number, frame.get(stmt.idx)) + reg = frame.get_casted(stmt.reg, ast.Name) + idx = frame.get_casted(stmt.idx, ast.Number) assert isinstance(idx.value, int) return (ast.Bit(reg, int(idx.value)),) diff --git a/src/bloqade/qasm2/dialects/core/address.py b/src/bloqade/qasm2/dialects/core/address.py index f7505b52..1677fe2a 100644 --- a/src/bloqade/qasm2/dialects/core/address.py +++ b/src/bloqade/qasm2/dialects/core/address.py @@ -22,7 +22,7 @@ def new( frame: interp.Frame[Address], stmt: QRegNew, ): - n_qubits = interp.get_const_value(int, stmt.n_qubits) + n_qubits = interp.expect_const(stmt.n_qubits, int) addr = AddressReg(range(interp.next_address, interp.next_address + n_qubits)) interp.next_address += n_qubits return (addr,) @@ -30,7 +30,7 @@ def new( @interp.impl(QRegGet) def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet): addr = frame.get(stmt.reg) - pos = interp.get_const_value(int, stmt.idx) + pos = interp.expect_const(stmt.idx, int) if isinstance(addr, AddressReg): global_idx = addr.data[pos] return (AddressQubit(global_idx),) diff --git a/src/bloqade/qasm2/dialects/expr/__init__.py b/src/bloqade/qasm2/dialects/expr/__init__.py index 15d0bbae..55972afd 100644 --- a/src/bloqade/qasm2/dialects/expr/__init__.py +++ b/src/bloqade/qasm2/dialects/expr/__init__.py @@ -1,3 +1,3 @@ -from . import _emit as _emit, _interp as _interp, _from_python as _from_python +from . import _interp as _interp, _from_python as _from_python, _emit as _emit from .stmts import * # noqa: F403 from ._dialect import dialect as dialect diff --git a/src/bloqade/qasm2/dialects/expr/_emit.py b/src/bloqade/qasm2/dialects/expr/_emit.py index f429cb85..01e4487e 100644 --- a/src/bloqade/qasm2/dialects/expr/_emit.py +++ b/src/bloqade/qasm2/dialects/expr/_emit.py @@ -1,49 +1,24 @@ +from __future__ import annotations from typing import Literal from kirin import interp from bloqade.qasm2.parse import ast -from bloqade.qasm2.types import QubitType -from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame +from bloqade.qasm2.emit import QASM2, Frame from . import stmts from ._dialect import dialect -@dialect.register(key="emit.qasm2.gate") +@dialect.register(key="emit.qasm2") class EmitExpr(interp.MethodTable): - @interp.impl(stmts.GateFunction) - def emit_func( - self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction - ): - - args: list[ast.Node] = [] - cparams, qparams = [], [] - for arg in stmt.body.blocks[0].args: - assert arg.name is not None - - args.append(ast.Name(id=arg.name)) - if arg.type.is_subseteq(QubitType): - qparams.append(arg.name) - else: - cparams.append(arg.name) - - emit.run_ssacfg_region(frame, stmt.body, tuple(args)) - emit.output = ast.Gate( - name=stmt.sym_name, - cparams=cparams, - qparams=qparams, - body=frame.body, - ) - return () - @interp.impl(stmts.ConstInt) @interp.impl(stmts.ConstFloat) def emit_const_int( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.ConstInt | stmts.ConstFloat, ): return (ast.Number(stmt.value),) @@ -51,15 +26,15 @@ def emit_const_int( @interp.impl(stmts.ConstPI) def emit_const_pi( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.ConstPI, ): return (ast.Pi(),) @interp.impl(stmts.Neg) - def emit_neg(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Neg): - arg = emit.assert_node(ast.Expr, frame.get(stmt.value)) + def emit_neg(self, emit: QASM2, frame: Frame, stmt: stmts.Neg): + arg = frame.get_casted(stmt.value, ast.Expr) return (ast.UnaryOp("-", arg),) @interp.impl(stmts.Sin) @@ -68,37 +43,37 @@ def emit_neg(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Neg): @interp.impl(stmts.Exp) @interp.impl(stmts.Log) @interp.impl(stmts.Sqrt) - def emit_sin(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt): - arg = emit.assert_node(ast.Expr, frame.get(stmt.value)) + def emit_sin(self, emit: QASM2, frame: Frame, stmt): + arg = frame.get_casted(stmt.value, ast.Expr) return (ast.Call(stmt.name, [arg]),) def emit_binop( self, sym: Literal["+", "-", "*", "/", "^"], - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt, ): - lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs)) - rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs)) + lhs = frame.get_casted(stmt.lhs, ast.Expr) + rhs = frame.get_casted(stmt.rhs, ast.Expr) return (ast.BinOp(sym, lhs, rhs),) @interp.impl(stmts.Add) - def emit_add(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add): + def emit_add(self, emit: QASM2, frame: Frame, stmt: stmts.Add): return self.emit_binop("+", emit, frame, stmt) @interp.impl(stmts.Sub) - def emit_sub(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add): + def emit_sub(self, emit: QASM2, frame: Frame, stmt: stmts.Add): return self.emit_binop("-", emit, frame, stmt) @interp.impl(stmts.Mul) - def emit_mul(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add): + def emit_mul(self, emit: QASM2, frame: Frame, stmt: stmts.Add): return self.emit_binop("*", emit, frame, stmt) @interp.impl(stmts.Div) - def emit_div(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add): + def emit_div(self, emit: QASM2, frame: Frame, stmt: stmts.Add): return self.emit_binop("/", emit, frame, stmt) @interp.impl(stmts.Pow) - def emit_pow(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add): + def emit_pow(self, emit: QASM2, frame: Frame, stmt: stmts.Add): return self.emit_binop("^", emit, frame, stmt) diff --git a/src/bloqade/qasm2/dialects/expr/stmts.py b/src/bloqade/qasm2/dialects/expr/stmts.py index 4e9c0bf3..57793cab 100644 --- a/src/bloqade/qasm2/dialects/expr/stmts.py +++ b/src/bloqade/qasm2/dialects/expr/stmts.py @@ -1,50 +1,12 @@ +from __future__ import annotations + from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.print.printer import Printer -from kirin.dialects.func.attrs import Signature from ._dialect import dialect -class GateFuncOpCallableInterface(ir.CallableStmtInterface["GateFunction"]): - - @classmethod - def get_callable_region(cls, stmt: "GateFunction") -> ir.Region: - return stmt.body - - -@statement(dialect=dialect) -class GateFunction(ir.Statement): - """Special Function for qasm2 gate subroutine.""" - - name = "gate.func" - traits = frozenset( - { - ir.IsolatedFromAbove(), - ir.SymbolOpInterface(), - ir.HasSignature(), - GateFuncOpCallableInterface(), - } - ) - sym_name: str = info.attribute() - signature: Signature = info.attribute() - body: ir.Region = info.region(multi=True) - - def print_impl(self, printer: Printer) -> None: - with printer.rich(style="red"): - printer.plain_print(self.name + " ") - - with printer.rich(style="cyan"): - printer.plain_print(self.sym_name) - - self.signature.print_impl(printer) - printer.plain_print(" ") - self.body.print_impl(printer) - - with printer.rich(style="black"): - printer.plain_print(f" // gate.func {self.sym_name}") - - @statement(dialect=dialect) class ConstInt(ir.Statement): """IR Statement representing a constant integer value.""" diff --git a/src/bloqade/qasm2/dialects/glob.py b/src/bloqade/qasm2/dialects/glob.py index 04defbfa..5fe2a856 100644 --- a/src/bloqade/qasm2/dialects/glob.py +++ b/src/bloqade/qasm2/dialects/glob.py @@ -2,9 +2,7 @@ from kirin.decl import info, statement from kirin.dialects import ilist -from bloqade.qasm2.parse import ast from bloqade.qasm2.types import QRegType -from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame from bloqade.squin.analysis.schedule import DagScheduleAnalysis dialect = ir.Dialect("qasm2.glob") @@ -28,18 +26,18 @@ def ugate(self, interp: DagScheduleAnalysis, frame: interp.Frame, stmt: UGate): return () -@dialect.register(key="emit.qasm2.gate") -class GlobEmit(interp.MethodTable): - @interp.impl(UGate) - def ugate(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: UGate): - registers = [ - emit.assert_node(ast.Name, reg) - for reg in frame.get_casted(stmt.registers, ilist.IList) - ] - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - frame.body.append( - ast.GlobUGate(theta=theta, phi=phi, lam=lam, registers=registers) - ) - return () +# @dialect.register(key="emit.qasm2.gate") +# class GlobEmit(interp.MethodTable): +# @interp.impl(UGate) +# def ugate(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: UGate): +# registers = [ +# emit.assert_node(ast.Name, reg) +# for reg in frame.get_casted(stmt.registers, ilist.IList) +# ] +# theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) +# phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) +# lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) +# frame.body.append( +# ast.GlobUGate(theta=theta, phi=phi, lam=lam, registers=registers) +# ) +# return () diff --git a/src/bloqade/qasm2/dialects/parallel.py b/src/bloqade/qasm2/dialects/parallel.py index 216b4d91..f428d444 100644 --- a/src/bloqade/qasm2/dialects/parallel.py +++ b/src/bloqade/qasm2/dialects/parallel.py @@ -1,13 +1,11 @@ -from typing import Any - from kirin import ir, types, interp, lowering from kirin.decl import info, statement from kirin.analysis import ForwardFrame from kirin.dialects import ilist -from bloqade.qasm2.parse import ast from bloqade.qasm2.types import QubitType -from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame + +# from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame from bloqade.squin.analysis.schedule import DagScheduleAnalysis dialect = ir.Dialect("qasm2.parallel") @@ -41,49 +39,49 @@ class RZ(ir.Statement): theta: ir.SSAValue = info.argument(types.Float) -@dialect.register(key="emit.qasm2.gate") -class Parallel(interp.MethodTable): - - def _emit_parallel_qargs( - self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, qargs: ir.SSAValue - ): - qargs_: ilist.IList[ast.Node, Any] = frame.get(qargs) # type: ignore - return [(emit.assert_node((ast.Name, ast.Bit), qarg),) for qarg in qargs_] - - @interp.impl(UGate) - def ugate(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: UGate): - qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - frame.body.append( - ast.ParaU3Gate( - theta=theta, phi=phi, lam=lam, qargs=ast.ParallelQArgs(qargs=qargs) - ) - ) - return () - - @interp.impl(RZ) - def rz(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: RZ): - qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - frame.body.append( - ast.ParaRZGate(theta=theta, qargs=ast.ParallelQArgs(qargs=qargs)) - ) - return () - - @interp.impl(CZ) - def cz(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: CZ): - ctrls = self._emit_parallel_qargs(emit, frame, stmt.ctrls) - qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) - frame.body.append( - ast.ParaCZGate( - qargs=ast.ParallelQArgs( - qargs=[ctrl + qarg for ctrl, qarg in zip(ctrls, qargs)] - ) - ) - ) - return () +# @dialect.register(key="emit.qasm2.gate") +# class Parallel(interp.MethodTable): + +# def _emit_parallel_qargs( +# self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, qargs: ir.SSAValue +# ): +# qargs_: ilist.IList[ast.Node, Any] = frame.get(qargs) # type: ignore +# return [(emit.assert_node((ast.Name, ast.Bit), qarg),) for qarg in qargs_] + +# @interp.impl(UGate) +# def ugate(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: UGate): +# qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) +# theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) +# phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) +# lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) +# frame.body.append( +# ast.ParaU3Gate( +# theta=theta, phi=phi, lam=lam, qargs=ast.ParallelQArgs(qargs=qargs) +# ) +# ) +# return () + +# @interp.impl(RZ) +# def rz(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: RZ): +# qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) +# theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) +# frame.body.append( +# ast.ParaRZGate(theta=theta, qargs=ast.ParallelQArgs(qargs=qargs)) +# ) +# return () + +# @interp.impl(CZ) +# def cz(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: CZ): +# ctrls = self._emit_parallel_qargs(emit, frame, stmt.ctrls) +# qargs = self._emit_parallel_qargs(emit, frame, stmt.qargs) +# frame.body.append( +# ast.ParaCZGate( +# qargs=ast.ParallelQArgs( +# qargs=[ctrl + qarg for ctrl, qarg in zip(ctrls, qargs)] +# ) +# ) +# ) +# return () @dialect.register(key="qasm2.schedule.dag") diff --git a/src/bloqade/qasm2/dialects/uop/__init__.py b/src/bloqade/qasm2/dialects/uop/__init__.py index 7ab47ef6..e926539a 100644 --- a/src/bloqade/qasm2/dialects/uop/__init__.py +++ b/src/bloqade/qasm2/dialects/uop/__init__.py @@ -1,4 +1,4 @@ -from . import _emit as _emit, stmts as stmts +from . import stmts as stmts, _emit as _emit from .stmts import * # noqa: F403 from ._dialect import dialect as dialect from .schedule import * # noqa: F403 diff --git a/src/bloqade/qasm2/dialects/uop/_emit.py b/src/bloqade/qasm2/dialects/uop/_emit.py index e3ded03f..ecae352f 100644 --- a/src/bloqade/qasm2/dialects/uop/_emit.py +++ b/src/bloqade/qasm2/dialects/uop/_emit.py @@ -1,52 +1,51 @@ +from __future__ import annotations + from kirin import interp +from bloqade.qasm2.emit import QASM2, Frame from bloqade.qasm2.parse import ast -from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame from . import stmts from ._dialect import dialect -@dialect.register(key="emit.qasm2.gate") +@dialect.register(key="emit.qasm2") class UOp(interp.MethodTable): @interp.impl(stmts.CX) def emit_cx( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.CX, ): - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append(ast.CXGate(ctrl=ctrl, qarg=qarg)) return () @interp.impl(stmts.UGate) def emit_ugate( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.UGate, ): - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + theta = frame.get_casted(stmt.theta, ast.Expr) + phi = frame.get_casted(stmt.phi, ast.Expr) + lam = frame.get_casted(stmt.lam, ast.Expr) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append(ast.UGate(theta=theta, phi=phi, lam=lam, qarg=qarg)) return () @interp.impl(stmts.Barrier) def emit_barrier( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.Barrier, ): - qargs = [ - emit.assert_node((ast.Bit, ast.Name), frame.get(qarg)) - for qarg in stmt.qargs - ] + qargs = [frame.get_casted(qarg, (ast.Bit, ast.Name)) for qarg in stmt.qargs] frame.body.append(ast.Barrier(qargs=qargs)) return () @@ -62,9 +61,9 @@ def emit_barrier( @interp.impl(stmts.T) @interp.impl(stmts.Tdag) def emit_single_qubit_gate( - self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.SingleQubitGate + self, emit: QASM2, frame: Frame, stmt: stmts.SingleQubitGate ): - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[], qargs=[qarg]) ) @@ -75,31 +74,31 @@ def emit_single_qubit_gate( @interp.impl(stmts.RZ) def emit_rotation( self, - emit: EmitQASM2Gate, - frame: EmitQASM2Frame, + emit: QASM2, + frame: Frame, stmt: stmts.RX | stmts.RY | stmts.RZ, ): - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + theta = frame.get_casted(stmt.theta, ast.Expr) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[theta], qargs=[qarg]) ) return () @interp.impl(stmts.U1) - def emit_u1(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.U1): - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_u1(self, emit: QASM2, frame: Frame, stmt: stmts.U1): + lam = frame.get_casted(stmt.lam, ast.Expr) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[lam], qargs=[qarg]) ) return () @interp.impl(stmts.U2) - def emit_u2(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.U2): - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_u2(self, emit: QASM2, frame: Frame, stmt: stmts.U2): + phi = frame.get_casted(stmt.phi, ast.Expr) + lam = frame.get_casted(stmt.lam, ast.Expr) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[phi, lam], qargs=[qarg]) ) @@ -110,21 +109,20 @@ def emit_u2(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.U2): @interp.impl(stmts.CZ) @interp.impl(stmts.CY) @interp.impl(stmts.CH) - def emit_two_qubit_gate( - self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CZ - ): - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_two_qubit_gate(self, emit: QASM2, frame: Frame, stmt: stmts.CZ): + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[], qargs=[ctrl, qarg]) ) return () @interp.impl(stmts.CCX) - def emit_ccx(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CCX): - ctrl1 = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl1)) - ctrl2 = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl2)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_ccx(self, emit: QASM2, frame: Frame, stmt: stmts.CCX): + ctrl1 = frame.get_casted(stmt.ctrl1, (ast.Bit, ast.Name)) + ctrl2 = frame.get_casted(stmt.ctrl2, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction( name=ast.Name(stmt.name), params=[], qargs=[ctrl1, ctrl2, qarg] @@ -133,10 +131,11 @@ def emit_ccx(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CCX): return () @interp.impl(stmts.CSwap) - def emit_cswap(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CSwap): - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg1 = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg1)) - qarg2 = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg2)) + def emit_cswap(self, emit: QASM2, frame: Frame, stmt: stmts.CSwap): + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg1 = frame.get_casted(stmt.qarg1, (ast.Bit, ast.Name)) + qarg2 = frame.get_casted(stmt.qarg2, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction( name=ast.Name(stmt.name), params=[], qargs=[ctrl, qarg1, qarg2] @@ -147,32 +146,35 @@ def emit_cswap(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CSw @interp.impl(stmts.CRZ) @interp.impl(stmts.CRY) @interp.impl(stmts.CRX) - def emit_cr(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CRX): - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_cr(self, emit: QASM2, frame: Frame, stmt: stmts.CRX): + lam = frame.get_casted(stmt.lam, ast.Expr) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[lam], qargs=[ctrl, qarg]) ) return () @interp.impl(stmts.CU1) - def emit_cu1(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CU1): - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_cu1(self, emit: QASM2, frame: Frame, stmt: stmts.CU1): + lam = frame.get_casted(stmt.lam, ast.Expr) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction(name=ast.Name(stmt.name), params=[lam], qargs=[ctrl, qarg]) ) return () @interp.impl(stmts.CU3) - def emit_cu3(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CU3): - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_cu3(self, emit: QASM2, frame: Frame, stmt: stmts.CU3): + theta = frame.get_casted(stmt.theta, ast.Expr) + phi = frame.get_casted(stmt.phi, ast.Expr) + lam = frame.get_casted(stmt.lam, ast.Expr) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction( name=ast.Name(stmt.name), params=[theta, phi, lam], qargs=[ctrl, qarg] @@ -181,13 +183,14 @@ def emit_cu3(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CU3): return () @interp.impl(stmts.CU) - def emit_cu(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CU): - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) - lam = emit.assert_node(ast.Expr, frame.get(stmt.lam)) - gamma = emit.assert_node(ast.Expr, frame.get(stmt.gamma)) - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_cu(self, emit: QASM2, frame: Frame, stmt: stmts.CU): + theta = frame.get_casted(stmt.theta, ast.Expr) + phi = frame.get_casted(stmt.phi, ast.Expr) + lam = frame.get_casted(stmt.lam, ast.Expr) + gamma = frame.get_casted(stmt.gamma, ast.Expr) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction( name=ast.Name(stmt.name), @@ -199,10 +202,11 @@ def emit_cu(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.CU): @interp.impl(stmts.RZZ) @interp.impl(stmts.RXX) - def emit_r2q(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.RZZ): - theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) - ctrl = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.ctrl)) - qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg)) + def emit_r2q(self, emit: QASM2, frame: Frame, stmt: stmts.RZZ): + theta = frame.get_casted(stmt.theta, ast.Expr) + ctrl = frame.get_casted(stmt.ctrl, (ast.Bit, ast.Name)) + qarg = frame.get_casted(stmt.qarg, (ast.Bit, ast.Name)) + frame.body.append( ast.Instruction( name=ast.Name(stmt.name), params=[theta], qargs=[ctrl, qarg] diff --git a/src/bloqade/qasm2/emit.py b/src/bloqade/qasm2/emit.py new file mode 100644 index 00000000..3624c341 --- /dev/null +++ b/src/bloqade/qasm2/emit.py @@ -0,0 +1,194 @@ +from abc import ABC +from dataclasses import field, dataclass + +from kirin import ir, emit, interp, idtable +from kirin.dialects import scf, func +from kirin.worklist import WorkList + +from bloqade.qasm2.parse import ast +from bloqade.qasm2.types import QubitType +from bloqade.qasm2.dialects import glob, noise, parallel + + +@dataclass +class Frame(emit.EmitFrame[ast.Node | None]): + ssa: idtable.IdTable[ir.SSAValue] + body: list[ast.Statement] = field(default_factory=list) + + +@dataclass +class QASM2(emit.EmitABC[Frame, ast.Node | None], ABC): + keys = ("emit.qasm2",) + void = None + + # options + prefix: str = field(default="", kw_only=True) + prefix_if_none: str = field(default="var_", kw_only=True) + + # state + callables: emit.julia.SymbolTable = field(init=False) + worklist: WorkList[ir.Statement] = field(init=False) + + def initialize(self): + super().initialize() + self.callables = emit.julia.SymbolTable() + self.worklist = WorkList() + return self + + def initialize_frame( + self, node: ir.Statement, *, has_parent_access: bool = False + ) -> Frame: + return Frame( + node, + ssa=idtable.IdTable[ir.SSAValue]( + prefix=self.prefix, + prefix_if_none=self.prefix_if_none, + ), + has_parent_access=has_parent_access, + ) + + def run(self, node: ir.Method | ir.Statement): + if isinstance(node, ir.Method): + node = node.code + + if self.dialects.data.intersection( + (parallel.dialect, glob.dialect, noise.dialect) + ): + header = ast.Kirin([dialect.name for dialect in self.dialects]) + else: + header = ast.OPENQASM(ast.Version(2, 0)) + + body = node.get_present_trait(ir.CallableStmtInterface).get_callable_region( + node + ) + with self.eval_context(): + block = body.blocks[0] + with self.new_frame(node) as frame: + frame.current_block = block + frame.current_stmt = block.first_stmt + if len(body.blocks) != 1: + raise interp.InterpreterError( + "QASM2 does not support general control flow" + ) + + for stmt in block.stmts: + if isinstance(stmt, func.ConstantNone): + continue + if isinstance(stmt, func.Return): + # QASM2 does not support return values + if not isinstance(stmt.value.owner, func.ConstantNone): + raise interp.InterpreterError( + "QASM2 does not support return values, kernel must return None" + ) + break + + ret = self.frame_eval(frame, stmt) + if isinstance(ret, tuple): + frame.set_values(stmt.results, ret) + elif ret is not None: + raise interp.InterpreterError( + f"QASM2 does not support return values or general control flows: {ret}" + ) + + code = ast.MainProgram(header=header, statements=frame.body) + + while self.worklist: + callable = self.worklist.pop() + if callable is None: + break + + _, gate = self.eval(callable) + if isinstance(gate, tuple) and isinstance(gate[0], ast.Gate): + code.statements.insert(0, gate[0]) # insert at the beginning + else: + raise interp.InterpreterError( + f"invalid result generated by {callable}: {gate}" + ) + + return code + + +@func.dialect.register(key="emit.qasm2") +class Func(interp.MethodTable): + + @interp.impl(func.Function) + def emit_func( + self, + emit: QASM2, + frame: Frame, + stmt: func.Function, + ): + args: list[ast.Node] = [] + cparams, qparams = [], [] + for arg in stmt.body.blocks[0].args: + assert arg.name is not None + + args.append(ast.Name(id=arg.name)) + if arg.type.is_subseteq(QubitType): + qparams.append(arg.name) + else: + cparams.append(arg.name) + + if len(stmt.body.blocks) != 1: + raise interp.InterpreterError( + f"Gate function {stmt.name} must have exactly one block" + ) + + block = stmt.body.blocks[0] + with emit.new_frame(stmt) as frame: + for node in block.stmts: + if isinstance(node, func.ConstantNone): + continue + + # NOTE: this is a single block if we see a return + # statement, we break. In QASM2 return can only + # be return None when used as a gate function + if isinstance(node, func.Return): + break + + ret = emit.frame_eval(frame, node) + if isinstance(ret, tuple): + frame.set_values(node.results, ret) + elif ret is not None: + raise interp.InterpreterError( + "QASM2 does not support return values or general control flows" + ) + + return ( + ast.Gate( + name=stmt.sym_name, cparams=cparams, qparams=qparams, body=frame.body + ), + ) + + +@scf.dialect.register(key="emit.qasm2") +class Scf(interp.MethodTable): + + @interp.impl(scf.IfElse) + def emit_if_else( + self, + emit: QASM2, + frame: Frame, + stmt: scf.IfElse, + ): + assert ( + len(stmt.then_body.blocks) == 1 + ), "QASM2 if can only have a single statement" + block = stmt.then_body.blocks[0] + first_stmt = block.first_stmt + last_stmt = block.last_stmt + assert ( + first_stmt is not None and len(block.stmts) == 2 + ), "QASM2 if can only have a single statement" + assert ( + isinstance(last_stmt, scf.Yield) and not last_stmt.values + ), "QASM2 if can only have a single gate statement" + assert len(stmt.else_body.blocks) == 0, "QASM2 if does not support else" + + with emit.new_frame(stmt) as body_frame: + emit.frame_eval(body_frame, first_stmt) + + frame.body.append( + ast.IfStmt(cond=frame.get_typed(stmt.cond, ast.Cmp), body=body_frame.body) + ) + return diff --git a/src/bloqade/qasm2/emit__/__init__.py b/src/bloqade/qasm2/emit__/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/qasm2/emit__/base.py b/src/bloqade/qasm2/emit__/base.py new file mode 100644 index 00000000..8c25481d --- /dev/null +++ b/src/bloqade/qasm2/emit__/base.py @@ -0,0 +1,51 @@ +from abc import ABC +from dataclasses import field, dataclass + +from kirin import ir, emit, idtable +from kirin.worklist import WorkList + +from bloqade.qasm2.parse import ast + + +@dataclass +class QASM2EmitFrame(emit.EmitFrame[ast.Node | None]): + ssa: idtable.IdTable[ir.SSAValue] + body: list[ast.Statement] = field(default_factory=list) + + +@dataclass +class QASM2EmitBase(emit.EmitABC[QASM2EmitFrame, ast.Node | None], ABC): + void = None + + # options + prefix: str = field(default="", kw_only=True) + prefix_if_none: str = field(default="var_", kw_only=True) + + # state + callables: emit.julia.SymbolTable = field(init=False) + worklist: WorkList[ir.Statement] = field(init=False) + + def initialize_frame( + self, node: ir.Statement, *, has_parent_access: bool = False + ) -> QASM2EmitFrame: + return QASM2EmitFrame( + node, + ssa=idtable.IdTable[ir.SSAValue]( + prefix=self.prefix, + prefix_if_none=self.prefix_if_none, + ), + has_parent_access=has_parent_access, + ) + + def run(self, node: ir.Method | ir.Statement): + if isinstance(node, ir.Method): + node = node.code + + with self.eval_context(): + self.callables.add(node) + self.worklist.append(node) + while self.worklist: + callable = self.worklist.pop() + if callable is None: + break + frame, _ = self.eval(callable) diff --git a/src/bloqade/qasm2/groups.py b/src/bloqade/qasm2/groups.py index 280638c0..c67b6b9d 100644 --- a/src/bloqade/qasm2/groups.py +++ b/src/bloqade/qasm2/groups.py @@ -1,6 +1,6 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt -from kirin.dialects import scf, func, ilist, lowering +from kirin.dialects import scf, func, ilist, ssacfg, lowering from bloqade.qasm2.dialects import ( uop, @@ -15,37 +15,37 @@ from bloqade.qasm2.rewrite.desugar import IndexingDesugarPass -@ir.dialect_group([uop, func, expr, lowering.func, lowering.call]) -def gate(self): - fold_pass = passes.Fold(self) - typeinfer_pass = passes.TypeInfer(self) +# @ir.dialect_group([uop, func, expr, lowering.func, lowering.call, ssacfg]) +# def gate(self): +# fold_pass = passes.Fold(self) +# typeinfer_pass = passes.TypeInfer(self) - def run_pass( - method: ir.Method, - *, - fold: bool = True, - ): - method.verify() +# def run_pass( +# method: ir.Method, +# *, +# fold: bool = True, +# ): +# method.verify() - if isinstance(method.code, func.Function): - new_code = expr.GateFunction( - sym_name=method.code.sym_name, - signature=method.code.signature, - body=method.code.body, - ) - method.code = new_code - else: - raise ValueError( - "Gate Method code must be a Function, cannot be lambda/closure" - ) +# if isinstance(method.code, func.Function): +# new_code = expr.GateFunction( +# sym_name=method.code.sym_name, +# signature=method.code.signature, +# body=method.code.body, +# ) +# method.code = new_code +# else: +# raise ValueError( +# "Gate Method code must be a Function, cannot be lambda/closure" +# ) - if fold: - fold_pass(method) +# if fold: +# fold_pass(method) - typeinfer_pass(method) - method.verify_type() +# typeinfer_pass(method) +# method.verify_type() - return run_pass +# return run_pass @ir.dialect_group( @@ -56,6 +56,7 @@ def run_pass( scf, indexing, func, + ssacfg, lowering.func, lowering.call, ] diff --git a/src/bloqade/qasm2/parse/ast.py b/src/bloqade/qasm2/parse/ast.py index f9a4aa97..33cb053e 100644 --- a/src/bloqade/qasm2/parse/ast.py +++ b/src/bloqade/qasm2/parse/ast.py @@ -63,7 +63,7 @@ class Gate(Statement): name: str cparams: list[str] qparams: list[str] - body: list[UOp | Barrier] + body: list[Statement] # list[UOp | Barrier] @dataclass @@ -81,7 +81,7 @@ class QOp(Statement): @dataclass class IfStmt(Statement): cond: Cmp - body: list[QOp] + body: list[Statement] # list[QOp] @dataclass diff --git a/src/bloqade/qasm2/passes/fold.py b/src/bloqade/qasm2/passes/fold.py index afb1b880..67eb9dfc 100644 --- a/src/bloqade/qasm2/passes/fold.py +++ b/src/bloqade/qasm2/passes/fold.py @@ -21,8 +21,6 @@ from kirin.ir.method import Method from kirin.rewrite.abc import RewriteResult -from bloqade.qasm2.dialects import expr - @dataclass class QASM2Fold(Pass): @@ -71,8 +69,8 @@ def unsafe_run(self, mt: Method) -> RewriteResult: ) def inline_simple(node: ir.Statement): - if isinstance(node, expr.GateFunction): - return self.inline_gate_subroutine + # if isinstance(node, expr.GateFunction): + # return self.inline_gate_subroutine if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)): return True # always inline calls outside of loops and if-else diff --git a/src/bloqade/qbraid/target.py b/src/bloqade/qbraid/target.py index 994dcf4a..8f0a4da5 100644 --- a/src/bloqade/qbraid/target.py +++ b/src/bloqade/qbraid/target.py @@ -6,7 +6,7 @@ from qbraid import QbraidProvider from qbraid.runtime import QbraidJob -from bloqade.qasm2.emit import QASM2 +# from bloqade.qasm2.emit import QASM2 class qBraid: @@ -71,16 +71,16 @@ def emit( An object you can query for the status of your submission as well as obtain simulator results from. """ - + ... # Convert method to QASM2 string - qasm2_emitter = QASM2( - allow_parallel=self.allow_parallel, - allow_global=self.allow_global, - qelib1=self.qelib1, - ) - qasm2_prog = qasm2_emitter.emit_str(method) + # # qasm2_emitter = QASM2( + # # allow_parallel=self.allow_parallel, + # # allow_global=self.allow_global, + # # qelib1=self.qelib1, + # # ) + # # qasm2_prog = qasm2_emitter.emit_str(method) - # Submit the QASM2 string to the qBraid simulator - quera_qasm_simulator = self.provider.get_device("quera_qasm_simulator") + # # Submit the QASM2 string to the qBraid simulator + # quera_qasm_simulator = self.provider.get_device("quera_qasm_simulator") - return quera_qasm_simulator.run(qasm2_prog, shots=shots, tags=tags) + # return quera_qasm_simulator.run(qasm2_prog, shots=shots, tags=tags) diff --git a/test/analysis/address/test_analysis.py b/test/analysis/address/test_analysis.py index 0d205b8d..c4642e49 100644 --- a/test/analysis/address/test_analysis.py +++ b/test/analysis/address/test_analysis.py @@ -52,8 +52,8 @@ def test_unwrap(): fold_pass = Fold(squin_with_qasm_core) fold_pass(constructed_method) - frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False + frame, _ = address.AddressAnalysis(constructed_method.dialects).run( + constructed_method ) address_wires = [] @@ -115,8 +115,8 @@ def test_multiple_unwrap(): fold_pass = Fold(squin_with_qasm_core) fold_pass(constructed_method) - frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False + frame, _ = address.AddressAnalysis(constructed_method.dialects).run( + constructed_method ) address_wire_parent_qubit_0 = [] @@ -180,10 +180,10 @@ def test_multiple_wire_apply(): fold_pass(constructed_method) # const_prop = const.Propagate(squin_with_qasm_core) - # frame, _ = const_prop.run_analysis(method=constructed_method, no_raise=False) + # frame, _ = const_prop.run(method=constructed_method, no_raise=False) - frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False + frame, _ = address.AddressAnalysis(constructed_method.dialects).run( + constructed_method ) address_wire_parent_qubit_0 = []