diff --git a/src/bloqade/gemini/__init__.py b/src/bloqade/gemini/__init__.py new file mode 100644 index 00000000..c03aeeb7 --- /dev/null +++ b/src/bloqade/gemini/__init__.py @@ -0,0 +1 @@ +from .groups import logical as logical diff --git a/src/bloqade/gemini/analysis/__init__.py b/src/bloqade/gemini/analysis/__init__.py new file mode 100644 index 00000000..8c94d180 --- /dev/null +++ b/src/bloqade/gemini/analysis/__init__.py @@ -0,0 +1,3 @@ +from .logical_validation.analysis import ( + GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis, +) diff --git a/src/bloqade/gemini/analysis/logical_validation/__init__.py b/src/bloqade/gemini/analysis/logical_validation/__init__.py new file mode 100644 index 00000000..1b289d8c --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/__init__.py @@ -0,0 +1 @@ +from . import impls as impls, analysis as analysis # NOTE: register methods diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py new file mode 100644 index 00000000..14a03cbf --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -0,0 +1,17 @@ +from kirin import ir + +from bloqade import squin +from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis + + +class GeminiLogicalValidationAnalysis(ValidationAnalysis): + keys = ["gemini.validate.logical"] + + first_gate = True + + def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + if isinstance(stmt, squin.gate.stmts.Gate): + # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here + self.first_gate = False + + return super().eval_stmt_fallback(frame, stmt) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py new file mode 100644 index 00000000..cf4bd87b --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -0,0 +1,101 @@ +from kirin import ir, interp as _interp +from kirin.analysis import const +from kirin.dialects import scf, func + +from bloqade.squin import gate +from bloqade.validation.analysis import ValidationFrame +from bloqade.validation.analysis.lattice import Error + +from .analysis import GeminiLogicalValidationAnalysis + + +@scf.dialect.register(key="gemini.validate.logical") +class __ScfGeminiLogicalValidation(_interp.MethodTable): + + @_interp.impl(scf.IfElse) + def if_else( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: scf.IfElse, + ): + frame.errors.append( + ir.ValidationError( + stmt, "If statements are not supported in logical Gemini programs!" + ) + ) + return ( + Error( + message="If statements are not supported in logical Gemini programs!" + ), + ) + + @_interp.impl(scf.For) + def for_loop( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: scf.For, + ): + if isinstance(stmt.iterable.hints.get("const"), const.Value): + return (interp.lattice.top(),) + + frame.errors.append( + ir.ValidationError( + stmt, + "Non-constant iterable in for loop is not supported in Gemini logical programs!", + ) + ) + + return ( + Error( + message="Non-constant iterable in for loop is not supported in Gemini logical programs!" + ), + ) + + +@func.dialect.register(key="gemini.validate.logical") +class __FuncGeminiLogicalValidation(_interp.MethodTable): + @_interp.impl(func.Invoke) + def invoke( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: func.Invoke, + ): + frame.errors.append( + ir.ValidationError( + stmt, + "Function invocations not supported in logical Gemini program!", + help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls", + ) + ) + + return tuple( + Error( + message="Function invocations not supported in logical Gemini program!" + ) + for _ in stmt.results + ) + + +@gate.dialect.register(key="gemini.validate.logical") +class __GateGeminiLogicalValidation(_interp.MethodTable): + @_interp.impl(gate.stmts.U3) + def u3( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: gate.stmts.U3, + ): + if interp.first_gate: + interp.first_gate = False + return () + + frame.errors.append( + ir.ValidationError( + stmt, + "U3 gate can only be used for initial state preparation, i.e. as the first gate!", + ) + ) + return () diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py new file mode 100644 index 00000000..90441099 --- /dev/null +++ b/src/bloqade/gemini/groups.py @@ -0,0 +1,67 @@ +from typing import Annotated + +from kirin import ir +from kirin.passes import Default +from kirin.prelude import structural_no_opt +from kirin.dialects import py, func, ilist +from typing_extensions import Doc +from kirin.passes.inline import InlinePass + +from bloqade.squin import gate, qubit +from bloqade.validation import KernelValidation +from bloqade.rewrite.passes import AggressiveUnroll + +from .analysis import GeminiLogicalValidationAnalysis + + +@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist])) +def logical(self): + """Compile a function to a Gemini logical kernel.""" + + def run_pass( + mt, + *, + verify: Annotated[ + bool, Doc("run `verify` before running passes, default is `True`") + ] = True, + typeinfer: Annotated[ + bool, + Doc("run type inference and apply the inferred type to IR, default `True`"), + ] = True, + fold: Annotated[bool, Doc("run folding passes")] = True, + aggressive: Annotated[ + bool, Doc("run aggressive folding passes if `fold=True`") + ] = False, + inline: Annotated[bool, Doc("inline function calls, default `True`")] = True, + aggressive_unroll: Annotated[ + bool, + Doc( + "Run aggressive inlining and unrolling pass on the IR, default `False`" + ), + ] = False, + no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, + ) -> None: + + if inline and not aggressive_unroll: + InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt) + + if aggressive_unroll: + AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt) + else: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + + default_pass.fixpoint(mt) + + if verify: + validator = KernelValidation(GeminiLogicalValidationAnalysis) + validator.run(mt, no_raise=no_raise) + mt.verify() + + return run_pass diff --git a/src/bloqade/squin/gate/stmts.py b/src/bloqade/squin/gate/stmts.py index 960ae95d..afc91897 100644 --- a/src/bloqade/squin/gate/stmts.py +++ b/src/bloqade/squin/gate/stmts.py @@ -8,7 +8,13 @@ @statement -class SingleQubitGate(ir.Statement): +class Gate(ir.Statement): + # NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class + pass + + +@statement +class SingleQubitGate(Gate): traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) @@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate): @statement -class RotationGate(ir.Statement): +class RotationGate(Gate): # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg traits = frozenset({lowering.FromPythonCall()}) angle: ir.SSAValue = info.argument(types.Float) @@ -85,7 +91,7 @@ class Rz(RotationGate): @statement -class ControlledGate(ir.Statement): +class ControlledGate(Gate): traits = frozenset({lowering.FromPythonCall()}) controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) @@ -110,7 +116,7 @@ class CZ(ControlledGate): @statement(dialect=dialect) -class U3(ir.Statement): +class U3(Gate): # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg traits = frozenset({lowering.FromPythonCall()}) theta: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/validation/__init__.py b/src/bloqade/validation/__init__.py new file mode 100644 index 00000000..e0992c23 --- /dev/null +++ b/src/bloqade/validation/__init__.py @@ -0,0 +1,2 @@ +from . import analysis as analysis +from .kernel_validation import KernelValidation as KernelValidation diff --git a/src/bloqade/validation/analysis/__init__.py b/src/bloqade/validation/analysis/__init__.py new file mode 100644 index 00000000..fc9a0b49 --- /dev/null +++ b/src/bloqade/validation/analysis/__init__.py @@ -0,0 +1,5 @@ +from . import lattice as lattice +from .analysis import ( + ValidationFrame as ValidationFrame, + ValidationAnalysis as ValidationAnalysis, +) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py new file mode 100644 index 00000000..323cbd40 --- /dev/null +++ b/src/bloqade/validation/analysis/analysis.py @@ -0,0 +1,41 @@ +from abc import ABC +from dataclasses import field, dataclass + +from kirin import ir +from kirin.analysis import ForwardExtra, ForwardFrame + +from .lattice import ErrorType + + +@dataclass +class ValidationFrame(ForwardFrame[ErrorType]): + # NOTE: cannot be set[Error] since that's not hashable + errors: list[ir.ValidationError] = field(default_factory=list) + """List of all ecnountered errors. + + Append a `kirin.ir.ValidationError` to this list in the method implementation + in order for it to get picked up by the `KernelValidation` run. + """ + + +@dataclass +class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC): + """Analysis pass that indicates errors in the IR according to the respective method tables. + + If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form) + you'll need to inherit from this class. + """ + + lattice = ErrorType + + def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): + return self.run_callable(method.code, (self.lattice.top(),) + args) + + def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + # NOTE: default to no errors + return tuple(self.lattice.top() for _ in stmt.results) + + def initialize_frame( + self, code: ir.Statement, *, has_parent_access: bool = False + ) -> ValidationFrame: + return ValidationFrame(code, has_parent_access=has_parent_access) diff --git a/src/bloqade/validation/analysis/lattice.py b/src/bloqade/validation/analysis/lattice.py new file mode 100644 index 00000000..d4c46469 --- /dev/null +++ b/src/bloqade/validation/analysis/lattice.py @@ -0,0 +1,58 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + IsSubsetEqMixin, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class ErrorType( + SimpleJoinMixin["ErrorType"], + SimpleMeetMixin["ErrorType"], + IsSubsetEqMixin["ErrorType"], + BoundedLattice["ErrorType"], +): + + @classmethod + def bottom(cls) -> "ErrorType": + return InvalidErrorType() + + @classmethod + def top(cls) -> "ErrorType": + return NoError() + + +@final +@dataclass +class InvalidErrorType(ErrorType, metaclass=SingletonMeta): + """Bottom to represent when we encounter an error running the analysis. + + When this is encountered, it means there might be an error, but we were unable to tell. + """ + + pass + + +@final +@dataclass +class Error(ErrorType): + """Indicates an error in the IR.""" + + message: str = "" + """Optional error message to show in the IR. + + NOTE: this is just to show a message when printing the IR. Actual errors + are collected by appending ir.ValidationError to the frame in the method + implementation. + """ + + +@final +@dataclass +class NoError(ErrorType, metaclass=SingletonMeta): + pass diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py new file mode 100644 index 00000000..84159352 --- /dev/null +++ b/src/bloqade/validation/kernel_validation.py @@ -0,0 +1,63 @@ +import sys +from dataclasses import dataclass + +from kirin import ir, exception +from rich.console import Console + +from .analysis import ValidationAnalysis + + +class ValidationErrorGroup(BaseException): + def __init__(self, *args: object, errors=[]) -> None: + super().__init__(*args) + self.errors = errors + + +# TODO: this overrides kirin's exception handler and should be upstreamed +def exception_handler(exc_type, exc_value, exc_tb): + if issubclass(exc_type, ValidationErrorGroup): + console = Console(force_terminal=True) + for i, err in enumerate(exc_value.errors): + with console.capture() as capture: + console.print(f"==== Error {i} ====") + console.print(f"[bold red]{type(err).__name__}:[/bold red]", end="") + print(capture.get(), *err.args, file=sys.stderr) + if err.source: + print("Source Traceback:", file=sys.stderr) + print(err.hint(), file=sys.stderr, end="") + console.print("=" * 40) + console.print( + "[bold red]Kernel validation failed:[/bold red] There were multiple errors encountered during validation, see above" + ) + return + + return exception.exception_handler(exc_type, exc_value, exc_tb) + + +sys.excepthook = exception_handler + + +@dataclass +class KernelValidation: + """Validate a kernel according to a `ValidationAnalysis`. + + This is a simple wrapper around the analysis that runs the analysis, checks + the `ValidationFrame` for errors and throws them if there are any. + """ + + validation_analysis_cls: type[ValidationAnalysis] + """The analysis that you want to run in order to validate the kernel.""" + + def run(self, mt: ir.Method, **kwargs) -> None: + validation_analysis = self.validation_analysis_cls(mt.dialects) + validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) + + errors = validation_frame.errors + + if len(errors) == 0: + # Valid program + return + elif len(errors) == 1: + raise errors[0] + else: + raise ValidationErrorGroup(errors=errors) diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py new file mode 100644 index 00000000..ce8e1a34 --- /dev/null +++ b/test/gemini/test_logical_validation.py @@ -0,0 +1,135 @@ +import pytest +from kirin import ir + +from bloqade import squin, gemini +from bloqade.types import Qubit +from bloqade.validation import KernelValidation +from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis +from bloqade.validation.kernel_validation import ValidationErrorGroup + + +def test_if_stmt_invalid(): + @gemini.logical(verify=False) + def main(): + q = squin.qalloc(3) + + squin.h(q[0]) + + for i in range(10): + squin.x(q[1]) + + m = squin.qubit.measure(q[1]) + + q2 = squin.qalloc(5) + squin.x(q2[0]) + + if m: + squin.x(q[1]) + + m2 = squin.qubit.measure(q[2]) + if m2: + squin.y(q[2]) + + frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis( + main, no_raise=False + ) + + main.print(analysis=frame.entries) + + validator = KernelValidation(GeminiLogicalValidationAnalysis) + + with pytest.raises(ValidationErrorGroup): + validator.run(main) + + +def test_for_loop(): + + @gemini.logical + def valid_loop(): + q = squin.qalloc(3) + + for i in range(3): + squin.x(q[i]) + + valid_loop.print() + + with pytest.raises(ir.ValidationError): + + @gemini.logical + def invalid_loop(n: int): + q = squin.qalloc(3) + + for i in range(n): + squin.x(q[i]) + + invalid_loop.print() + + +def test_func(): + @gemini.logical + def sub_kernel(q: Qubit): + squin.x(q) + + @gemini.logical + def main(): + q = squin.qalloc(3) + sub_kernel(q[0]) + + main.print() + + with pytest.raises(ValidationErrorGroup): + + @gemini.logical(inline=False) + def invalid(): + q = squin.qalloc(3) + sub_kernel(q[0]) + + +def test_clifford_gates(): + @gemini.logical + def main(): + q = squin.qalloc(2) + squin.u3(0.123, 0.253, 1.2, q[0]) + + squin.h(q[0]) + squin.cx(q[0], q[1]) + + with pytest.raises(ir.ValidationError): + + @gemini.logical(no_raise=False) + def invalid(): + q = squin.qalloc(2) + + squin.h(q[0]) + squin.cx(q[0], q[1]) + squin.u3(0.123, 0.253, 1.2, q[0]) + + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis( + invalid, no_raise=False + ) + + invalid.print(analysis=frame.entries) + + +def test_multiple_errors(): + did_error = False + try: + + @gemini.logical + def main(n: int): + q = squin.qalloc(3) + m = squin.qubit.measure(q[0]) + squin.x(q[1]) + if m: + squin.x(q[0]) + + for k in range(n): + squin.h(q[k]) + + squin.u3(0.1, 0.2, 0.3, q[1]) + + except ValidationErrorGroup as e: + did_error = True + assert len(e.errors) == 3 + + assert did_error