1- import pytest
2- from kirin .passes import HintConst , inline
31from kirin .dialects import scf
2+ from kirin .passes .inline import InlinePass
43
54from bloqade import squin
65from bloqade .analysis .measure_id import MeasurementIDAnalysis
6+ from bloqade .stim .passes .flatten import Flatten
77from bloqade .analysis .measure_id .lattice import (
8+ NotMeasureId ,
89 MeasureIdBool ,
910 MeasureIdTuple ,
1011 InvalidMeasureId ,
@@ -15,7 +16,16 @@ def results_at(kern, block_id, stmt_id):
1516 return kern .code .body .blocks [block_id ].stmts .at (stmt_id ).results # type: ignore
1617
1718
18- @pytest .mark .xfail
19+ def results_of_variables (kernel , variable_names ):
20+ results = {}
21+ for stmt in kernel .callable_region .stmts ():
22+ for result in stmt .results :
23+ if result .name in variable_names :
24+ results [result .name ] = result
25+
26+ return results
27+
28+
1929def test_add ():
2030 @squin .kernel
2131 def test ():
@@ -28,6 +38,8 @@ def test():
2838 ml2 = squin .broadcast .measure (ql2 )
2939 return ml1 + ml2
3040
41+ Flatten (test .dialects ).fixpoint (test )
42+
3143 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
3244
3345 measure_id_tuples = [
@@ -41,7 +53,6 @@ def test():
4153 assert measure_id_tuples [- 1 ] == expected_measure_id_tuple
4254
4355
44- @pytest .mark .xfail
4556def test_measure_alias ():
4657
4758 @squin .kernel
@@ -52,28 +63,33 @@ def test():
5263
5364 return ml_alias
5465
66+ Flatten (test .dialects ).fixpoint (test )
5567 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
5668
57- test .print (analysis = frame .entries )
58-
5969 # Collect MeasureIdTuples
6070 measure_id_tuples = [
6171 value for value in frame .entries .values () if isinstance (value , MeasureIdTuple )
6272 ]
6373
64- # construct expected MeasureIdTuple
65- expected_measure_id_tuple = MeasureIdTuple (
74+ # construct expected MeasureIdTuples
75+ measure_id_tuple_with_id_bools = MeasureIdTuple (
6676 data = tuple ([MeasureIdBool (idx = i ) for i in range (1 , 6 )])
6777 )
78+ measure_id_tuple_with_not_measures = MeasureIdTuple (
79+ data = tuple ([NotMeasureId () for _ in range (5 )])
80+ )
6881
69- assert len (measure_id_tuples ) == 2
82+ assert len (measure_id_tuples ) == 3
83+ # New qubit.new semantics cause a MeasureIdTuple to be generated full of NotMeasureIds because
84+ # qubit.new is actually an ilist.map that invokes single qubit allocation multiple times
85+ # and puts them into an ilist.
86+ assert measure_id_tuples [0 ] == measure_id_tuple_with_not_measures
7087 assert all (
71- measure_id_tuple == expected_measure_id_tuple
72- for measure_id_tuple in measure_id_tuples
88+ measure_id_tuple == measure_id_tuple_with_id_bools
89+ for measure_id_tuple in measure_id_tuples [ 1 :]
7390 )
7491
7592
76- @pytest .mark .xfail
7793def test_measure_count_at_if_else ():
7894
7995 @squin .kernel
@@ -88,6 +104,7 @@ def test():
88104 if ms [3 ]:
89105 squin .y (q [1 ])
90106
107+ Flatten (test .dialects ).fixpoint (test )
91108 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
92109
93110 assert all (
@@ -96,32 +113,29 @@ def test():
96113 )
97114
98115
99- @pytest .mark .xfail
100116def test_scf_cond_true ():
101117 @squin .kernel
102118 def test ():
103- q = squin .qalloc (1 )
119+ q = squin .qalloc (3 )
104120 squin .x (q [2 ])
105121
106122 ms = None
107123 cond = True
108124 if cond :
109- ms = squin .broadcast . measure (q )
125+ ms = squin .measure (q [ 1 ] )
110126 else :
111127 ms = squin .measure (q [0 ])
112128
113129 return ms
114130
115- HintConst ( dialects = test .dialects ).unsafe_run (test )
131+ InlinePass ( test .dialects ).fixpoint (test )
116132 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
117133
118- # MeasureIdTuple(data= MeasureIdBool(idx=1), ) should occur twice:
134+ # MeasureIdBool(idx=1) should occur twice:
119135 # First from the measurement in the true branch, then
120136 # the result of the scf.IfElse itself
121137 analysis_results = [
122- val
123- for val in frame .entries .values ()
124- if val == MeasureIdTuple (data = (MeasureIdBool (idx = 1 ),))
138+ val for val in frame .entries .values () if val == MeasureIdBool (idx = 1 )
125139 ]
126140 assert len (analysis_results ) == 2
127141
@@ -136,16 +150,16 @@ def test():
136150 ms = None
137151 cond = False
138152 if cond :
139- ms = squin .broadcast . measure (q )
153+ ms = squin .measure (q [ 1 ] )
140154 else :
141- ms = squin .qubit . measure (q [0 ])
155+ ms = squin .measure (q [0 ])
142156
143157 return ms
144158
145- inline .InlinePass (test .dialects ).fixpoint (test )
146-
147- HintConst (dialects = test .dialects ).unsafe_run (test )
159+ # need to preserve the scf.IfElse but need things like qalloc to be inlined
160+ InlinePass (test .dialects ).fixpoint (test )
148161 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
162+ test .print (analysis = frame .entries )
149163
150164 # MeasureIdBool(idx=1) should occur twice:
151165 # First from the measurement in the false branch, then
@@ -156,7 +170,37 @@ def test():
156170 assert len (analysis_results ) == 2
157171
158172
159- @pytest .mark .xfail
173+ def test_scf_cond_unknown ():
174+
175+ @squin .kernel
176+ def test (cond : bool ):
177+ q = squin .qalloc (5 )
178+ squin .x (q [2 ])
179+
180+ if cond :
181+ ms = squin .broadcast .measure (q )
182+ else :
183+ ms = squin .measure (q [0 ])
184+
185+ return ms
186+
187+ # We can use Flatten here because the variable condition for the scf.IfElse
188+ # means it cannot be simplified.
189+ Flatten (test .dialects ).fixpoint (test )
190+ frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
191+ analysis_results = [
192+ val for val in frame .entries .values () if isinstance (val , MeasureIdTuple )
193+ ]
194+ # Both branches of the scf.IfElse should be properly traversed and contain the following
195+ # analysis results.
196+ expected_full_register_measurement = MeasureIdTuple (
197+ data = tuple ([MeasureIdBool (idx = i ) for i in range (1 , 6 )])
198+ )
199+ expected_else_measurement = MeasureIdTuple (data = (MeasureIdBool (idx = 6 ),))
200+ assert expected_full_register_measurement in analysis_results
201+ assert expected_else_measurement in analysis_results
202+
203+
160204def test_slice ():
161205 @squin .kernel
162206 def test ():
@@ -170,19 +214,23 @@ def test():
170214
171215 return ms_final
172216
217+ Flatten (test .dialects ).fixpoint (test )
173218 frame , _ = MeasurementIDAnalysis (test .dialects ).run_analysis (test )
174219
175- test . print ( analysis = frame . entries )
220+ results = results_of_variables ( test , ( "msi" , "msi2" , "ms_final" ) )
176221
177- assert [frame .entries [result ] for result in results_at (test , 0 , 7 )] == [
178- MeasureIdTuple (data = tuple (list (MeasureIdBool (idx = i ) for i in range (2 , 7 ))))
179- ]
180- assert [frame .entries [result ] for result in results_at (test , 0 , 9 )] == [
181- MeasureIdTuple (data = tuple (list (MeasureIdBool (idx = i ) for i in range (3 , 7 ))))
182- ]
183- assert [frame .entries [result ] for result in results_at (test , 0 , 11 )] == [
184- MeasureIdTuple (data = (MeasureIdBool (idx = 3 ), MeasureIdBool (idx = 5 )))
185- ]
222+ # This is an assertion against `msi` NOT the initial list of measurements
223+ assert frame .get (results ["msi" ]) == MeasureIdTuple (
224+ data = tuple (list (MeasureIdBool (idx = i ) for i in range (2 , 7 )))
225+ )
226+ # msi2
227+ assert frame .get (results ["msi2" ]) == MeasureIdTuple (
228+ data = tuple (list (MeasureIdBool (idx = i ) for i in range (3 , 7 )))
229+ )
230+ # ms_final
231+ assert frame .get (results ["ms_final" ]) == MeasureIdTuple (
232+ data = (MeasureIdBool (idx = 3 ), MeasureIdBool (idx = 5 ))
233+ )
186234
187235
188236def test_getitem_no_hint ():
0 commit comments