diff --git a/src/bloqade/native/_prelude.py b/src/bloqade/native/_prelude.py index 0776de0c..e9f42b99 100644 --- a/src/bloqade/native/_prelude.py +++ b/src/bloqade/native/_prelude.py @@ -7,10 +7,10 @@ from bloqade import qubit -from .dialects import gates +from .dialects import gate -@ir.dialect_group(structural_no_opt.union([gates, qubit])) +@ir.dialect_group(structural_no_opt.union([gate, qubit])) def kernel(self): """Compile a function to a native kernel.""" diff --git a/src/bloqade/native/dialects/gate/__init__.py b/src/bloqade/native/dialects/gate/__init__.py new file mode 100644 index 00000000..79898df3 --- /dev/null +++ b/src/bloqade/native/dialects/gate/__init__.py @@ -0,0 +1,2 @@ +from . import stmts as stmts +from ._dialect import dialect as dialect diff --git a/src/bloqade/native/dialects/gate/_dialect.py b/src/bloqade/native/dialects/gate/_dialect.py new file mode 100644 index 00000000..873f9911 --- /dev/null +++ b/src/bloqade/native/dialects/gate/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("native.gate") diff --git a/src/bloqade/native/dialects/gates/_interface.py b/src/bloqade/native/dialects/gate/_interface.py similarity index 64% rename from src/bloqade/native/dialects/gates/_interface.py rename to src/bloqade/native/dialects/gate/_interface.py index 6d68a589..a36e2e9a 100644 --- a/src/bloqade/native/dialects/gates/_interface.py +++ b/src/bloqade/native/dialects/gate/_interface.py @@ -12,21 +12,21 @@ @lowering.wraps(CZ) def cz( - ctrls: ilist.IList[qubit.Qubit, Len], - qargs: ilist.IList[qubit.Qubit, Len], + controls: ilist.IList[qubit.Qubit, Len], + targets: ilist.IList[qubit.Qubit, Len], ): ... @lowering.wraps(R) def r( - inputs: ilist.IList[qubit.Qubit, typing.Any], axis_angle: float, rotation_angle: float, + qubits: ilist.IList[qubit.Qubit, typing.Any], ): ... @lowering.wraps(Rz) def rz( - inputs: ilist.IList[qubit.Qubit, typing.Any], rotation_angle: float, + qubits: ilist.IList[qubit.Qubit, typing.Any], ): ... diff --git a/src/bloqade/native/dialects/gates/stmts.py b/src/bloqade/native/dialects/gate/stmts.py similarity index 72% rename from src/bloqade/native/dialects/gates/stmts.py rename to src/bloqade/native/dialects/gate/stmts.py index 458af507..c73462e5 100644 --- a/src/bloqade/native/dialects/gates/stmts.py +++ b/src/bloqade/native/dialects/gate/stmts.py @@ -12,20 +12,20 @@ @statement(dialect=dialect) class CZ(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) - qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) + controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) + targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) @statement(dialect=dialect) class R(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) axis_angle: ir.SSAValue = info.argument(types.Float) rotation_angle: ir.SSAValue = info.argument(types.Float) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) @statement(dialect=dialect) class Rz(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) rotation_angle: ir.SSAValue = info.argument(types.Float) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) diff --git a/src/bloqade/native/dialects/gates/__init__.py b/src/bloqade/native/dialects/gates/__init__.py deleted file mode 100644 index 34d5f8df..00000000 --- a/src/bloqade/native/dialects/gates/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .stmts import CZ as CZ, R as R, Rz as Rz -from ._dialect import dialect as dialect -from ._interface import r as r, cz as cz, rz as rz diff --git a/src/bloqade/native/dialects/gates/_dialect.py b/src/bloqade/native/dialects/gates/_dialect.py deleted file mode 100644 index 7679809b..00000000 --- a/src/bloqade/native/dialects/gates/_dialect.py +++ /dev/null @@ -1,3 +0,0 @@ -from kirin import ir - -dialect = ir.Dialect("bloqade.native") diff --git a/src/bloqade/native/stdlib/broadcast.py b/src/bloqade/native/stdlib/broadcast.py index e5a10ccc..e8942a9a 100644 --- a/src/bloqade/native/stdlib/broadcast.py +++ b/src/bloqade/native/stdlib/broadcast.py @@ -5,7 +5,7 @@ from bloqade import qubit from bloqade.native._prelude import kernel -from bloqade.native.dialects.gates import _interface as native +from bloqade.native.dialects.gate import _interface as native @kernel @@ -29,7 +29,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.r(qubits, 0.0, _radian_to_turn(angle)) + native.r(0.0, _radian_to_turn(angle), qubits) @kernel @@ -70,7 +70,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.r(qubits, 0.25, _radian_to_turn(angle)) + native.r(0.25, _radian_to_turn(angle), qubits) @kernel @@ -111,7 +111,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]): angle (float): Rotation angle in radians. qubits (ilist.IList[qubit.Qubit, Any]): Target qubits. """ - native.rz(qubits, _radian_to_turn(angle)) + native.rz(_radian_to_turn(angle), qubits) @kernel diff --git a/src/bloqade/pyqrack/native.py b/src/bloqade/pyqrack/native.py index c68ef8ac..88b26119 100644 --- a/src/bloqade/pyqrack/native.py +++ b/src/bloqade/pyqrack/native.py @@ -7,29 +7,29 @@ from pyqrack import Pauli from bloqade.pyqrack import PyQrackQubit from bloqade.pyqrack.base import PyQrackInterpreter -from bloqade.native.dialects import gates +from bloqade.native.dialects.gate import stmts -@gates.dialect.register(key="pyqrack") +@stmts.dialect.register(key="pyqrack") class NativeMethods(interp.MethodTable): - @interp.impl(gates.CZ) - def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.CZ): - ctrls = frame.get_casted(stmt.ctrls, ilist.IList[PyQrackQubit, Any]) - qargs = frame.get_casted(stmt.qargs, ilist.IList[PyQrackQubit, Any]) + @interp.impl(stmts.CZ) + def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ): + controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any]) + targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any]) - for ctrl, qarg in zip(ctrls, qargs): - if ctrl.is_active() and qarg.is_active(): - ctrl.sim_reg.mcz([ctrl.addr], qarg.addr) + for ctrl, trgt in zip(controls, targets): + if ctrl.is_active() and trgt.is_active(): + ctrl.sim_reg.mcz([ctrl.addr], trgt.addr) return () - @interp.impl(gates.R) - def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.R): - inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any]) + @interp.impl(stmts.R) + def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R): + qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any]) rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float) axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float) - for qubit in inputs: + for qubit in qubits: if qubit.is_active(): qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr) qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr) @@ -37,12 +37,12 @@ def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.R): return () - @interp.impl(gates.Rz) - def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.Rz): - inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any]) + @interp.impl(stmts.Rz) + def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz): + qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any]) rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float) - for qubit in inputs: + for qubit in qubits: if qubit.is_active(): qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr) diff --git a/test/native/upstream/test_squin2native.py b/test/native/upstream/test_squin2native.py index 4a823905..7d250e67 100644 --- a/test/native/upstream/test_squin2native.py +++ b/test/native/upstream/test_squin2native.py @@ -5,7 +5,7 @@ from bloqade import squin from bloqade.squin import gate from bloqade.pyqrack import StackMemorySimulator -from bloqade.native.dialects import gates +from bloqade.native.dialects import gate as native_gate from bloqade.native.upstream import GateRule, SquinToNative @@ -33,14 +33,14 @@ def main(): new_main = SquinToNative().emit(main, no_raise=True) new_callgraph = callgraph.CallGraph(new_main) - # make sure all kernels have been converted to native gates + # make sure all kernels have been converted to native gate all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers) for ker in all_kernels: assert gate.dialect not in ker.dialects - assert gates.dialect in ker.dialects + assert native_gate.dialect in ker.dialects # test to make sure the statevectors are the same - # before and after conversion to native gates + # before and after conversion to native gate old_sv = np.asarray(StackMemorySimulator(min_qubits=n).state_vector(main)) old_sv /= old_sv[imax := np.abs(old_sv).argmax()] / np.abs(old_sv[imax])