diff --git a/src/bloqade/cirq_utils/noise/model.py b/src/bloqade/cirq_utils/noise/model.py index 4e3c064d..ab694f0f 100644 --- a/src/bloqade/cirq_utils/noise/model.py +++ b/src/bloqade/cirq_utils/noise/model.py @@ -106,7 +106,7 @@ def validate_moments(moments: Iterable[cirq.Moment]): continue gate = operation.gate - for allowed_family in allowed_target_gates: + for allowed_family in set(allowed_target_gates).union({cirq.GateFamily(gate=cirq.ops.common_channels.ResetChannel, ignore_global_phase=True)}): if gate in allowed_family: break else: @@ -246,14 +246,18 @@ def noisy_moment(self, moment, system_qubits): original_moment = moment # Check if the moment is empty - if len(moment.operations) == 0: + if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]): move_noise_ops = [] gate_noise_ops = [] # Check if the moment contains 1-qubit gates or 2-qubit gates elif len(moment.operations[0].qubits) == 1: - gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops( - moment, system_qubits - ) + if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])): + move_noise_ops = [] + gate_noise_ops = [] + else: + gate_noise_ops, move_noise_ops = self._single_qubit_moment_noise_ops( + moment, system_qubits + ) elif len(moment.operations[0].qubits) == 2: control_qubits = [op.qubits[0] for op in moment.operations] target_qubits = [op.qubits[1] for op in moment.operations] @@ -319,20 +323,26 @@ def noisy_moments( # Split into moments with only 1Q and 2Q gates moments_1q = [ - cirq.Moment([op for op in moment.operations if len(op.qubits) == 1]) + cirq.Moment([op for op in moment.operations if (len(op.qubits) == 1) and (not cirq.is_measurement(op)) and (not isinstance(op.gate, cirq.ResetChannel))]) for moment in moments ] moments_2q = [ - cirq.Moment([op for op in moment.operations if len(op.qubits) == 2]) + cirq.Moment([op for op in moment.operations if (len(op.qubits) == 2) and (not cirq.is_measurement(op))]) for moment in moments ] - assert len(moments_1q) == len(moments_2q) + moments_measurement = [ + cirq.Moment([op for op in moment.operations if (cirq.is_measurement(op)) or (isinstance(op.gate, cirq.ResetChannel))]) + for moment in moments + ] + + assert len(moments_1q) == len(moments_2q) == len(moments_measurement) interleaved_moments = [] for idx, moment in enumerate(moments_1q): interleaved_moments.append(moment) interleaved_moments.append(moments_2q[idx]) + interleaved_moments.append(moments_measurement[idx]) interleaved_circuit = cirq.Circuit.from_moments(*interleaved_moments) @@ -368,14 +378,17 @@ def noisy_moment(self, moment, system_qubits): "all qubits in the circuit must be defined as cirq.GridQubit objects." ) # Check if the moment is empty - if len(moment.operations) == 0: + if len(moment.operations) == 0 or cirq.is_measurement(moment.operations[0]): move_moments = [] gate_noise_ops = [] # Check if the moment contains 1-qubit gates or 2-qubit gates elif len(moment.operations[0].qubits) == 1: - gate_noise_ops, _ = self._single_qubit_moment_noise_ops( - moment, system_qubits - ) + if (isinstance(moment.operations[0].gate, cirq.ResetChannel)) or (cirq.is_measurement(moment.operations[0])): + gate_noise_ops = [] + else: + gate_noise_ops, _ = self._single_qubit_moment_noise_ops( + moment, system_qubits + ) move_moments = [] elif len(moment.operations[0].qubits) == 2: cg = OneZoneConflictGraph(moment)