Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .groups import logical as logical
3 changes: 3 additions & 0 deletions src/bloqade/gemini/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .logical_validation.analysis import (
GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis,
)
1 change: 1 addition & 0 deletions src/bloqade/gemini/analysis/logical_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import impls as impls, analysis as analysis # NOTE: register methods
17 changes: 17 additions & 0 deletions src/bloqade/gemini/analysis/logical_validation/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from kirin import ir

from bloqade import squin
from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis


class GeminiLogicalValidationAnalysis(ValidationAnalysis):
keys = ["gemini.validate.logical"]

first_gate = True

def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
if isinstance(stmt, squin.gate.stmts.Gate):
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
self.first_gate = False

return super().eval_stmt_fallback(frame, stmt)
101 changes: 101 additions & 0 deletions src/bloqade/gemini/analysis/logical_validation/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from kirin import ir, interp as _interp
from kirin.analysis import const
from kirin.dialects import scf, func

from bloqade.squin import gate
from bloqade.validation.analysis import ValidationFrame
from bloqade.validation.analysis.lattice import Error

from .analysis import GeminiLogicalValidationAnalysis


@scf.dialect.register(key="gemini.validate.logical")
class __ScfGeminiLogicalValidation(_interp.MethodTable):

@_interp.impl(scf.IfElse)
def if_else(
self,
interp: GeminiLogicalValidationAnalysis,
frame: ValidationFrame,
stmt: scf.IfElse,
):
frame.errors.append(
ir.ValidationError(
stmt, "If statements are not supported in logical Gemini programs!"
)
)
return (
Error(
message="If statements are not supported in logical Gemini programs!"
),
)

@_interp.impl(scf.For)
def for_loop(
self,
interp: GeminiLogicalValidationAnalysis,
frame: ValidationFrame,
stmt: scf.For,
):
if isinstance(stmt.iterable.hints.get("const"), const.Value):
return (interp.lattice.top(),)

frame.errors.append(
ir.ValidationError(
stmt,
"Non-constant iterable in for loop is not supported in Gemini logical programs!",
)
)

return (
Error(
message="Non-constant iterable in for loop is not supported in Gemini logical programs!"
),
)


@func.dialect.register(key="gemini.validate.logical")
class __FuncGeminiLogicalValidation(_interp.MethodTable):
@_interp.impl(func.Invoke)
def invoke(
self,
interp: GeminiLogicalValidationAnalysis,
frame: ValidationFrame,
stmt: func.Invoke,
):
frame.errors.append(
ir.ValidationError(
stmt,
"Function invocations not supported in logical Gemini program!",
help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls",
)
)

return tuple(
Error(
message="Function invocations not supported in logical Gemini program!"
)
for _ in stmt.results
)


@gate.dialect.register(key="gemini.validate.logical")
class __GateGeminiLogicalValidation(_interp.MethodTable):
@_interp.impl(gate.stmts.U3)
def u3(
self,
interp: GeminiLogicalValidationAnalysis,
frame: ValidationFrame,
stmt: gate.stmts.U3,
):
if interp.first_gate:
interp.first_gate = False
return ()

frame.errors.append(
ir.ValidationError(
stmt,
"U3 gate can only be used for initial state preparation, i.e. as the first gate!",
)
)
return ()
67 changes: 67 additions & 0 deletions src/bloqade/gemini/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Annotated

from kirin import ir
from kirin.passes import Default
from kirin.prelude import structural_no_opt
from kirin.dialects import py, func, ilist
from typing_extensions import Doc
from kirin.passes.inline import InlinePass

from bloqade.squin import gate, qubit
from bloqade.validation import KernelValidation
from bloqade.rewrite.passes import AggressiveUnroll

from .analysis import GeminiLogicalValidationAnalysis


@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
def logical(self):
"""Compile a function to a Gemini logical kernel."""

def run_pass(
mt,
*,
verify: Annotated[
bool, Doc("run `verify` before running passes, default is `True`")
] = True,
typeinfer: Annotated[
bool,
Doc("run type inference and apply the inferred type to IR, default `True`"),
] = True,
fold: Annotated[bool, Doc("run folding passes")] = True,
aggressive: Annotated[
bool, Doc("run aggressive folding passes if `fold=True`")
] = False,
inline: Annotated[bool, Doc("inline function calls, default `True`")] = True,
aggressive_unroll: Annotated[
bool,
Doc(
"Run aggressive inlining and unrolling pass on the IR, default `False`"
),
] = False,
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
) -> None:

if inline and not aggressive_unroll:
InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt)

if aggressive_unroll:
AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt)
else:
default_pass = Default(
self,
verify=verify,
fold=fold,
aggressive=aggressive,
typeinfer=typeinfer,
no_raise=no_raise,
)

default_pass.fixpoint(mt)

if verify:
validator = KernelValidation(GeminiLogicalValidationAnalysis)
validator.run(mt, no_raise=no_raise)
mt.verify()

return run_pass
14 changes: 10 additions & 4 deletions src/bloqade/squin/gate/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


@statement
class SingleQubitGate(ir.Statement):
class Gate(ir.Statement):
# NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
pass


@statement
class SingleQubitGate(Gate):
traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])

Expand Down Expand Up @@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate):


@statement
class RotationGate(ir.Statement):
class RotationGate(Gate):
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
traits = frozenset({lowering.FromPythonCall()})
angle: ir.SSAValue = info.argument(types.Float)
Expand All @@ -85,7 +91,7 @@ class Rz(RotationGate):


@statement
class ControlledGate(ir.Statement):
class ControlledGate(Gate):
traits = frozenset({lowering.FromPythonCall()})
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
Expand All @@ -110,7 +116,7 @@ class CZ(ControlledGate):


@statement(dialect=dialect)
class U3(ir.Statement):
class U3(Gate):
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
traits = frozenset({lowering.FromPythonCall()})
theta: ir.SSAValue = info.argument(types.Float)
Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import analysis as analysis
from .kernel_validation import KernelValidation as KernelValidation
5 changes: 5 additions & 0 deletions src/bloqade/validation/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from . import lattice as lattice
from .analysis import (
ValidationFrame as ValidationFrame,
ValidationAnalysis as ValidationAnalysis,
)
41 changes: 41 additions & 0 deletions src/bloqade/validation/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC
from dataclasses import field, dataclass

from kirin import ir
from kirin.analysis import ForwardExtra, ForwardFrame

from .lattice import ErrorType


@dataclass
class ValidationFrame(ForwardFrame[ErrorType]):
# NOTE: cannot be set[Error] since that's not hashable
errors: list[ir.ValidationError] = field(default_factory=list)
"""List of all ecnountered errors.

Append a `kirin.ir.ValidationError` to this list in the method implementation
in order for it to get picked up by the `KernelValidation` run.
"""


@dataclass
class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC):
"""Analysis pass that indicates errors in the IR according to the respective method tables.

If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form)
you'll need to inherit from this class.
"""

lattice = ErrorType

def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]):
return self.run_callable(method.code, (self.lattice.top(),) + args)

def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
# NOTE: default to no errors
return tuple(self.lattice.top() for _ in stmt.results)

def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
) -> ValidationFrame:
return ValidationFrame(code, has_parent_access=has_parent_access)
58 changes: 58 additions & 0 deletions src/bloqade/validation/analysis/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
IsSubsetEqMixin,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class ErrorType(
SimpleJoinMixin["ErrorType"],
SimpleMeetMixin["ErrorType"],
IsSubsetEqMixin["ErrorType"],
BoundedLattice["ErrorType"],
):

@classmethod
def bottom(cls) -> "ErrorType":
return InvalidErrorType()

@classmethod
def top(cls) -> "ErrorType":
return NoError()


@final
@dataclass
class InvalidErrorType(ErrorType, metaclass=SingletonMeta):
"""Bottom to represent when we encounter an error running the analysis.

When this is encountered, it means there might be an error, but we were unable to tell.
"""

pass


@final
@dataclass
class Error(ErrorType):
"""Indicates an error in the IR."""

message: str = ""
"""Optional error message to show in the IR.

NOTE: this is just to show a message when printing the IR. Actual errors
are collected by appending ir.ValidationError to the frame in the method
implementation.
"""


@final
@dataclass
class NoError(ErrorType, metaclass=SingletonMeta):
pass
Loading