Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 12 additions & 28 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TypeVar
from dataclasses import field

from kirin import ir, interp
from kirin.analysis import Forward, const
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame

from bloqade.types import QubitType
Expand All @@ -15,7 +14,7 @@ class AddressAnalysis(Forward[Address]):
This analysis pass can be used to track the global addresses of qubits and wires.
"""

keys = ["qubit.address"]
keys = ("qubit.address",)
lattice = Address
next_address: int = field(init=False)

Expand All @@ -24,37 +23,22 @@ def initialize(self):
self.next_address: int = 0
return self

@property
def qubit_count(self) -> int:
"""Total number of qubits found by the analysis."""
return self.next_address

T = TypeVar("T")

def get_const_value(self, typ: type[T], value: ir.SSAValue) -> T:
if isinstance(hint := value.hints.get("const"), const.Value):
data = hint.data
if isinstance(data, typ):
return hint.data
raise interp.InterpreterError(
f"Expected constant value <type = {typ}>, got {data}"
)
raise interp.InterpreterError(
f"Expected constant value <type = {typ}>, got {value}"
)
def method_self(self, method: ir.Method) -> Address:
return self.lattice.bottom()

def eval_stmt_fallback(
self, frame: ForwardFrame[Address], stmt: ir.Statement
) -> tuple[Address, ...] | interp.SpecialValue[Address]:
def eval_fallback(
self, frame: ForwardFrame[Address], node: ir.Statement
) -> interp.StatementResult[Address]:
return tuple(
(
self.lattice.top()
if result.type.is_subseteq(QubitType)
else self.lattice.bottom()
)
for result in stmt.results
for result in node.results
)

def run_method(self, method: ir.Method, args: tuple[Address, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
@property
def qubit_count(self) -> int:
"""Total number of qubits found by the analysis."""
return self.next_address
19 changes: 6 additions & 13 deletions src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from kirin import interp
from kirin.analysis import ForwardFrame, const
from kirin.dialects import cf, py, scf, func, ilist
from kirin.dialects import py, scf, func, ilist

from bloqade import squin

Expand Down Expand Up @@ -74,7 +74,7 @@ class PyIndexing(interp.MethodTable):
@interp.impl(py.GetItem)
def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
# Integer index into the thing being indexed
idx = interp.get_const_value(int, stmt.index)
idx = interp.expect_const(stmt.index, int)
# The object being indexed into
obj = frame.get(stmt.obj)
# The `data` attributes holds onto other Address types
Expand Down Expand Up @@ -116,15 +116,8 @@ def invoke(self, interp_: AddressAnalysis, frame: interp.Frame, stmt: func.Invok
# TODO: support lambda?


@cf.dialect.register(key="qubit.address")
class Cf(cf.typeinfer.TypeInfer):
# NOTE: cf just re-use the type infer method table
# it's the same process as type infer.
pass


@scf.dialect.register(key="qubit.address")
class Scf(scf.absint.Methods):
class Scf(interp.MethodTable):

@interp.impl(scf.For)
def for_loop(
Expand All @@ -134,7 +127,7 @@ def for_loop(
stmt: scf.For,
):
if not isinstance(hint := stmt.iterable.hints.get("const"), const.Value):
return interp_.eval_stmt_fallback(frame, stmt)
return interp_.eval_fallback(frame, stmt)

iterable = hint.data
loop_vars = frame.get_values(stmt.initializers)
Expand All @@ -144,7 +137,7 @@ def for_loop(
# NOTE: we need to actually run iteration in case there are
# new allocations/re-assign in the loop body.
for _ in iterable:
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
with interp_.new_frame(stmt) as body_frame:
body_frame.entries.update(frame.entries)
body_frame.set_values(
block_args,
Expand Down Expand Up @@ -218,7 +211,7 @@ def new(
frame: ForwardFrame[Address],
stmt: squin.qubit.New,
):
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
n_qubits = interp_.expect_const(stmt.n_qubits, int)
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
interp_.next_address += n_qubits
return (addr,)
4 changes: 2 additions & 2 deletions src/bloqade/qasm2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from bloqade.types import Qubit as Qubit, QubitType as QubitType

Check failure on line 1 in src/bloqade/qasm2/__init__.py

View workflow job for this annotation

GitHub Actions / build

Imports are incorrectly sorted and/or formatted.

from . import (
emit as emit,
glob as glob,
parse as parse,
dialects as dialects,
parallel as parallel,
emit as emit,
)
from .types import (
Bit as Bit,
Expand All @@ -15,6 +15,6 @@
CRegType as CRegType,
QRegType as QRegType,
)
from .groups import gate as gate, main as main, extended as extended
from .groups import main as main, extended as extended
from ._wrappers import * # noqa: F403
from ._qasm_loading import loads as loads, loadfile as loadfile
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import _emit as _emit, address as address, _typeinfer as _typeinfer
from . import address as address, _typeinfer as _typeinfer, _emit as _emit

Check failure on line 1 in src/bloqade/qasm2/dialects/core/__init__.py

View workflow job for this annotation

GitHub Actions / build

Imports are incorrectly sorted and/or formatted.
from .stmts import * # noqa: F403
from ._dialect import dialect as dialect
50 changes: 22 additions & 28 deletions src/bloqade/qasm2/dialects/core/_emit.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,62 @@
from __future__ import annotations

from kirin import interp

from bloqade.qasm2.emit import QASM2, Frame
from bloqade.qasm2.parse import ast
from bloqade.qasm2.emit.main import EmitQASM2Main, EmitQASM2Frame

from . import stmts
from ._dialect import dialect


@dialect.register(key="emit.qasm2.main")
@dialect.register(key="emit.qasm2")
class Core(interp.MethodTable):

@interp.impl(stmts.CRegNew)
def emit_creg_new(
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegNew
):
n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_bits))
def emit_creg_new(self, emit: QASM2, frame: Frame, stmt: stmts.CRegNew):
n_bits = frame.get_casted(stmt.n_bits, ast.Number)
# check if its int first, because Int.is_integer() is added for >=3.12
assert isinstance(n_bits.value, int), "expected integer"
name = emit.ssa_id[stmt.result]
name = frame.ssa[stmt.result]
frame.body.append(ast.CReg(name=name, size=int(n_bits.value)))
return (ast.Name(name),)

@interp.impl(stmts.QRegNew)
def emit_qreg_new(
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.QRegNew
):
n_bits = emit.assert_node(ast.Number, frame.get(stmt.n_qubits))
def emit_qreg_new(self, emit: QASM2, frame: Frame, stmt: stmts.QRegNew):
n_bits = frame.get_casted(stmt.n_qubits, ast.Number)
assert isinstance(n_bits.value, int), "expected integer"
name = emit.ssa_id[stmt.result]
name = frame.ssa[stmt.result]
frame.body.append(ast.QReg(name=name, size=int(n_bits.value)))
return (ast.Name(name),)

@interp.impl(stmts.Reset)
def emit_reset(self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Reset):
qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg))
def emit_reset(self, emit: QASM2, frame: Frame, stmt: stmts.Reset):
qarg: ast.Name | ast.Bit = frame.get(stmt.qarg) # type: ignore
frame.body.append(ast.Reset(qarg=qarg))
return ()

@interp.impl(stmts.Measure)
def emit_measure(
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.Measure
):
qarg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.qarg))
carg = emit.assert_node((ast.Bit, ast.Name), frame.get(stmt.carg))
def emit_measure(self, emit: QASM2, frame: Frame, stmt: stmts.Measure):
qarg: ast.Bit | ast.Name = frame.get(stmt.qarg) # type: ignore
carg: ast.Name | ast.Bit = frame.get(stmt.carg) # type: ignore
frame.body.append(ast.Measure(qarg=qarg, carg=carg))
return ()

@interp.impl(stmts.CRegEq)
def emit_creg_eq(
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: stmts.CRegEq
):
lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs))
rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs))
def emit_creg_eq(self, emit: QASM2, frame: Frame, stmt: stmts.CRegEq):
lhs = frame.get_casted(stmt.lhs, ast.Expr)
rhs = frame.get_casted(stmt.rhs, ast.Expr)
return (ast.Cmp(lhs=lhs, rhs=rhs),)

@interp.impl(stmts.CRegGet)
@interp.impl(stmts.QRegGet)
def emit_qreg_get(
self,
emit: EmitQASM2Main,
frame: EmitQASM2Frame,
emit: QASM2,
frame: Frame,
stmt: stmts.QRegGet | stmts.CRegGet,
):
reg = emit.assert_node(ast.Name, frame.get(stmt.reg))
idx = emit.assert_node(ast.Number, frame.get(stmt.idx))
reg = frame.get_casted(stmt.reg, ast.Name)
idx = frame.get_casted(stmt.idx, ast.Number)
assert isinstance(idx.value, int)
return (ast.Bit(reg, int(idx.value)),)
4 changes: 2 additions & 2 deletions src/bloqade/qasm2/dialects/core/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def new(
frame: interp.Frame[Address],
stmt: QRegNew,
):
n_qubits = interp.get_const_value(int, stmt.n_qubits)
n_qubits = interp.expect_const(stmt.n_qubits, int)
addr = AddressReg(range(interp.next_address, interp.next_address + n_qubits))
interp.next_address += n_qubits
return (addr,)

@interp.impl(QRegGet)
def get(self, interp: AddressAnalysis, frame: interp.Frame[Address], stmt: QRegGet):
addr = frame.get(stmt.reg)
pos = interp.get_const_value(int, stmt.idx)
pos = interp.expect_const(stmt.idx, int)
if isinstance(addr, AddressReg):
global_idx = addr.data[pos]
return (AddressQubit(global_idx),)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import _emit as _emit, _interp as _interp, _from_python as _from_python
from . import _interp as _interp, _from_python as _from_python, _emit as _emit

Check failure on line 1 in src/bloqade/qasm2/dialects/expr/__init__.py

View workflow job for this annotation

GitHub Actions / build

Imports are incorrectly sorted and/or formatted.
from .stmts import * # noqa: F403
from ._dialect import dialect as dialect
65 changes: 20 additions & 45 deletions src/bloqade/qasm2/dialects/expr/_emit.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,40 @@
from __future__ import annotations

Check failure on line 1 in src/bloqade/qasm2/dialects/expr/_emit.py

View workflow job for this annotation

GitHub Actions / build

Imports are incorrectly sorted and/or formatted.
from typing import Literal

from kirin import interp

from bloqade.qasm2.parse import ast
from bloqade.qasm2.types import QubitType
from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
from bloqade.qasm2.emit import QASM2, Frame

from . import stmts
from ._dialect import dialect


@dialect.register(key="emit.qasm2.gate")
@dialect.register(key="emit.qasm2")
class EmitExpr(interp.MethodTable):

@interp.impl(stmts.GateFunction)
def emit_func(
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
):

args: list[ast.Node] = []
cparams, qparams = [], []
for arg in stmt.body.blocks[0].args:
assert arg.name is not None

args.append(ast.Name(id=arg.name))
if arg.type.is_subseteq(QubitType):
qparams.append(arg.name)
else:
cparams.append(arg.name)

emit.run_ssacfg_region(frame, stmt.body, tuple(args))
emit.output = ast.Gate(
name=stmt.sym_name,
cparams=cparams,
qparams=qparams,
body=frame.body,
)
return ()

@interp.impl(stmts.ConstInt)
@interp.impl(stmts.ConstFloat)
def emit_const_int(
self,
emit: EmitQASM2Gate,
frame: EmitQASM2Frame,
emit: QASM2,
frame: Frame,
stmt: stmts.ConstInt | stmts.ConstFloat,
):
return (ast.Number(stmt.value),)

@interp.impl(stmts.ConstPI)
def emit_const_pi(
self,
emit: EmitQASM2Gate,
frame: EmitQASM2Frame,
emit: QASM2,
frame: Frame,
stmt: stmts.ConstPI,
):
return (ast.Pi(),)

@interp.impl(stmts.Neg)
def emit_neg(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Neg):
arg = emit.assert_node(ast.Expr, frame.get(stmt.value))
def emit_neg(self, emit: QASM2, frame: Frame, stmt: stmts.Neg):
arg = frame.get_casted(stmt.value, ast.Expr)
return (ast.UnaryOp("-", arg),)

@interp.impl(stmts.Sin)
Expand All @@ -68,37 +43,37 @@
@interp.impl(stmts.Exp)
@interp.impl(stmts.Log)
@interp.impl(stmts.Sqrt)
def emit_sin(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
arg = emit.assert_node(ast.Expr, frame.get(stmt.value))
def emit_sin(self, emit: QASM2, frame: Frame, stmt):
arg = frame.get_casted(stmt.value, ast.Expr)
return (ast.Call(stmt.name, [arg]),)

def emit_binop(
self,
sym: Literal["+", "-", "*", "/", "^"],
emit: EmitQASM2Gate,
frame: EmitQASM2Frame,
emit: QASM2,
frame: Frame,
stmt,
):
lhs = emit.assert_node(ast.Expr, frame.get(stmt.lhs))
rhs = emit.assert_node(ast.Expr, frame.get(stmt.rhs))
lhs = frame.get_casted(stmt.lhs, ast.Expr)
rhs = frame.get_casted(stmt.rhs, ast.Expr)
return (ast.BinOp(sym, lhs, rhs),)

@interp.impl(stmts.Add)
def emit_add(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
def emit_add(self, emit: QASM2, frame: Frame, stmt: stmts.Add):
return self.emit_binop("+", emit, frame, stmt)

@interp.impl(stmts.Sub)
def emit_sub(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
def emit_sub(self, emit: QASM2, frame: Frame, stmt: stmts.Add):
return self.emit_binop("-", emit, frame, stmt)

@interp.impl(stmts.Mul)
def emit_mul(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
def emit_mul(self, emit: QASM2, frame: Frame, stmt: stmts.Add):
return self.emit_binop("*", emit, frame, stmt)

@interp.impl(stmts.Div)
def emit_div(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
def emit_div(self, emit: QASM2, frame: Frame, stmt: stmts.Add):
return self.emit_binop("/", emit, frame, stmt)

@interp.impl(stmts.Pow)
def emit_pow(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.Add):
def emit_pow(self, emit: QASM2, frame: Frame, stmt: stmts.Add):
return self.emit_binop("^", emit, frame, stmt)
Loading
Loading