From f2ac907438674b6cd45a488cd19eeea6f172fe7d Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 16 Oct 2025 09:42:18 +0200 Subject: [PATCH 01/12] Gemini logical dialect group --- src/bloqade/gemini/__init__.py | 1 + src/bloqade/gemini/groups.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 src/bloqade/gemini/__init__.py create mode 100644 src/bloqade/gemini/groups.py diff --git a/src/bloqade/gemini/__init__.py b/src/bloqade/gemini/__init__.py new file mode 100644 index 00000000..c03aeeb7 --- /dev/null +++ b/src/bloqade/gemini/__init__.py @@ -0,0 +1 @@ +from .groups import logical as logical diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py new file mode 100644 index 00000000..7a9e9d5b --- /dev/null +++ b/src/bloqade/gemini/groups.py @@ -0,0 +1,21 @@ +from kirin import ir +from kirin.prelude import structural_no_opt +from kirin.dialects import py, func, ilist + +from bloqade.squin import gate, qubit + +# from .passes import ValidateGeminiLogical +# from .analysis import GeminiLogicalValidationAnalysis +# from .analysis.logical_validation.analysis import ValidateInterpreter + + +@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist])) +def logical(self): + + def run_pass(mt: ir.Method, *, validate=True): + if validate: + # GeminiLogicalValidationAnalysis(mt.dialects).run_analysis(mt, no_raise=False) + # ValidateInterpreter(mt.dialects).run(mt, ()) + mt.verify() + + return run_pass From 83c4bfa72af878365c54cc23ad3062a2dabd41b0 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 21 Oct 2025 15:30:29 +0200 Subject: [PATCH 02/12] Draft implementation of validation --- src/bloqade/gemini/analysis/__init__.py | 3 + .../analysis/logical_validation/__init__.py | 1 + .../analysis/logical_validation/analysis.py | 17 +++++ .../analysis/logical_validation/impls.py | 71 +++++++++++++++++++ .../analysis/logical_validation/lattice.py | 60 ++++++++++++++++ src/bloqade/gemini/groups.py | 45 +++++++++--- src/bloqade/gemini/validation/__init__.py | 0 src/bloqade/gemini/validation/logical.py | 37 ++++++++++ test/gemini/test_logical.py | 40 +++++++++++ 9 files changed, 266 insertions(+), 8 deletions(-) create mode 100644 src/bloqade/gemini/analysis/__init__.py create mode 100644 src/bloqade/gemini/analysis/logical_validation/__init__.py create mode 100644 src/bloqade/gemini/analysis/logical_validation/analysis.py create mode 100644 src/bloqade/gemini/analysis/logical_validation/impls.py create mode 100644 src/bloqade/gemini/analysis/logical_validation/lattice.py create mode 100644 src/bloqade/gemini/validation/__init__.py create mode 100644 src/bloqade/gemini/validation/logical.py create mode 100644 test/gemini/test_logical.py diff --git a/src/bloqade/gemini/analysis/__init__.py b/src/bloqade/gemini/analysis/__init__.py new file mode 100644 index 00000000..8c94d180 --- /dev/null +++ b/src/bloqade/gemini/analysis/__init__.py @@ -0,0 +1,3 @@ +from .logical_validation.analysis import ( + GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis, +) diff --git a/src/bloqade/gemini/analysis/logical_validation/__init__.py b/src/bloqade/gemini/analysis/logical_validation/__init__.py new file mode 100644 index 00000000..1b289d8c --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/__init__.py @@ -0,0 +1 @@ +from . import impls as impls, analysis as analysis # NOTE: register methods diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py new file mode 100644 index 00000000..b4faf46d --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -0,0 +1,17 @@ +from kirin import ir +from kirin.analysis import Forward, ForwardFrame + +from .lattice import ErrorType + + +class GeminiLogicalValidationAnalysis(Forward[ErrorType]): + keys = ["gemini.validate.logical"] + lattice = ErrorType + + has_allocated_qubits: bool = False + + def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + def eval_stmt_fallback(self, frame: ForwardFrame[ErrorType], stmt: ir.Statement): + return (self.lattice.top(),) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py new file mode 100644 index 00000000..debea699 --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -0,0 +1,71 @@ +from kirin import ir, interp as _interp +from kirin.analysis import ForwardFrame, const +from kirin.dialects import scf + +from bloqade.squin import qubit + +from .lattice import Error +from .analysis import GeminiLogicalValidationAnalysis + + +@qubit.dialect.register(key="gemini.validate.logical") +class __QubitGeminiLogicalValidation(_interp.MethodTable): + + @_interp.impl(qubit.New) + def new( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ForwardFrame, + stmt: qubit.New, + ): + # TODO: this is actually tricky, since qalloc calls qubit.new multiple times and we have to make sure qalloc is only called once + # but it can technically contain many qubit.new calls + # if interp.has_allocated_qubits: + # raise ir.ValidationError( + # stmt, "Can only allocate qubits once in a logical Gemini program!" + # ) + + # interp.has_allocated_qubits = True + + pass + + +@scf.dialect.register(key="gemini.validate.logical") +class __ScfGeminiLogicalValidation(_interp.MethodTable): + + @_interp.impl(scf.IfElse) + def if_else( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ForwardFrame, + stmt: scf.IfElse, + ): + # raise ir.ValidationError( + # stmt, "if statements are not supported in logical Gemini programs!" + # ) + return ( + Error( + ir.ValidationError( + stmt, "if statements are not supported in logical Gemini programs!" + ) + ), + ) + + @_interp.impl(scf.For) + def for_loop( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ForwardFrame, + stmt: scf.For, + ): + if isinstance(stmt.iterable.hints.get("const"), const.Value): + return (interp.lattice.top(),) + + return ( + Error( + ir.ValidationError( + stmt, + "Non-constant iterable in for loop is not supported in Gemini logical programs!", + ) + ), + ) diff --git a/src/bloqade/gemini/analysis/logical_validation/lattice.py b/src/bloqade/gemini/analysis/logical_validation/lattice.py new file mode 100644 index 00000000..a166e734 --- /dev/null +++ b/src/bloqade/gemini/analysis/logical_validation/lattice.py @@ -0,0 +1,60 @@ +from typing import final +from dataclasses import dataclass + +from kirin import ir +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + IsSubsetEqMixin, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class ErrorType( + SimpleJoinMixin["ErrorType"], + SimpleMeetMixin["ErrorType"], + IsSubsetEqMixin["ErrorType"], + BoundedLattice["ErrorType"], +): + + @classmethod + def bottom(cls) -> "ErrorType": + return InvalidErrorType() + + @classmethod + def top(cls) -> "ErrorType": + return NoError() + + +@final +@dataclass +class InvalidErrorType(ErrorType, metaclass=SingletonMeta): + """Bottom to represent when we encounter an error running the analysis. + + When this is encountered, it means there might be an error, but we were unable to tell. + """ + + pass + + +# @final +# @dataclass +# class AnyErrorType(ErrorType, metaclass=SingletonMeta): +# """Top to indicate that there was an error, but we can't really tell where""" +# pass + + +@final +@dataclass +class Error(ErrorType): + """We found an error, here's a hopefully helpful message.""" + + error: ir.ValidationError + + +@final +@dataclass +class NoError(ErrorType, metaclass=SingletonMeta): + pass diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py index 7a9e9d5b..020b18f0 100644 --- a/src/bloqade/gemini/groups.py +++ b/src/bloqade/gemini/groups.py @@ -1,21 +1,50 @@ +from typing import Annotated + from kirin import ir +from kirin.passes import Default from kirin.prelude import structural_no_opt from kirin.dialects import py, func, ilist +from typing_extensions import Doc from bloqade.squin import gate, qubit -# from .passes import ValidateGeminiLogical -# from .analysis import GeminiLogicalValidationAnalysis -# from .analysis.logical_validation.analysis import ValidateInterpreter +from .analysis import GeminiLogicalValidationAnalysis +from .validation.logical import KernelValidation @ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist])) def logical(self): + """Compile a function to a Gemini logical kernel.""" + + def run_pass( + mt, + *, + verify: Annotated[ + bool, Doc("run `verify` before running passes, default is `True`") + ] = True, + typeinfer: Annotated[ + bool, + Doc("run type inference and apply the inferred type to IR, default `True`"), + ] = True, + fold: Annotated[bool, Doc("run folding passes")] = True, + aggressive: Annotated[ + bool, Doc("run aggressive folding passes if `fold=True`") + ] = False, + no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, + ) -> None: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + + default_pass.fixpoint(mt) - def run_pass(mt: ir.Method, *, validate=True): - if validate: - # GeminiLogicalValidationAnalysis(mt.dialects).run_analysis(mt, no_raise=False) - # ValidateInterpreter(mt.dialects).run(mt, ()) - mt.verify() + if verify: + validator = KernelValidation(GeminiLogicalValidationAnalysis) + validator.run(mt, no_raise=no_raise) return run_pass diff --git a/src/bloqade/gemini/validation/__init__.py b/src/bloqade/gemini/validation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/gemini/validation/logical.py b/src/bloqade/gemini/validation/logical.py new file mode 100644 index 00000000..afd04082 --- /dev/null +++ b/src/bloqade/gemini/validation/logical.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass + +from kirin import ir +from kirin.analysis import Forward, ForwardFrame + +from ..analysis.logical_validation.lattice import Error + + +@dataclass +class KernelValidation: + validation_analysis_cls: type[Forward] + + def run(self, mt: ir.Method, **kwargs) -> None: + validation_analysis = self.validation_analysis_cls(mt.dialects) + validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) + + errors = self.get_exceptions(mt, validation_frame) + + if len(errors) == 0: + # Valid program + return + + # TODO: Make something similar to an ExceptionGroup that pretty-prints ValidationErrors + raise errors[0] + + def get_exceptions(self, mt: ir.Method, validation_frame: ForwardFrame): + errors = [] + for value in validation_frame.entries.values(): + if not isinstance(value, Error): + continue + + if isinstance(value.error, ir.ValidationError): + value.error.attach(mt) + + errors.append(value.error) + + return errors diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py new file mode 100644 index 00000000..f71924a3 --- /dev/null +++ b/test/gemini/test_logical.py @@ -0,0 +1,40 @@ +import pytest +from kirin import ir + +from bloqade import squin, gemini +from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis +from bloqade.gemini.validation.logical import KernelValidation + + +def test_if_stmt_invalid(): + @gemini.logical(verify=False) + def main(): + q = squin.qubit.new(3) + + squin.h(q[0]) + + for i in range(10): + squin.x(q[1]) + + m = squin.qubit.measure(q[1]) + + q2 = squin.qubit.new(5) + squin.x(q2[0]) + + if m: + squin.x(q[1]) + + m2 = squin.qubit.measure(q[2]) + if m2: + squin.y(q[2]) + + frame, _ = GeminiLogicalValidationAnalysis(main.dialects).run_analysis( + main, no_raise=False + ) + + main.print(analysis=frame.entries) + + validator = KernelValidation(GeminiLogicalValidationAnalysis) + + with pytest.raises(ir.ValidationError): + validator.run(main) From 6bb087aab96090e9cc1ca89b1c0e9d582c5ba4e2 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 21 Oct 2025 15:59:25 +0200 Subject: [PATCH 03/12] Restructure and abstract the validation --- .../analysis/logical_validation/analysis.py | 8 +++---- .../analysis/logical_validation/impls.py | 19 +++++---------- src/bloqade/gemini/groups.py | 2 +- src/bloqade/gemini/validation/__init__.py | 0 src/bloqade/validation/__init__.py | 2 ++ src/bloqade/validation/analysis/__init__.py | 5 ++++ src/bloqade/validation/analysis/analysis.py | 24 +++++++++++++++++++ .../analysis}/lattice.py | 11 +++------ .../kernel_validation.py} | 14 +++++------ test/gemini/test_logical.py | 2 +- 10 files changed, 53 insertions(+), 34 deletions(-) delete mode 100644 src/bloqade/gemini/validation/__init__.py create mode 100644 src/bloqade/validation/__init__.py create mode 100644 src/bloqade/validation/analysis/__init__.py create mode 100644 src/bloqade/validation/analysis/analysis.py rename src/bloqade/{gemini/analysis/logical_validation => validation/analysis}/lattice.py (82%) rename src/bloqade/{gemini/validation/logical.py => validation/kernel_validation.py} (66%) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index b4faf46d..6932dbdb 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -1,10 +1,10 @@ from kirin import ir -from kirin.analysis import Forward, ForwardFrame -from .lattice import ErrorType +from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis +from bloqade.validation.analysis.lattice import ErrorType -class GeminiLogicalValidationAnalysis(Forward[ErrorType]): +class GeminiLogicalValidationAnalysis(ValidationAnalysis): keys = ["gemini.validate.logical"] lattice = ErrorType @@ -13,5 +13,5 @@ class GeminiLogicalValidationAnalysis(Forward[ErrorType]): def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): return self.run_callable(method.code, (self.lattice.bottom(),) + args) - def eval_stmt_fallback(self, frame: ForwardFrame[ErrorType], stmt: ir.Statement): + def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): return (self.lattice.top(),) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py index debea699..73b56385 100644 --- a/src/bloqade/gemini/analysis/logical_validation/impls.py +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -1,10 +1,10 @@ -from kirin import ir, interp as _interp +from kirin import interp as _interp from kirin.analysis import ForwardFrame, const from kirin.dialects import scf from bloqade.squin import qubit +from bloqade.validation.analysis.lattice import Error -from .lattice import Error from .analysis import GeminiLogicalValidationAnalysis @@ -40,15 +40,8 @@ def if_else( frame: ForwardFrame, stmt: scf.IfElse, ): - # raise ir.ValidationError( - # stmt, "if statements are not supported in logical Gemini programs!" - # ) return ( - Error( - ir.ValidationError( - stmt, "if statements are not supported in logical Gemini programs!" - ) - ), + Error(stmt, "if statements are not supported in logical Gemini programs!"), ) @_interp.impl(scf.For) @@ -62,10 +55,10 @@ def for_loop( return (interp.lattice.top(),) return ( - Error( - ir.ValidationError( + ( + Error( stmt, "Non-constant iterable in for loop is not supported in Gemini logical programs!", - ) + ), ), ) diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py index 020b18f0..dda7d177 100644 --- a/src/bloqade/gemini/groups.py +++ b/src/bloqade/gemini/groups.py @@ -7,9 +7,9 @@ from typing_extensions import Doc from bloqade.squin import gate, qubit +from bloqade.validation import KernelValidation from .analysis import GeminiLogicalValidationAnalysis -from .validation.logical import KernelValidation @ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist])) diff --git a/src/bloqade/gemini/validation/__init__.py b/src/bloqade/gemini/validation/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/bloqade/validation/__init__.py b/src/bloqade/validation/__init__.py new file mode 100644 index 00000000..e0992c23 --- /dev/null +++ b/src/bloqade/validation/__init__.py @@ -0,0 +1,2 @@ +from . import analysis as analysis +from .kernel_validation import KernelValidation as KernelValidation diff --git a/src/bloqade/validation/analysis/__init__.py b/src/bloqade/validation/analysis/__init__.py new file mode 100644 index 00000000..fc9a0b49 --- /dev/null +++ b/src/bloqade/validation/analysis/__init__.py @@ -0,0 +1,5 @@ +from . import lattice as lattice +from .analysis import ( + ValidationFrame as ValidationFrame, + ValidationAnalysis as ValidationAnalysis, +) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py new file mode 100644 index 00000000..f7f3cece --- /dev/null +++ b/src/bloqade/validation/analysis/analysis.py @@ -0,0 +1,24 @@ +from kirin import ir +from kirin.analysis import Forward, ForwardFrame + +from .lattice import ErrorType + +ValidationFrame = ForwardFrame[ErrorType] + + +class ValidationAnalysis(Forward[ErrorType]): + """Analysis pass that indicates errors in the IR according to the respective method tables. + + If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form) + you'll need to inherit from this class. + """ + + keys = ["validation"] + lattice = ErrorType + + def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + # NOTE: default to no errors + return (self.lattice.top(),) diff --git a/src/bloqade/gemini/analysis/logical_validation/lattice.py b/src/bloqade/validation/analysis/lattice.py similarity index 82% rename from src/bloqade/gemini/analysis/logical_validation/lattice.py rename to src/bloqade/validation/analysis/lattice.py index a166e734..6783a2e4 100644 --- a/src/bloqade/gemini/analysis/logical_validation/lattice.py +++ b/src/bloqade/validation/analysis/lattice.py @@ -39,19 +39,14 @@ class InvalidErrorType(ErrorType, metaclass=SingletonMeta): pass -# @final -# @dataclass -# class AnyErrorType(ErrorType, metaclass=SingletonMeta): -# """Top to indicate that there was an error, but we can't really tell where""" -# pass - - @final @dataclass class Error(ErrorType): """We found an error, here's a hopefully helpful message.""" - error: ir.ValidationError + stmt: ir.IRNode + msg: str + help: str | None = None @final diff --git a/src/bloqade/gemini/validation/logical.py b/src/bloqade/validation/kernel_validation.py similarity index 66% rename from src/bloqade/gemini/validation/logical.py rename to src/bloqade/validation/kernel_validation.py index afd04082..1aaeacb2 100644 --- a/src/bloqade/gemini/validation/logical.py +++ b/src/bloqade/validation/kernel_validation.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from kirin import ir -from kirin.analysis import Forward, ForwardFrame -from ..analysis.logical_validation.lattice import Error +from .analysis import ValidationFrame, ValidationAnalysis +from .analysis.lattice import Error @dataclass class KernelValidation: - validation_analysis_cls: type[Forward] + validation_analysis_cls: type[ValidationAnalysis] def run(self, mt: ir.Method, **kwargs) -> None: validation_analysis = self.validation_analysis_cls(mt.dialects) @@ -23,15 +23,15 @@ def run(self, mt: ir.Method, **kwargs) -> None: # TODO: Make something similar to an ExceptionGroup that pretty-prints ValidationErrors raise errors[0] - def get_exceptions(self, mt: ir.Method, validation_frame: ForwardFrame): + def get_exceptions(self, mt: ir.Method, validation_frame: ValidationFrame): errors = [] for value in validation_frame.entries.values(): if not isinstance(value, Error): continue - if isinstance(value.error, ir.ValidationError): - value.error.attach(mt) + error = ir.ValidationError(value.stmt, value.msg, help=value.help) + error.attach(mt) - errors.append(value.error) + errors.append(error) return errors diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index f71924a3..0ec53d19 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -2,8 +2,8 @@ from kirin import ir from bloqade import squin, gemini +from bloqade.validation import KernelValidation from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis -from bloqade.gemini.validation.logical import KernelValidation def test_if_stmt_invalid(): From e9f3d94711ec35d01f2a994ec006c5b9fc4b0192 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 21 Oct 2025 16:00:41 +0200 Subject: [PATCH 04/12] Fix test --- test/gemini/test_logical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index 0ec53d19..4c9bb0c0 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -9,7 +9,7 @@ def test_if_stmt_invalid(): @gemini.logical(verify=False) def main(): - q = squin.qubit.new(3) + q = squin.qalloc(3) squin.h(q[0]) @@ -18,7 +18,7 @@ def main(): m = squin.qubit.measure(q[1]) - q2 = squin.qubit.new(5) + q2 = squin.qalloc(5) squin.x(q2[0]) if m: From a329022627b9382441827878710443d71bcdbfad Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 09:36:38 +0200 Subject: [PATCH 05/12] Make analysis ABC --- src/bloqade/validation/analysis/analysis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index f7f3cece..0a08fa29 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -1,3 +1,5 @@ +from abc import ABC + from kirin import ir from kirin.analysis import Forward, ForwardFrame @@ -6,14 +8,13 @@ ValidationFrame = ForwardFrame[ErrorType] -class ValidationAnalysis(Forward[ErrorType]): +class ValidationAnalysis(Forward[ErrorType], ABC): """Analysis pass that indicates errors in the IR according to the respective method tables. If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form) you'll need to inherit from this class. """ - keys = ["validation"] lattice = ErrorType def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): From 84b349254ef8abb3cedb943288b4430d40b31013 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 10:21:14 +0200 Subject: [PATCH 06/12] Add some more implementations and tests --- .../analysis/logical_validation/impls.py | 52 +++++++++---------- src/bloqade/gemini/groups.py | 35 +++++++++---- test/gemini/test_logical.py | 47 +++++++++++++++++ 3 files changed, 97 insertions(+), 37 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py index 73b56385..e5c63d20 100644 --- a/src/bloqade/gemini/analysis/logical_validation/impls.py +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -1,35 +1,13 @@ from kirin import interp as _interp -from kirin.analysis import ForwardFrame, const -from kirin.dialects import scf +from kirin.analysis import const +from kirin.dialects import scf, func -from bloqade.squin import qubit +from bloqade.validation.analysis import ValidationFrame from bloqade.validation.analysis.lattice import Error from .analysis import GeminiLogicalValidationAnalysis -@qubit.dialect.register(key="gemini.validate.logical") -class __QubitGeminiLogicalValidation(_interp.MethodTable): - - @_interp.impl(qubit.New) - def new( - self, - interp: GeminiLogicalValidationAnalysis, - frame: ForwardFrame, - stmt: qubit.New, - ): - # TODO: this is actually tricky, since qalloc calls qubit.new multiple times and we have to make sure qalloc is only called once - # but it can technically contain many qubit.new calls - # if interp.has_allocated_qubits: - # raise ir.ValidationError( - # stmt, "Can only allocate qubits once in a logical Gemini program!" - # ) - - # interp.has_allocated_qubits = True - - pass - - @scf.dialect.register(key="gemini.validate.logical") class __ScfGeminiLogicalValidation(_interp.MethodTable): @@ -37,7 +15,7 @@ class __ScfGeminiLogicalValidation(_interp.MethodTable): def if_else( self, interp: GeminiLogicalValidationAnalysis, - frame: ForwardFrame, + frame: ValidationFrame, stmt: scf.IfElse, ): return ( @@ -48,7 +26,7 @@ def if_else( def for_loop( self, interp: GeminiLogicalValidationAnalysis, - frame: ForwardFrame, + frame: ValidationFrame, stmt: scf.For, ): if isinstance(stmt.iterable.hints.get("const"), const.Value): @@ -59,6 +37,24 @@ def for_loop( Error( stmt, "Non-constant iterable in for loop is not supported in Gemini logical programs!", - ), + ) + ), + ) + + +@func.dialect.register(key="gemini.validate.logical") +class __FuncGeminiLogicalValidation(_interp.MethodTable): + @_interp.impl(func.Invoke) + def invoke( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: func.Invoke, + ): + return ( + Error( + stmt, + "Function invocations not supported in logical Gemini program!", + help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls", ), ) diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py index dda7d177..6ec4434b 100644 --- a/src/bloqade/gemini/groups.py +++ b/src/bloqade/gemini/groups.py @@ -5,9 +5,11 @@ from kirin.prelude import structural_no_opt from kirin.dialects import py, func, ilist from typing_extensions import Doc +from kirin.passes.inline import InlinePass from bloqade.squin import gate, qubit from bloqade.validation import KernelValidation +from bloqade.rewrite.passes import AggressiveUnroll from .analysis import GeminiLogicalValidationAnalysis @@ -30,21 +32,36 @@ def run_pass( aggressive: Annotated[ bool, Doc("run aggressive folding passes if `fold=True`") ] = False, + inline: Annotated[bool, Doc("inline function calls, default `True`")] = True, + aggressive_unroll: Annotated[ + bool, + Doc( + "Run aggressive inlining and unrolling pass on the IR, default `False`" + ), + ] = False, no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, ) -> None: - default_pass = Default( - self, - verify=verify, - fold=fold, - aggressive=aggressive, - typeinfer=typeinfer, - no_raise=no_raise, - ) - default_pass.fixpoint(mt) + if aggressive_unroll: + AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt) + else: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + + default_pass.fixpoint(mt) + + if inline and not aggressive_unroll: + InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt) if verify: validator = KernelValidation(GeminiLogicalValidationAnalysis) validator.run(mt, no_raise=no_raise) + mt.verify() return run_pass diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index 4c9bb0c0..72601c3b 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -2,6 +2,7 @@ from kirin import ir from bloqade import squin, gemini +from bloqade.types import Qubit from bloqade.validation import KernelValidation from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis @@ -38,3 +39,49 @@ def main(): with pytest.raises(ir.ValidationError): validator.run(main) + + +def test_for_loop(): + + @gemini.logical + def valid_loop(): + q = squin.qalloc(3) + + for i in range(3): + squin.x(q[i]) + + valid_loop.print() + + with pytest.raises(ir.ValidationError): + + @gemini.logical + def invalid_loop(n: int): + q = squin.qalloc(3) + + for i in range(n): + squin.x(q[i]) + + invalid_loop.print() + + +def test_func(): + @gemini.logical + def sub_kernel(q: Qubit): + squin.x(q) + + @gemini.logical + def main(): + q = squin.qalloc(3) + sub_kernel(q[0]) + + main.print() + + with pytest.raises(ir.ValidationError): + + @gemini.logical(inline=False) + def invalid(): + q = squin.qalloc(3) + sub_kernel(q[0]) + + +test_func() From e7e2ef2414a6b65b457a3a97ee4a68987eb79854 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 11:35:51 +0200 Subject: [PATCH 07/12] Verify that U3 can only occur at first position --- .../analysis/logical_validation/analysis.py | 13 +++--- .../analysis/logical_validation/impls.py | 22 ++++++++++ src/bloqade/gemini/groups.py | 6 +-- src/bloqade/squin/gate/stmts.py | 14 +++++-- src/bloqade/validation/analysis/analysis.py | 40 ++++++++++++++++++- src/bloqade/validation/kernel_validation.py | 18 +++++++-- test/gemini/test_logical.py | 25 +++++++++++- 7 files changed, 119 insertions(+), 19 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 6932dbdb..93e97d8e 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -1,17 +1,18 @@ from kirin import ir +from bloqade import squin from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis -from bloqade.validation.analysis.lattice import ErrorType class GeminiLogicalValidationAnalysis(ValidationAnalysis): keys = ["gemini.validate.logical"] - lattice = ErrorType - has_allocated_qubits: bool = False - - def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + first_gate = True def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): + + if isinstance(stmt, squin.gate.stmts.Gate): + # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here + self.first_gate = False + return (self.lattice.top(),) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py index e5c63d20..11231591 100644 --- a/src/bloqade/gemini/analysis/logical_validation/impls.py +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -2,6 +2,7 @@ from kirin.analysis import const from kirin.dialects import scf, func +from bloqade.squin import gate from bloqade.validation.analysis import ValidationFrame from bloqade.validation.analysis.lattice import Error @@ -58,3 +59,24 @@ def invoke( help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls", ), ) + + +@gate.dialect.register(key="gemini.validate.logical") +class __GateGeminiLogicalValidation(_interp.MethodTable): + @_interp.impl(gate.stmts.U3) + def u3( + self, + interp: GeminiLogicalValidationAnalysis, + frame: ValidationFrame, + stmt: gate.stmts.U3, + ): + if interp.first_gate: + interp.first_gate = False + return (interp.lattice.top(),) + + return ( + Error( + stmt, + "U3 gate can only be used for initial state preparation, i.e. as the first gate!", + ), + ) diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/groups.py index 6ec4434b..90441099 100644 --- a/src/bloqade/gemini/groups.py +++ b/src/bloqade/gemini/groups.py @@ -42,6 +42,9 @@ def run_pass( no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True, ) -> None: + if inline and not aggressive_unroll: + InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt) + if aggressive_unroll: AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt) else: @@ -56,9 +59,6 @@ def run_pass( default_pass.fixpoint(mt) - if inline and not aggressive_unroll: - InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt) - if verify: validator = KernelValidation(GeminiLogicalValidationAnalysis) validator.run(mt, no_raise=no_raise) diff --git a/src/bloqade/squin/gate/stmts.py b/src/bloqade/squin/gate/stmts.py index 960ae95d..afc91897 100644 --- a/src/bloqade/squin/gate/stmts.py +++ b/src/bloqade/squin/gate/stmts.py @@ -8,7 +8,13 @@ @statement -class SingleQubitGate(ir.Statement): +class Gate(ir.Statement): + # NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class + pass + + +@statement +class SingleQubitGate(Gate): traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) @@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate): @statement -class RotationGate(ir.Statement): +class RotationGate(Gate): # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg traits = frozenset({lowering.FromPythonCall()}) angle: ir.SSAValue = info.argument(types.Float) @@ -85,7 +91,7 @@ class Rz(RotationGate): @statement -class ControlledGate(ir.Statement): +class ControlledGate(Gate): traits = frozenset({lowering.FromPythonCall()}) controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) @@ -110,7 +116,7 @@ class CZ(ControlledGate): @statement(dialect=dialect) -class U3(ir.Statement): +class U3(Gate): # NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg traits = frozenset({lowering.FromPythonCall()}) theta: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 0a08fa29..25233ee7 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -1,6 +1,9 @@ from abc import ABC +from typing import Iterable +from dataclasses import field, dataclass from kirin import ir +from kirin.interp import AbstractFrame from kirin.analysis import Forward, ForwardFrame from .lattice import ErrorType @@ -8,6 +11,7 @@ ValidationFrame = ForwardFrame[ErrorType] +@dataclass class ValidationAnalysis(Forward[ErrorType], ABC): """Analysis pass that indicates errors in the IR according to the respective method tables. @@ -17,9 +21,43 @@ class ValidationAnalysis(Forward[ErrorType], ABC): lattice = ErrorType + additional_errors: list[ErrorType] = field(default_factory=list) + """List to store return values that are not associated with an SSA Value (e.g. when the statement has no ResultValue)""" + def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): - return self.run_callable(method.code, (self.lattice.bottom(),) + args) + return self.run_callable(method.code, (self.lattice.top(),) + args) def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): # NOTE: default to no errors return (self.lattice.top(),) + + def set_values( + self, + frame: AbstractFrame[ErrorType], + ssa: Iterable[ir.SSAValue], + results: Iterable[ErrorType], + ): + """Set the abstract values for the given SSA values in the frame. + + This method is overridden to account for additional errors we may + encounter when they are not associated to an SSA Value. + """ + + number_of_ssa_values = 0 + for ssa_value, result in zip(ssa, results): + number_of_ssa_values += 1 + if ssa_value in frame.entries: + frame.entries[ssa_value] = frame.entries[ssa_value].join(result) + else: + frame.entries[ssa_value] = result + + if isinstance(results, tuple): + # NOTE: usually what we have + self.additional_errors.extend(results[number_of_ssa_values:]) + + for i, result in enumerate(results): + # NOTE: only sure-fire way I found to get remaining values from an Iterable + if i < number_of_ssa_values: + continue + + self.additional_errors.append(result) diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 1aaeacb2..2333940c 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -1,9 +1,10 @@ +import itertools from dataclasses import dataclass from kirin import ir from .analysis import ValidationFrame, ValidationAnalysis -from .analysis.lattice import Error +from .analysis.lattice import Error, ErrorType @dataclass @@ -14,7 +15,9 @@ def run(self, mt: ir.Method, **kwargs) -> None: validation_analysis = self.validation_analysis_cls(mt.dialects) validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) - errors = self.get_exceptions(mt, validation_frame) + errors = self.get_exceptions( + mt, validation_frame, validation_analysis.additional_errors + ) if len(errors) == 0: # Valid program @@ -23,9 +26,16 @@ def run(self, mt: ir.Method, **kwargs) -> None: # TODO: Make something similar to an ExceptionGroup that pretty-prints ValidationErrors raise errors[0] - def get_exceptions(self, mt: ir.Method, validation_frame: ValidationFrame): + def get_exceptions( + self, + mt: ir.Method, + validation_frame: ValidationFrame, + additional_errors: list[ErrorType], + ): errors = [] - for value in validation_frame.entries.values(): + for value in itertools.chain( + validation_frame.entries.values(), additional_errors + ): if not isinstance(value, Error): continue diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index 72601c3b..a8fa6cc6 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -84,4 +84,27 @@ def invalid(): sub_kernel(q[0]) -test_func() +def test_clifford_gates(): + @gemini.logical + def main(): + q = squin.qalloc(2) + squin.u3(0.123, 0.253, 1.2, q[0]) + + squin.h(q[0]) + squin.cx(q[0], q[1]) + + with pytest.raises(ir.ValidationError): + + @gemini.logical(no_raise=False) + def invalid(): + q = squin.qalloc(2) + + squin.h(q[0]) + squin.cx(q[0], q[1]) + squin.u3(0.123, 0.253, 1.2, q[0]) + + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis( + invalid, no_raise=False + ) + + invalid.print(analysis=frame.entries) From e78b9606ac56f0f7ca7f19fa6881ea2905b9f646 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 13:23:14 +0200 Subject: [PATCH 08/12] Nice error messages when encountering multiple errors --- src/bloqade/validation/analysis/analysis.py | 4 +- src/bloqade/validation/kernel_validation.py | 41 +++++++++++++++++++-- test/gemini/test_logical.py | 18 +++++++++ 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 25233ee7..6c4fad79 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -53,11 +53,11 @@ def set_values( if isinstance(results, tuple): # NOTE: usually what we have - self.additional_errors.extend(results[number_of_ssa_values:]) + self.additional_errors.extend(results[number_of_ssa_values + 1 :]) for i, result in enumerate(results): # NOTE: only sure-fire way I found to get remaining values from an Iterable - if i < number_of_ssa_values: + if i <= number_of_ssa_values: continue self.additional_errors.append(result) diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 2333940c..111357dd 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -1,12 +1,44 @@ +import sys import itertools from dataclasses import dataclass -from kirin import ir +from kirin import ir, exception +from rich.console import Console from .analysis import ValidationFrame, ValidationAnalysis from .analysis.lattice import Error, ErrorType +class ValidationErrorGroup(BaseException): + def __init__(self, *args: object, errors=[]) -> None: + super().__init__(*args) + self.errors = errors + + +# TODO: this overrides kirin's exception handler and should be upstreamed +def exception_handler(exc_type, exc_value, exc_tb): + if issubclass(exc_type, ValidationErrorGroup): + console = Console(force_terminal=True) + for i, err in enumerate(exc_value.errors): + with console.capture() as capture: + console.print(f"==== Error {i} ====") + console.print(f"[bold red]{type(err).__name__}:[/bold red]", end="") + print(capture.get(), *err.args, file=sys.stderr) + if err.source: + print("Source Traceback:", file=sys.stderr) + print(err.hint(), file=sys.stderr, end="") + console.print("=" * 40) + console.print( + "[bold red]Kernel validation failed:[/bold red] There were multiple errors encountered during validation, see above" + ) + return + + return exception.exception_handler(exc_type, exc_value, exc_tb) + + +sys.excepthook = exception_handler + + @dataclass class KernelValidation: validation_analysis_cls: type[ValidationAnalysis] @@ -22,9 +54,10 @@ def run(self, mt: ir.Method, **kwargs) -> None: if len(errors) == 0: # Valid program return - - # TODO: Make something similar to an ExceptionGroup that pretty-prints ValidationErrors - raise errors[0] + elif len(errors) == 1: + raise errors[0] + else: + raise ValidationErrorGroup(errors=errors) def get_exceptions( self, diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index a8fa6cc6..94947cd4 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -5,6 +5,7 @@ from bloqade.types import Qubit from bloqade.validation import KernelValidation from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis +from bloqade.validation.kernel_validation import ValidationErrorGroup def test_if_stmt_invalid(): @@ -108,3 +109,20 @@ def invalid(): ) invalid.print(analysis=frame.entries) + + +def test_multiple_errors(): + with pytest.raises(ValidationErrorGroup): + + @gemini.logical + def main(n: int): + q = squin.qalloc(3) + m = squin.qubit.measure(q[0]) + squin.x(q[1]) + if m: + squin.x(q[0]) + + for k in range(n): + squin.h(q[k]) + + squin.u3(0.1, 0.2, 0.3, q[1]) From f0a3516f74fe1cec8ccadf679ec7618d746f0e2d Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 15:00:31 +0200 Subject: [PATCH 09/12] Fix bug in collecting additional errors and tests --- src/bloqade/validation/analysis/analysis.py | 16 ++++++++-------- test/gemini/test_logical.py | 13 ++++++++++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 6c4fad79..11419f9b 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -53,11 +53,11 @@ def set_values( if isinstance(results, tuple): # NOTE: usually what we have - self.additional_errors.extend(results[number_of_ssa_values + 1 :]) - - for i, result in enumerate(results): - # NOTE: only sure-fire way I found to get remaining values from an Iterable - if i <= number_of_ssa_values: - continue - - self.additional_errors.append(result) + self.additional_errors.extend(results[number_of_ssa_values:]) + else: + for i, result in enumerate(results): + # NOTE: only sure-fire way I found to get remaining values from an Iterable + if i < number_of_ssa_values: + continue + + self.additional_errors.append(result) diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical.py index 94947cd4..ce8e1a34 100644 --- a/test/gemini/test_logical.py +++ b/test/gemini/test_logical.py @@ -38,7 +38,7 @@ def main(): validator = KernelValidation(GeminiLogicalValidationAnalysis) - with pytest.raises(ir.ValidationError): + with pytest.raises(ValidationErrorGroup): validator.run(main) @@ -77,7 +77,7 @@ def main(): main.print() - with pytest.raises(ir.ValidationError): + with pytest.raises(ValidationErrorGroup): @gemini.logical(inline=False) def invalid(): @@ -112,7 +112,8 @@ def invalid(): def test_multiple_errors(): - with pytest.raises(ValidationErrorGroup): + did_error = False + try: @gemini.logical def main(n: int): @@ -126,3 +127,9 @@ def main(n: int): squin.h(q[k]) squin.u3(0.1, 0.2, 0.3, q[1]) + + except ValidationErrorGroup as e: + did_error = True + assert len(e.errors) == 3 + + assert did_error From c5a2515fc5e15b1012d1c45561550b3825cc2290 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 15:01:09 +0200 Subject: [PATCH 10/12] Rename file --- test/gemini/{test_logical.py => test_logical_validation.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/gemini/{test_logical.py => test_logical_validation.py} (100%) diff --git a/test/gemini/test_logical.py b/test/gemini/test_logical_validation.py similarity index 100% rename from test/gemini/test_logical.py rename to test/gemini/test_logical_validation.py From 47ef35be8027c35b06110a23a8fde61659d5a4b9 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 15:54:25 +0200 Subject: [PATCH 11/12] Make the frame collect errors directly --- src/bloqade/validation/analysis/analysis.py | 60 +++++++++++---------- src/bloqade/validation/kernel_validation.py | 31 ++--------- 2 files changed, 36 insertions(+), 55 deletions(-) diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 11419f9b..6bcdd5aa 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -3,16 +3,20 @@ from dataclasses import field, dataclass from kirin import ir -from kirin.interp import AbstractFrame -from kirin.analysis import Forward, ForwardFrame +from kirin.analysis import ForwardExtra, ForwardFrame -from .lattice import ErrorType +from .lattice import Error, ErrorType -ValidationFrame = ForwardFrame[ErrorType] + +@dataclass +class ValidationFrame(ForwardFrame[ErrorType]): + # NOTE: cannot be set[Error] since that's not hashable + errors: list[Error] = field(default_factory=list) + """List of all ecnountered errors.""" @dataclass -class ValidationAnalysis(Forward[ErrorType], ABC): +class ValidationAnalysis(ForwardExtra[ValidationFrame, ErrorType], ABC): """Analysis pass that indicates errors in the IR according to the respective method tables. If you need to implement validation for a dialect shared by many groups (for example, if you need to ascertain if statements have a specific form) @@ -21,9 +25,6 @@ class ValidationAnalysis(Forward[ErrorType], ABC): lattice = ErrorType - additional_errors: list[ErrorType] = field(default_factory=list) - """List to store return values that are not associated with an SSA Value (e.g. when the statement has no ResultValue)""" - def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): return self.run_callable(method.code, (self.lattice.top(),) + args) @@ -33,31 +34,32 @@ def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): def set_values( self, - frame: AbstractFrame[ErrorType], + frame: ValidationFrame, ssa: Iterable[ir.SSAValue], results: Iterable[ErrorType], ): """Set the abstract values for the given SSA values in the frame. - This method is overridden to account for additional errors we may - encounter when they are not associated to an SSA Value. + This method is overridden to explicitly collect all errors we found in the + additional field of the frame. That also includes statements that don't + have an associated `ResultValue`. """ - number_of_ssa_values = 0 - for ssa_value, result in zip(ssa, results): - number_of_ssa_values += 1 - if ssa_value in frame.entries: - frame.entries[ssa_value] = frame.entries[ssa_value].join(result) - else: - frame.entries[ssa_value] = result - - if isinstance(results, tuple): - # NOTE: usually what we have - self.additional_errors.extend(results[number_of_ssa_values:]) - else: - for i, result in enumerate(results): - # NOTE: only sure-fire way I found to get remaining values from an Iterable - if i < number_of_ssa_values: - continue - - self.additional_errors.append(result) + ssa_value_list = list(ssa) + number_of_ssa_values = len(ssa_value_list) + for i, result in enumerate(results): + if isinstance(result, Error): + frame.errors.append(result) + + if i < number_of_ssa_values: + ssa_value = ssa_value_list[i] + + if ssa_value in frame.entries: + frame.entries[ssa_value] = frame.entries[ssa_value].join(result) + else: + frame.entries[ssa_value] = result + + def initialize_frame( + self, code: ir.Statement, *, has_parent_access: bool = False + ) -> ValidationFrame: + return ValidationFrame(code, has_parent_access=has_parent_access) diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 111357dd..07c4ffcb 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -1,12 +1,10 @@ import sys -import itertools from dataclasses import dataclass from kirin import ir, exception from rich.console import Console -from .analysis import ValidationFrame, ValidationAnalysis -from .analysis.lattice import Error, ErrorType +from .analysis import ValidationAnalysis class ValidationErrorGroup(BaseException): @@ -47,9 +45,10 @@ def run(self, mt: ir.Method, **kwargs) -> None: validation_analysis = self.validation_analysis_cls(mt.dialects) validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) - errors = self.get_exceptions( - mt, validation_frame, validation_analysis.additional_errors - ) + errors = [ + ir.ValidationError(err.stmt, err.msg, help=err.help) + for err in validation_frame.errors + ] if len(errors) == 0: # Valid program @@ -58,23 +57,3 @@ def run(self, mt: ir.Method, **kwargs) -> None: raise errors[0] else: raise ValidationErrorGroup(errors=errors) - - def get_exceptions( - self, - mt: ir.Method, - validation_frame: ValidationFrame, - additional_errors: list[ErrorType], - ): - errors = [] - for value in itertools.chain( - validation_frame.entries.values(), additional_errors - ): - if not isinstance(value, Error): - continue - - error = ir.ValidationError(value.stmt, value.msg, help=value.help) - error.attach(mt) - - errors.append(error) - - return errors From 9f66250a62ade4fc1137edd4d12fd94fb4639251 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 22 Oct 2025 16:25:13 +0200 Subject: [PATCH 12/12] Simplify things by just appending to the list of errors --- .../analysis/logical_validation/analysis.py | 3 +- .../analysis/logical_validation/impls.py | 47 +++++++++++++------ src/bloqade/validation/analysis/analysis.py | 40 ++++------------ src/bloqade/validation/analysis/lattice.py | 13 +++-- src/bloqade/validation/kernel_validation.py | 12 +++-- 5 files changed, 58 insertions(+), 57 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 93e97d8e..14a03cbf 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -10,9 +10,8 @@ class GeminiLogicalValidationAnalysis(ValidationAnalysis): first_gate = True def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): - if isinstance(stmt, squin.gate.stmts.Gate): # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here self.first_gate = False - return (self.lattice.top(),) + return super().eval_stmt_fallback(frame, stmt) diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py index 11231591..cf4bd87b 100644 --- a/src/bloqade/gemini/analysis/logical_validation/impls.py +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -1,4 +1,4 @@ -from kirin import interp as _interp +from kirin import ir, interp as _interp from kirin.analysis import const from kirin.dialects import scf, func @@ -19,8 +19,15 @@ def if_else( frame: ValidationFrame, stmt: scf.IfElse, ): + frame.errors.append( + ir.ValidationError( + stmt, "If statements are not supported in logical Gemini programs!" + ) + ) return ( - Error(stmt, "if statements are not supported in logical Gemini programs!"), + Error( + message="If statements are not supported in logical Gemini programs!" + ), ) @_interp.impl(scf.For) @@ -33,12 +40,16 @@ def for_loop( if isinstance(stmt.iterable.hints.get("const"), const.Value): return (interp.lattice.top(),) + frame.errors.append( + ir.ValidationError( + stmt, + "Non-constant iterable in for loop is not supported in Gemini logical programs!", + ) + ) + return ( - ( - Error( - stmt, - "Non-constant iterable in for loop is not supported in Gemini logical programs!", - ) + Error( + message="Non-constant iterable in for loop is not supported in Gemini logical programs!" ), ) @@ -52,12 +63,19 @@ def invoke( frame: ValidationFrame, stmt: func.Invoke, ): - return ( - Error( + frame.errors.append( + ir.ValidationError( stmt, "Function invocations not supported in logical Gemini program!", help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls", - ), + ) + ) + + return tuple( + Error( + message="Function invocations not supported in logical Gemini program!" + ) + for _ in stmt.results ) @@ -72,11 +90,12 @@ def u3( ): if interp.first_gate: interp.first_gate = False - return (interp.lattice.top(),) + return () - return ( - Error( + frame.errors.append( + ir.ValidationError( stmt, "U3 gate can only be used for initial state preparation, i.e. as the first gate!", - ), + ) ) + return () diff --git a/src/bloqade/validation/analysis/analysis.py b/src/bloqade/validation/analysis/analysis.py index 6bcdd5aa..323cbd40 100644 --- a/src/bloqade/validation/analysis/analysis.py +++ b/src/bloqade/validation/analysis/analysis.py @@ -1,18 +1,21 @@ from abc import ABC -from typing import Iterable from dataclasses import field, dataclass from kirin import ir from kirin.analysis import ForwardExtra, ForwardFrame -from .lattice import Error, ErrorType +from .lattice import ErrorType @dataclass class ValidationFrame(ForwardFrame[ErrorType]): # NOTE: cannot be set[Error] since that's not hashable - errors: list[Error] = field(default_factory=list) - """List of all ecnountered errors.""" + errors: list[ir.ValidationError] = field(default_factory=list) + """List of all ecnountered errors. + + Append a `kirin.ir.ValidationError` to this list in the method implementation + in order for it to get picked up by the `KernelValidation` run. + """ @dataclass @@ -30,34 +33,7 @@ def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]): def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement): # NOTE: default to no errors - return (self.lattice.top(),) - - def set_values( - self, - frame: ValidationFrame, - ssa: Iterable[ir.SSAValue], - results: Iterable[ErrorType], - ): - """Set the abstract values for the given SSA values in the frame. - - This method is overridden to explicitly collect all errors we found in the - additional field of the frame. That also includes statements that don't - have an associated `ResultValue`. - """ - - ssa_value_list = list(ssa) - number_of_ssa_values = len(ssa_value_list) - for i, result in enumerate(results): - if isinstance(result, Error): - frame.errors.append(result) - - if i < number_of_ssa_values: - ssa_value = ssa_value_list[i] - - if ssa_value in frame.entries: - frame.entries[ssa_value] = frame.entries[ssa_value].join(result) - else: - frame.entries[ssa_value] = result + return tuple(self.lattice.top() for _ in stmt.results) def initialize_frame( self, code: ir.Statement, *, has_parent_access: bool = False diff --git a/src/bloqade/validation/analysis/lattice.py b/src/bloqade/validation/analysis/lattice.py index 6783a2e4..d4c46469 100644 --- a/src/bloqade/validation/analysis/lattice.py +++ b/src/bloqade/validation/analysis/lattice.py @@ -1,7 +1,6 @@ from typing import final from dataclasses import dataclass -from kirin import ir from kirin.lattice import ( SingletonMeta, BoundedLattice, @@ -42,11 +41,15 @@ class InvalidErrorType(ErrorType, metaclass=SingletonMeta): @final @dataclass class Error(ErrorType): - """We found an error, here's a hopefully helpful message.""" + """Indicates an error in the IR.""" - stmt: ir.IRNode - msg: str - help: str | None = None + message: str = "" + """Optional error message to show in the IR. + + NOTE: this is just to show a message when printing the IR. Actual errors + are collected by appending ir.ValidationError to the frame in the method + implementation. + """ @final diff --git a/src/bloqade/validation/kernel_validation.py b/src/bloqade/validation/kernel_validation.py index 07c4ffcb..84159352 100644 --- a/src/bloqade/validation/kernel_validation.py +++ b/src/bloqade/validation/kernel_validation.py @@ -39,16 +39,20 @@ def exception_handler(exc_type, exc_value, exc_tb): @dataclass class KernelValidation: + """Validate a kernel according to a `ValidationAnalysis`. + + This is a simple wrapper around the analysis that runs the analysis, checks + the `ValidationFrame` for errors and throws them if there are any. + """ + validation_analysis_cls: type[ValidationAnalysis] + """The analysis that you want to run in order to validate the kernel.""" def run(self, mt: ir.Method, **kwargs) -> None: validation_analysis = self.validation_analysis_cls(mt.dialects) validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs) - errors = [ - ir.ValidationError(err.stmt, err.msg, help=err.help) - for err in validation_frame.errors - ] + errors = validation_frame.errors if len(errors) == 0: # Valid program