Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 84 additions & 36 deletions test/analysis/measure_id/test_measure_id.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from kirin.passes import HintConst, inline
from kirin.dialects import scf
from kirin.passes.inline import InlinePass

from bloqade import squin
from bloqade.analysis.measure_id import MeasurementIDAnalysis
from bloqade.stim.passes.flatten import Flatten
from bloqade.analysis.measure_id.lattice import (
NotMeasureId,
MeasureIdBool,
MeasureIdTuple,
InvalidMeasureId,
Expand All @@ -15,7 +16,16 @@ def results_at(kern, block_id, stmt_id):
return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore


@pytest.mark.xfail
def results_of_variables(kernel, variable_names):
results = {}
for stmt in kernel.callable_region.stmts():
for result in stmt.results:
if result.name in variable_names:
results[result.name] = result

return results


def test_add():
@squin.kernel
def test():
Expand All @@ -28,6 +38,8 @@ def test():
ml2 = squin.broadcast.measure(ql2)
return ml1 + ml2

Flatten(test.dialects).fixpoint(test)

frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)

measure_id_tuples = [
Expand All @@ -41,7 +53,6 @@ def test():
assert measure_id_tuples[-1] == expected_measure_id_tuple


@pytest.mark.xfail
def test_measure_alias():

@squin.kernel
Expand All @@ -52,28 +63,33 @@ def test():

return ml_alias

Flatten(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)

test.print(analysis=frame.entries)

# Collect MeasureIdTuples
measure_id_tuples = [
value for value in frame.entries.values() if isinstance(value, MeasureIdTuple)
]

# construct expected MeasureIdTuple
expected_measure_id_tuple = MeasureIdTuple(
# construct expected MeasureIdTuples
measure_id_tuple_with_id_bools = MeasureIdTuple(
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
)
measure_id_tuple_with_not_measures = MeasureIdTuple(
data=tuple([NotMeasureId() for _ in range(5)])
)

assert len(measure_id_tuples) == 2
assert len(measure_id_tuples) == 3
# New qubit.new semantics cause a MeasureIdTuple to be generated full of NotMeasureIds because
# qubit.new is actually an ilist.map that invokes single qubit allocation multiple times
# and puts them into an ilist.
assert measure_id_tuples[0] == measure_id_tuple_with_not_measures
assert all(
measure_id_tuple == expected_measure_id_tuple
for measure_id_tuple in measure_id_tuples
measure_id_tuple == measure_id_tuple_with_id_bools
for measure_id_tuple in measure_id_tuples[1:]
)


@pytest.mark.xfail
def test_measure_count_at_if_else():

@squin.kernel
Expand All @@ -88,6 +104,7 @@ def test():
if ms[3]:
squin.y(q[1])

Flatten(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)

assert all(
Expand All @@ -96,32 +113,29 @@ def test():
)


@pytest.mark.xfail
def test_scf_cond_true():
@squin.kernel
def test():
q = squin.qalloc(1)
q = squin.qalloc(3)
squin.x(q[2])

ms = None
cond = True
if cond:
ms = squin.broadcast.measure(q)
ms = squin.measure(q[1])
else:
ms = squin.measure(q[0])

return ms

HintConst(dialects=test.dialects).unsafe_run(test)
InlinePass(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)

# MeasureIdTuple(data=MeasureIdBool(idx=1),) should occur twice:
# MeasureIdBool(idx=1) should occur twice:
# First from the measurement in the true branch, then
# the result of the scf.IfElse itself
analysis_results = [
val
for val in frame.entries.values()
if val == MeasureIdTuple(data=(MeasureIdBool(idx=1),))
val for val in frame.entries.values() if val == MeasureIdBool(idx=1)
]
assert len(analysis_results) == 2

Expand All @@ -136,16 +150,16 @@ def test():
ms = None
cond = False
if cond:
ms = squin.broadcast.measure(q)
ms = squin.measure(q[1])
else:
ms = squin.qubit.measure(q[0])
ms = squin.measure(q[0])

return ms

inline.InlinePass(test.dialects).fixpoint(test)

HintConst(dialects=test.dialects).unsafe_run(test)
# need to preserve the scf.IfElse but need things like qalloc to be inlined
InlinePass(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
test.print(analysis=frame.entries)

# MeasureIdBool(idx=1) should occur twice:
# First from the measurement in the false branch, then
Expand All @@ -156,7 +170,37 @@ def test():
assert len(analysis_results) == 2


@pytest.mark.xfail
def test_scf_cond_unknown():

@squin.kernel
def test(cond: bool):
q = squin.qalloc(5)
squin.x(q[2])

if cond:
ms = squin.broadcast.measure(q)
else:
ms = squin.measure(q[0])

return ms

# We can use Flatten here because the variable condition for the scf.IfElse
# means it cannot be simplified.
Flatten(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
analysis_results = [
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
]
# Both branches of the scf.IfElse should be properly traversed and contain the following
# analysis results.
expected_full_register_measurement = MeasureIdTuple(
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
)
expected_else_measurement = MeasureIdTuple(data=(MeasureIdBool(idx=6),))
assert expected_full_register_measurement in analysis_results
assert expected_else_measurement in analysis_results


def test_slice():
@squin.kernel
def test():
Expand All @@ -170,19 +214,23 @@ def test():

return ms_final

Flatten(test.dialects).fixpoint(test)
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)

test.print(analysis=frame.entries)
results = results_of_variables(test, ("msi", "msi2", "ms_final"))

assert [frame.entries[result] for result in results_at(test, 0, 7)] == [
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))))
]
assert [frame.entries[result] for result in results_at(test, 0, 9)] == [
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))))
]
assert [frame.entries[result] for result in results_at(test, 0, 11)] == [
MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)))
]
# This is an assertion against `msi` NOT the initial list of measurements
assert frame.get(results["msi"]) == MeasureIdTuple(
data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7)))
)
# msi2
assert frame.get(results["msi2"]) == MeasureIdTuple(
data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7)))
)
# ms_final
assert frame.get(results["ms_final"]) == MeasureIdTuple(
data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5))
)


def test_getitem_no_hint():
Expand Down