Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/bloqade/native/_prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from bloqade import qubit

from .dialects import gates
from .dialects import gate


@ir.dialect_group(structural_no_opt.union([gates, qubit]))
@ir.dialect_group(structural_no_opt.union([gate, qubit]))
def kernel(self):
"""Compile a function to a native kernel."""

Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/native/dialects/gate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import stmts as stmts
from ._dialect import dialect as dialect
3 changes: 3 additions & 0 deletions src/bloqade/native/dialects/gate/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("native.gate")
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@

@lowering.wraps(CZ)
def cz(
ctrls: ilist.IList[qubit.Qubit, Len],
qargs: ilist.IList[qubit.Qubit, Len],
controls: ilist.IList[qubit.Qubit, Len],
targets: ilist.IList[qubit.Qubit, Len],
): ...


@lowering.wraps(R)
def r(
inputs: ilist.IList[qubit.Qubit, typing.Any],
axis_angle: float,
rotation_angle: float,
qubits: ilist.IList[qubit.Qubit, typing.Any],
): ...


@lowering.wraps(Rz)
def rz(
inputs: ilist.IList[qubit.Qubit, typing.Any],
rotation_angle: float,
qubits: ilist.IList[qubit.Qubit, typing.Any],
): ...
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@
@statement(dialect=dialect)
class CZ(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])


@statement(dialect=dialect)
class R(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
axis_angle: ir.SSAValue = info.argument(types.Float)
rotation_angle: ir.SSAValue = info.argument(types.Float)
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


@statement(dialect=dialect)
class Rz(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
rotation_angle: ir.SSAValue = info.argument(types.Float)
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
3 changes: 0 additions & 3 deletions src/bloqade/native/dialects/gates/__init__.py

This file was deleted.

3 changes: 0 additions & 3 deletions src/bloqade/native/dialects/gates/_dialect.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/bloqade/native/stdlib/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bloqade import qubit
from bloqade.native._prelude import kernel
from bloqade.native.dialects.gates import _interface as native
from bloqade.native.dialects.gate import _interface as native


@kernel
Expand All @@ -29,7 +29,7 @@ def rx(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
angle (float): Rotation angle in radians.
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
"""
native.r(qubits, 0.0, _radian_to_turn(angle))
native.r(0.0, _radian_to_turn(angle), qubits)


@kernel
Expand Down Expand Up @@ -70,7 +70,7 @@ def ry(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
angle (float): Rotation angle in radians.
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
"""
native.r(qubits, 0.25, _radian_to_turn(angle))
native.r(0.25, _radian_to_turn(angle), qubits)


@kernel
Expand Down Expand Up @@ -111,7 +111,7 @@ def rz(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
angle (float): Rotation angle in radians.
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
"""
native.rz(qubits, _radian_to_turn(angle))
native.rz(_radian_to_turn(angle), qubits)


@kernel
Expand Down
34 changes: 17 additions & 17 deletions src/bloqade/pyqrack/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,42 @@
from pyqrack import Pauli
from bloqade.pyqrack import PyQrackQubit
from bloqade.pyqrack.base import PyQrackInterpreter
from bloqade.native.dialects import gates
from bloqade.native.dialects.gate import stmts


@gates.dialect.register(key="pyqrack")
@stmts.dialect.register(key="pyqrack")
class NativeMethods(interp.MethodTable):

@interp.impl(gates.CZ)
def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.CZ):
ctrls = frame.get_casted(stmt.ctrls, ilist.IList[PyQrackQubit, Any])
qargs = frame.get_casted(stmt.qargs, ilist.IList[PyQrackQubit, Any])
@interp.impl(stmts.CZ)
def cz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.CZ):
controls = frame.get_casted(stmt.controls, ilist.IList[PyQrackQubit, Any])
targets = frame.get_casted(stmt.targets, ilist.IList[PyQrackQubit, Any])

for ctrl, qarg in zip(ctrls, qargs):
if ctrl.is_active() and qarg.is_active():
ctrl.sim_reg.mcz([ctrl.addr], qarg.addr)
for ctrl, trgt in zip(controls, targets):
if ctrl.is_active() and trgt.is_active():
ctrl.sim_reg.mcz([ctrl.addr], trgt.addr)

return ()

@interp.impl(gates.R)
def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.R):
inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
@interp.impl(stmts.R)
def r(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.R):
qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)
axis_angle = 2 * math.pi * frame.get_casted(stmt.axis_angle, float)
for qubit in inputs:
for qubit in qubits:
if qubit.is_active():
qubit.sim_reg.r(Pauli.PauliZ, axis_angle, qubit.addr)
qubit.sim_reg.r(Pauli.PauliX, rotation_angle, qubit.addr)
qubit.sim_reg.r(Pauli.PauliZ, -axis_angle, qubit.addr)

return ()

@interp.impl(gates.Rz)
def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: gates.Rz):
inputs = frame.get_casted(stmt.inputs, ilist.IList[PyQrackQubit, Any])
@interp.impl(stmts.Rz)
def rz(self, _interp: PyQrackInterpreter, frame: interp.Frame, stmt: stmts.Rz):
qubits = frame.get_casted(stmt.qubits, ilist.IList[PyQrackQubit, Any])
rotation_angle = 2 * math.pi * frame.get_casted(stmt.rotation_angle, float)

for qubit in inputs:
for qubit in qubits:
if qubit.is_active():
qubit.sim_reg.r(Pauli.PauliZ, rotation_angle, qubit.addr)

Expand Down
8 changes: 4 additions & 4 deletions test/native/upstream/test_squin2native.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from bloqade import squin
from bloqade.squin import gate
from bloqade.pyqrack import StackMemorySimulator
from bloqade.native.dialects import gates
from bloqade.native.dialects import gate as native_gate
from bloqade.native.upstream import GateRule, SquinToNative


Expand Down Expand Up @@ -33,14 +33,14 @@ def main():
new_main = SquinToNative().emit(main, no_raise=True)

new_callgraph = callgraph.CallGraph(new_main)
# make sure all kernels have been converted to native gates
# make sure all kernels have been converted to native gate
all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
for ker in all_kernels:
assert gate.dialect not in ker.dialects
assert gates.dialect in ker.dialects
assert native_gate.dialect in ker.dialects

# test to make sure the statevectors are the same
# before and after conversion to native gates
# before and after conversion to native gate
old_sv = np.asarray(StackMemorySimulator(min_qubits=n).state_vector(main))
old_sv /= old_sv[imax := np.abs(old_sv).argmax()] / np.abs(old_sv[imax])

Expand Down