Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions src/bloqade/cirq_utils/noise/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def validate_moments(moments: Iterable[cirq.Moment]):
continue

gate = operation.gate
for allowed_family in allowed_target_gates:
for allowed_family in set(allowed_target_gates).union({cirq.GateFamily(gate=cirq.ops.common_channels.ResetChannel, ignore_global_phase=True)}):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just add that to the allowed gates with e.g.

reset_family = cirq.GateFamily(gate=cirq.ResetChannel, ignore_global_phase=True)
allowed_target_gates: frozenset[cirq.GateFamily] = cirq.CZTargetGateset(additional_gates=[reset_family]).gates

Your approach here works as well though.

if gate in allowed_family:
break
else:
Expand Down Expand Up @@ -246,14 +246,18 @@ def noisy_moment(self, moment, system_qubits):
original_moment = moment

# Check if the moment is empty
if len(moment.operations) == 0:
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
move_noise_ops = []
gate_noise_ops = []
# Check if the moment contains 1-qubit gates or 2-qubit gates
elif len(moment.operations[0].qubits) == 1:
gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])):
move_noise_ops = []
gate_noise_ops = []
else:
gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
elif len(moment.operations[0].qubits) == 2:
control_qubits = [op.qubits[0] for op in moment.operations]
target_qubits = [op.qubits[1] for op in moment.operations]
Expand Down Expand Up @@ -319,20 +323,26 @@ def noisy_moments(

# Split into moments with only 1Q and 2Q gates
moments_1q = [
cirq.Moment([op for op in moment.operations if len(op.qubits) == 1])
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 1) and (not cirq.is_measurement(op)) and (not isinstance(op.gate, cirq.ResetChannel))])
for moment in moments
]
moments_2q = [
cirq.Moment([op for op in moment.operations if len(op.qubits) == 2])
cirq.Moment([op for op in moment.operations if (len(op.qubits) == 2) and (not cirq.is_measurement(op))])
for moment in moments
]

assert len(moments_1q) == len(moments_2q)
moments_measurement = [
cirq.Moment([op for op in moment.operations if (cirq.is_measurement(op)) or (isinstance(op.gate, cirq.ResetChannel))])
for moment in moments
]

assert len(moments_1q) == len(moments_2q) == len(moments_measurement)

interleaved_moments = []
for idx, moment in enumerate(moments_1q):
interleaved_moments.append(moment)
interleaved_moments.append(moments_2q[idx])
interleaved_moments.append(moments_measurement[idx])

interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments)

Expand Down Expand Up @@ -368,14 +378,17 @@ def noisy_moment(self, moment, system_qubits):
"all qubits in the circuit must be defined as cirq.GridQubit objects."
)
# Check if the moment is empty
if len(moment.operations) == 0:
if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]):
move_moments = []
gate_noise_ops = []
# Check if the moment contains 1-qubit gates or 2-qubit gates
elif len(moment.operations[0].qubits) == 1:
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])):
gate_noise_ops = []
else:
gate_noise_ops, _ = self._single_qubit_moment_noise_ops(
moment, system_qubits
)
move_moments = []
elif len(moment.operations[0].qubits) == 2:
cg = OneZoneConflictGraph(moment)
Expand Down
Loading