Skip to content

Commit db67b8b

Browse files
johnzl-777david-pl
andauthored
Get measure_id tests to work (#587)
Co-authored-by: David Plankensteiner <david-pl@users.noreply.github.com> Co-authored-by: David Plankensteiner <da.plankensteiner@gmail.com>
1 parent 34e9679 commit db67b8b

File tree

1 file changed

+84
-36
lines changed

1 file changed

+84
-36
lines changed

test/analysis/measure_id/test_measure_id.py

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import pytest
2-
from kirin.passes import HintConst, inline
31
from kirin.dialects import scf
2+
from kirin.passes.inline import InlinePass
43

54
from bloqade import squin
65
from bloqade.analysis.measure_id import MeasurementIDAnalysis
6+
from bloqade.stim.passes.flatten import Flatten
77
from bloqade.analysis.measure_id.lattice import (
8+
NotMeasureId,
89
MeasureIdBool,
910
MeasureIdTuple,
1011
InvalidMeasureId,
@@ -15,7 +16,16 @@ def results_at(kern, block_id, stmt_id):
1516
return kern.code.body.blocks[block_id].stmts.at(stmt_id).results # type: ignore
1617

1718

18-
@pytest.mark.xfail
19+
def results_of_variables(kernel, variable_names):
20+
results = {}
21+
for stmt in kernel.callable_region.stmts():
22+
for result in stmt.results:
23+
if result.name in variable_names:
24+
results[result.name] = result
25+
26+
return results
27+
28+
1929
def test_add():
2030
@squin.kernel
2131
def test():
@@ -28,6 +38,8 @@ def test():
2838
ml2 = squin.broadcast.measure(ql2)
2939
return ml1 + ml2
3040

41+
Flatten(test.dialects).fixpoint(test)
42+
3143
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
3244

3345
measure_id_tuples = [
@@ -41,7 +53,6 @@ def test():
4153
assert measure_id_tuples[-1] == expected_measure_id_tuple
4254

4355

44-
@pytest.mark.xfail
4556
def test_measure_alias():
4657

4758
@squin.kernel
@@ -52,28 +63,33 @@ def test():
5263

5364
return ml_alias
5465

66+
Flatten(test.dialects).fixpoint(test)
5567
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
5668

57-
test.print(analysis=frame.entries)
58-
5969
# Collect MeasureIdTuples
6070
measure_id_tuples = [
6171
value for value in frame.entries.values() if isinstance(value, MeasureIdTuple)
6272
]
6373

64-
# construct expected MeasureIdTuple
65-
expected_measure_id_tuple = MeasureIdTuple(
74+
# construct expected MeasureIdTuples
75+
measure_id_tuple_with_id_bools = MeasureIdTuple(
6676
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
6777
)
78+
measure_id_tuple_with_not_measures = MeasureIdTuple(
79+
data=tuple([NotMeasureId() for _ in range(5)])
80+
)
6881

69-
assert len(measure_id_tuples) == 2
82+
assert len(measure_id_tuples) == 3
83+
# New qubit.new semantics cause a MeasureIdTuple to be generated full of NotMeasureIds because
84+
# qubit.new is actually an ilist.map that invokes single qubit allocation multiple times
85+
# and puts them into an ilist.
86+
assert measure_id_tuples[0] == measure_id_tuple_with_not_measures
7087
assert all(
71-
measure_id_tuple == expected_measure_id_tuple
72-
for measure_id_tuple in measure_id_tuples
88+
measure_id_tuple == measure_id_tuple_with_id_bools
89+
for measure_id_tuple in measure_id_tuples[1:]
7390
)
7491

7592

76-
@pytest.mark.xfail
7793
def test_measure_count_at_if_else():
7894

7995
@squin.kernel
@@ -88,6 +104,7 @@ def test():
88104
if ms[3]:
89105
squin.y(q[1])
90106

107+
Flatten(test.dialects).fixpoint(test)
91108
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
92109

93110
assert all(
@@ -96,32 +113,29 @@ def test():
96113
)
97114

98115

99-
@pytest.mark.xfail
100116
def test_scf_cond_true():
101117
@squin.kernel
102118
def test():
103-
q = squin.qalloc(1)
119+
q = squin.qalloc(3)
104120
squin.x(q[2])
105121

106122
ms = None
107123
cond = True
108124
if cond:
109-
ms = squin.broadcast.measure(q)
125+
ms = squin.measure(q[1])
110126
else:
111127
ms = squin.measure(q[0])
112128

113129
return ms
114130

115-
HintConst(dialects=test.dialects).unsafe_run(test)
131+
InlinePass(test.dialects).fixpoint(test)
116132
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
117133

118-
# MeasureIdTuple(data=MeasureIdBool(idx=1),) should occur twice:
134+
# MeasureIdBool(idx=1) should occur twice:
119135
# First from the measurement in the true branch, then
120136
# the result of the scf.IfElse itself
121137
analysis_results = [
122-
val
123-
for val in frame.entries.values()
124-
if val == MeasureIdTuple(data=(MeasureIdBool(idx=1),))
138+
val for val in frame.entries.values() if val == MeasureIdBool(idx=1)
125139
]
126140
assert len(analysis_results) == 2
127141

@@ -136,16 +150,16 @@ def test():
136150
ms = None
137151
cond = False
138152
if cond:
139-
ms = squin.broadcast.measure(q)
153+
ms = squin.measure(q[1])
140154
else:
141-
ms = squin.qubit.measure(q[0])
155+
ms = squin.measure(q[0])
142156

143157
return ms
144158

145-
inline.InlinePass(test.dialects).fixpoint(test)
146-
147-
HintConst(dialects=test.dialects).unsafe_run(test)
159+
# need to preserve the scf.IfElse but need things like qalloc to be inlined
160+
InlinePass(test.dialects).fixpoint(test)
148161
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
162+
test.print(analysis=frame.entries)
149163

150164
# MeasureIdBool(idx=1) should occur twice:
151165
# First from the measurement in the false branch, then
@@ -156,7 +170,37 @@ def test():
156170
assert len(analysis_results) == 2
157171

158172

159-
@pytest.mark.xfail
173+
def test_scf_cond_unknown():
174+
175+
@squin.kernel
176+
def test(cond: bool):
177+
q = squin.qalloc(5)
178+
squin.x(q[2])
179+
180+
if cond:
181+
ms = squin.broadcast.measure(q)
182+
else:
183+
ms = squin.measure(q[0])
184+
185+
return ms
186+
187+
# We can use Flatten here because the variable condition for the scf.IfElse
188+
# means it cannot be simplified.
189+
Flatten(test.dialects).fixpoint(test)
190+
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
191+
analysis_results = [
192+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
193+
]
194+
# Both branches of the scf.IfElse should be properly traversed and contain the following
195+
# analysis results.
196+
expected_full_register_measurement = MeasureIdTuple(
197+
data=tuple([MeasureIdBool(idx=i) for i in range(1, 6)])
198+
)
199+
expected_else_measurement = MeasureIdTuple(data=(MeasureIdBool(idx=6),))
200+
assert expected_full_register_measurement in analysis_results
201+
assert expected_else_measurement in analysis_results
202+
203+
160204
def test_slice():
161205
@squin.kernel
162206
def test():
@@ -170,19 +214,23 @@ def test():
170214

171215
return ms_final
172216

217+
Flatten(test.dialects).fixpoint(test)
173218
frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test)
174219

175-
test.print(analysis=frame.entries)
220+
results = results_of_variables(test, ("msi", "msi2", "ms_final"))
176221

177-
assert [frame.entries[result] for result in results_at(test, 0, 7)] == [
178-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7))))
179-
]
180-
assert [frame.entries[result] for result in results_at(test, 0, 9)] == [
181-
MeasureIdTuple(data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7))))
182-
]
183-
assert [frame.entries[result] for result in results_at(test, 0, 11)] == [
184-
MeasureIdTuple(data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5)))
185-
]
222+
# This is an assertion against `msi` NOT the initial list of measurements
223+
assert frame.get(results["msi"]) == MeasureIdTuple(
224+
data=tuple(list(MeasureIdBool(idx=i) for i in range(2, 7)))
225+
)
226+
# msi2
227+
assert frame.get(results["msi2"]) == MeasureIdTuple(
228+
data=tuple(list(MeasureIdBool(idx=i) for i in range(3, 7)))
229+
)
230+
# ms_final
231+
assert frame.get(results["ms_final"]) == MeasureIdTuple(
232+
data=(MeasureIdBool(idx=3), MeasureIdBool(idx=5))
233+
)
186234

187235

188236
def test_getitem_no_hint():

0 commit comments

Comments
 (0)