diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index c1b836a1ff1..74717d2a720 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1032,7 +1032,7 @@ def _with_rescoped_keys_( for moment in self.moments: new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys) moments.append(new_moment) - bindable_keys |= protocols.measurement_key_objs(new_moment) + bindable_keys |= new_moment.measurement_keys return self._from_moments(moments, tags=self.tags) def _qid_shape_(self) -> tuple[int, ...]: @@ -1273,10 +1273,7 @@ def to_text_diagram_drawer( """ qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) cbits = tuple( - sorted( - set(key for op in self.all_operations() for key in protocols.control_keys(op)), - key=str, - ) + sorted(set(key for op in self.all_operations() for key in op.control_keys), key=str) ) labels = qubits + cbits label_map = {labels[i]: i for i in range(len(labels))} @@ -1659,14 +1656,19 @@ def factorize(self) -> Iterable[Self]: for qubits in qubit_factors ) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: + @property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset().union(*(m.measurement_keys for m in self.moments)) + + @property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: measures: set[cirq.MeasurementKey] = set() controls: set[cirq.MeasurementKey] = set() for op in self.all_operations(): # Only require keys that haven't already been measured earlier - controls.update(k for k in protocols.control_keys(op) if k not in measures) + controls.update(op.control_keys - measures) # Record any measurement keys produced by this op - measures.update(protocols.measurement_key_objs(op)) + measures.update(op.measurement_keys) return frozenset(controls) @@ -2133,19 +2135,14 @@ def earliest_available_moment( end_moment_index = len(self.moments) last_available = end_moment_index k = end_moment_index - op_control_keys = protocols.control_keys(op) - op_measurement_keys = protocols.measurement_key_objs(op) - op_qubits = op.qubits while k > 0: k -= 1 moment = self._moments[k] - if moment.operates_on(op_qubits): - return last_available - moment_measurement_keys = moment._measurement_key_objs_() if ( - not op_measurement_keys.isdisjoint(moment_measurement_keys) - or not op_control_keys.isdisjoint(moment_measurement_keys) - or not moment._control_keys_().isdisjoint(op_measurement_keys) + moment.operates_on(op.qubits) + or not op.measurement_keys.isdisjoint(moment.measurement_keys) + or not op.control_keys.isdisjoint(moment.measurement_keys) + or not moment.control_keys.isdisjoint(op.measurement_keys) ): return last_available if self._can_add_op_at(k, op): @@ -2964,8 +2961,8 @@ def get_earliest_accommodating_moment_index( The integer index of the earliest moment that can accommodate the given moment or operation. """ mop_qubits = moment_or_operation.qubits - mop_mkeys = protocols.measurement_key_objs(moment_or_operation) - mop_ckeys = protocols.control_keys(moment_or_operation) + mop_mkeys = moment_or_operation.measurement_keys + mop_ckeys = moment_or_operation.control_keys if isinstance(moment_or_operation, Moment): # For consistency with `Circuit.append`, moments always get placed at the end of a circuit. diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 66d37c3898d..7762ba6031e 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -198,9 +198,7 @@ def __init__( if mapped_repeat_until: if self._use_repetition_ids or self._repetitions != 1: raise ValueError('Cannot use repetitions with repeat_until') - if protocols.measurement_key_objs(self._mapped_single_loop()).isdisjoint( - mapped_repeat_until.keys - ): + if self._mapped_single_loop().measurement_keys.isdisjoint(mapped_repeat_until.keys): raise ValueError('Infinite loop: condition is not modified in subcircuit.') @property @@ -310,8 +308,8 @@ def _ensure_deterministic_loop_count(self): raise ValueError('Cannot unroll circuit due to nondeterministic repetitions') @cached_property - def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: - circuit_keys = protocols.measurement_key_objs(self.circuit) + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + circuit_keys = self.circuit.measurement_keys if circuit_keys and self.use_repetition_ids: self._ensure_deterministic_loop_count() if self.repetition_ids is not None: @@ -328,27 +326,18 @@ def _measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: for key in circuit_keys ) - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - return self._measurement_key_objs - - def _measurement_key_names_(self) -> frozenset[str]: - return frozenset(str(key) for key in self._measurement_key_objs_()) - @cached_property - def _control_keys(self) -> frozenset[cirq.MeasurementKey]: + def control_keys(self) -> frozenset[cirq.MeasurementKey]: keys = ( frozenset() - if not protocols.control_keys(self.circuit) - else protocols.control_keys(self._mapped_single_loop()) + if not self.circuit.control_keys + else self._mapped_single_loop().control_keys ) mapped_repeat_until = self._mapped_repeat_until if mapped_repeat_until is not None: - keys |= frozenset(mapped_repeat_until.keys) - self._measurement_key_objs_() + keys |= frozenset(mapped_repeat_until.keys) - self.measurement_keys return keys - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - return self._control_keys - def _is_parameterized_(self) -> bool: return any(self._parameter_names_generator()) @@ -396,9 +385,7 @@ def _mapped_repeat_until(self) -> cirq.Condition | None: repeat_until, self.param_resolver, recursive=False ) return protocols.with_rescoped_keys( - repeat_until, - self.parent_path, - bindable_keys=self._extern_keys | self._measurement_key_objs, + repeat_until, self.parent_path, bindable_keys=self._extern_keys | self.measurement_keys ) def mapped_circuit(self, deep: bool = False) -> cirq.Circuit: diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 1f26fcab78c..37043ee2d85 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4915,6 +4915,7 @@ def test_create_speed() -> None: c = cirq.Circuit(ops) duration = time.perf_counter() - t assert len(c) == moments + print(duration) assert duration < 4 @@ -4937,6 +4938,7 @@ def test_append_speed() -> None: c.append(xs[q]) duration = time.perf_counter() - t assert len(c) == moments + print(duration) assert duration < 5 diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 91d312e15fc..d861967123c 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -156,10 +156,6 @@ def all_measurement_key_objs(self) -> frozenset[cirq.MeasurementKey]: def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: return self.all_measurement_key_objs() - @_compat.cached_method - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - return super()._control_keys_() - @_compat.cached_method def are_all_measurements_terminal(self) -> bool: return super().are_all_measurements_terminal() @@ -224,3 +220,11 @@ def to_op(self) -> cirq.CircuitOperation: from cirq.circuits import CircuitOperation return CircuitOperation(self) + + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return super().measurement_keys + + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return super().control_keys diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 3eaec877251..90e23f7899f 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -108,8 +108,6 @@ def __init__( raise ValueError(f'Overlapping operations: {self.operations}') self._qubit_to_op[q] = op - self._measurement_key_objs: frozenset[cirq.MeasurementKey] | None = None - self._control_keys: frozenset[cirq.MeasurementKey] | None = None self._tags = tags @classmethod @@ -222,10 +220,8 @@ def with_operation(self, operation: cirq.Operation) -> cirq.Moment: m._sorted_operations = None m._qubit_to_op = {**self._qubit_to_op, **{q: operation for q in operation.qubits}} - m._measurement_key_objs = self._measurement_key_objs_().union( - protocols.measurement_key_objs(operation) - ) - m._control_keys = self._control_keys_().union(protocols.control_keys(operation)) + m.__setattr__('measurement_keys', self.measurement_keys | operation.measurement_keys) + m.__setattr__('control_keys', self.control_keys | operation.control_keys) return m @@ -260,11 +256,12 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment: m._operations = self._operations + flattened_contents m._sorted_operations = None - m._measurement_key_objs = self._measurement_key_objs_().union( - set(itertools.chain(*(protocols.measurement_key_objs(op) for op in flattened_contents))) + m.__setattr__( + 'measurement_keys', + self.measurement_keys.union(*(op.measurement_keys for op in flattened_contents)), ) - m._control_keys = self._control_keys_().union( - set(itertools.chain(*(protocols.control_keys(op) for op in flattened_contents))) + m.__setattr__( + 'control_keys', self.control_keys.union(*(op.control_keys for op in flattened_contents)) ) return m @@ -323,23 +320,13 @@ def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): for op in self.operations ) - @_compat.cached_method() - def _measurement_key_names_(self) -> frozenset[str]: - return frozenset(str(key) for key in self._measurement_key_objs_()) - - def _measurement_key_objs_(self) -> frozenset[cirq.MeasurementKey]: - if self._measurement_key_objs is None: - self._measurement_key_objs = frozenset( - key for op in self.operations for key in protocols.measurement_key_objs(op) - ) - return self._measurement_key_objs + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset().union(*(op.measurement_keys for op in self.operations)) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - if self._control_keys is None: - self._control_keys = frozenset( - k for op in self.operations for k in protocols.control_keys(op) - ) - return self._control_keys + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset().union(*(op.control_keys for op in self.operations)) def _sorted_operations_(self) -> tuple[cirq.Operation, ...]: if self._sorted_operations is None: diff --git a/cirq-core/cirq/circuits/moment_test.py b/cirq-core/cirq/circuits/moment_test.py index 59d53194761..a530e1d43c9 100644 --- a/cirq-core/cirq/circuits/moment_test.py +++ b/cirq-core/cirq/circuits/moment_test.py @@ -413,19 +413,16 @@ def test_measurement_keys() -> None: def test_measurement_key_objs_caching() -> None: q0, q1, q2, q3 = cirq.LineQubit.range(4) m = cirq.Moment(cirq.measure(q0, key='foo')) - assert m._measurement_key_objs is None + assert m.measurement_keys == {cirq.MeasurementKey(name='foo')} key_objs = cirq.measurement_key_objs(m) - assert m._measurement_key_objs == key_objs + assert m.measurement_keys == key_objs # Make sure it gets updated when adding an operation. m = m.with_operation(cirq.measure(q1, key='bar')) - assert m._measurement_key_objs == { - cirq.MeasurementKey(name='bar'), - cirq.MeasurementKey(name='foo'), - } + assert m.measurement_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')} # Or multiple operations. m = m.with_operations(cirq.measure(q2, key='doh'), cirq.measure(q3, key='baz')) - assert m._measurement_key_objs == { + assert m.measurement_keys == { cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo'), cirq.MeasurementKey(name='doh'), @@ -436,18 +433,18 @@ def test_measurement_key_objs_caching() -> None: def test_control_keys_caching() -> None: q0, q1, q2, q3 = cirq.LineQubit.range(4) m = cirq.Moment(cirq.X(q0).with_classical_controls('foo')) - assert m._control_keys is None + assert m.control_keys == {cirq.MeasurementKey(name='foo')} keys = cirq.control_keys(m) - assert m._control_keys == keys + assert m.control_keys == keys # Make sure it gets updated when adding an operation. m = m.with_operation(cirq.X(q1).with_classical_controls('bar')) - assert m._control_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')} + assert m.control_keys == {cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo')} # Or multiple operations. m = m.with_operations( cirq.X(q2).with_classical_controls('doh'), cirq.X(q3).with_classical_controls('baz') ) - assert m._control_keys == { + assert m.control_keys == { cirq.MeasurementKey(name='bar'), cirq.MeasurementKey(name='foo'), cirq.MeasurementKey(name='doh'), diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index 06b9771fd1d..4bfbf3cc574 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence, Set +from functools import cached_property from typing import Any, TYPE_CHECKING import sympy @@ -90,7 +91,7 @@ def __init__( ValueError: If an unsupported gate is being classically controlled. """ - if protocols.measurement_key_objs(sub_operation): + if sub_operation.measurement_keys: raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) @@ -223,11 +224,10 @@ def _with_rescoped_keys_( sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) return sub_operation.with_classical_controls(*conds) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - local_keys: frozenset[cirq.MeasurementKey] = frozenset( - k for condition in self._conditions for k in condition.keys - ) - return local_keys.union(protocols.control_keys(self._sub_operation)) + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + local_keys = frozenset(k for condition in self._conditions for k in condition.keys) + return local_keys | self._sub_operation.control_keys def _qasm_(self, args: cirq.QasmArgs) -> str | None: args.validate_version('2.0', '3.0') diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 33929591a2f..17bce0d240e 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -20,6 +20,7 @@ import re import warnings from collections.abc import Collection, Mapping, Sequence, Set +from functools import cached_property from types import NotImplementedType from typing import Any, cast, Self, TYPE_CHECKING, TypeVar @@ -376,5 +377,9 @@ def controlled_by( control_qid_shape=tuple(q.dimension for q in qubits), ).on(*(qubits + self._qubits)) + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.gate.measurement_keys + TV = TypeVar('TV', bound=raw_types.Gate) diff --git a/cirq-core/cirq/ops/kraus_channel.py b/cirq-core/cirq/ops/kraus_channel.py index 315afcaae5d..3d129d8022f 100644 --- a/cirq-core/cirq/ops/kraus_channel.py +++ b/cirq-core/cirq/ops/kraus_channel.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping +from functools import cached_property from typing import Any, TYPE_CHECKING import numpy as np @@ -79,15 +80,11 @@ def num_qubits(self) -> int: def _kraus_(self): return self._kraus_ops - def _measurement_key_name_(self) -> str: + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: if self._key is None: - return NotImplemented - return str(self._key) - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - if self._key is None: - return NotImplemented - return self._key + return frozenset() + return frozenset([self._key]) def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): if self._key is None: diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index dde17ee3a91..43919cbf3fd 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property from typing import Any, TYPE_CHECKING import numpy as np @@ -171,11 +172,9 @@ def full_invert_mask(self) -> tuple[bool, ...]: def _is_measurement_(self) -> bool: return True - def _measurement_key_name_(self) -> str: - return self.key - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - return self.mkey + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset([self.mkey]) def _kraus_(self): size = np.prod(self._qid_shape, dtype=np.int64) diff --git a/cirq-core/cirq/ops/mixed_unitary_channel.py b/cirq-core/cirq/ops/mixed_unitary_channel.py index b70795efdeb..b6c36f3463c 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping +from functools import cached_property from typing import Any, TYPE_CHECKING import numpy as np @@ -84,15 +85,11 @@ def num_qubits(self) -> int: def _mixture_(self): return self._mixture - def _measurement_key_name_(self) -> str: + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: if self._key is None: - return NotImplemented - return str(self._key) - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - if self._key is None: - return NotImplemented - return self._key + return frozenset() + return frozenset([self._key]) def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): if self._key is None: diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index d87fee97710..eded5deaec2 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence +from functools import cached_property from typing import Any, cast, TYPE_CHECKING from cirq import protocols, value @@ -124,11 +125,9 @@ def with_observable( def _is_measurement_(self) -> bool: return True - def _measurement_key_name_(self) -> str: - return self.key - - def _measurement_key_obj_(self) -> cirq.MeasurementKey: - return self.mkey + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return frozenset([self.mkey]) def observable(self) -> cirq.DensePauliString: """Pauli observable which should be measured by the gate.""" diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index df6c4514c96..0b798a55c24 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -19,6 +19,7 @@ import abc import functools from collections.abc import Callable, Collection, Hashable, Iterable, Mapping, Sequence, Set +from functools import cached_property from types import NotImplementedType from typing import Any, cast, overload, TYPE_CHECKING @@ -444,6 +445,10 @@ def _qid_shape_(self) -> tuple[int, ...]: """ raise NotImplementedError + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.measurement_key_objs(self, _skip_property_check=True) + def _equal_up_to_global_phase_( self, other: Any, atol: float = 1e-8 ) -> NotImplementedType | bool: @@ -741,6 +746,14 @@ def without_classical_controls(self) -> cirq.Operation: """ return self + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.measurement_key_objs(self, _skip_property_check=True) + + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return protocols.control_keys(self, _skip_property_check=True) + @value.value_equality class TaggedOperation(Operation): @@ -973,8 +986,13 @@ def with_classical_controls(self, *conditions): return self return self.sub_operation.with_classical_controls(*conditions) - def _control_keys_(self) -> frozenset[cirq.MeasurementKey]: - return protocols.control_keys(self.sub_operation) + @cached_property + def measurement_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.sub_operation.measurement_keys + + @cached_property + def control_keys(self) -> frozenset[cirq.MeasurementKey]: + return self.sub_operation.control_keys @value.value_equality @@ -1091,10 +1109,10 @@ def _operations_commutes_impl( False: `ops1` and `ops2` do not commute. NotImplemented: The commutativity cannot be determined here. """ - ops1_keys = frozenset(k for op in ops1 for k in protocols.measurement_key_objs(op)) - ops2_keys = frozenset(k for op in ops2 for k in protocols.measurement_key_objs(op)) - ops1_control_keys = frozenset(k for op in ops1 for k in protocols.control_keys(op)) - ops2_control_keys = frozenset(k for op in ops2 for k in protocols.control_keys(op)) + ops1_keys = frozenset(k for op in ops1 for k in op.measurement_keys) + ops2_keys = frozenset(k for op in ops2 for k in op.measurement_keys) + ops1_control_keys = frozenset(k for op in ops1 for k in op.control_keys) + ops2_control_keys = frozenset(k for op in ops2 for k in op.control_keys) if ( not ops1_keys.isdisjoint(ops2_keys) or not ops1_control_keys.isdisjoint(ops2_keys) diff --git a/cirq-core/cirq/protocols/control_key_protocol.py b/cirq-core/cirq/protocols/control_key_protocol.py index a02af1508e9..045b64e9529 100644 --- a/cirq-core/cirq/protocols/control_key_protocol.py +++ b/cirq-core/cirq/protocols/control_key_protocol.py @@ -44,7 +44,7 @@ def _control_keys_(self) -> frozenset[cirq.MeasurementKey] | NotImplementedType """ -def control_keys(val: Any) -> frozenset[cirq.MeasurementKey]: +def control_keys(val: Any, _skip_property_check=False) -> frozenset[cirq.MeasurementKey]: """Gets the keys that the value is classically controlled by. Args: @@ -61,6 +61,11 @@ def control_keys(val: Any) -> frozenset[cirq.MeasurementKey]: the subcircuit are still required externally and thus appear in the result. """ + if not _skip_property_check: + attr = getattr(val, 'control_keys', None) + if attr is not None: + return attr + getter = getattr(val, '_control_keys_', None) result = NotImplemented if getter is None else getter() if result is not NotImplemented and result is not None: diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index a141e759107..74442bdf896 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -197,10 +197,14 @@ def measurement_key_name(val, default=RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( - val: Any, + val: Any, _skip_property_check=False ) -> frozenset[cirq.MeasurementKey] | NotImplementedType | None: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" + if not _skip_property_check: + attr = getattr(val, 'measurement_keys', None) + if attr is not None: + return attr getter = getattr(val, '_measurement_key_objs_', None) result = NotImplemented if getter is None else getter() @@ -232,7 +236,7 @@ def _measurement_key_names_from_magic_methods( return result -def measurement_key_objs(val: Any) -> frozenset[cirq.MeasurementKey]: +def measurement_key_objs(val: Any, _skip_property_check=False) -> frozenset[cirq.MeasurementKey]: """Gets the measurement key objects of measurements within the given value. Args: @@ -242,7 +246,7 @@ def measurement_key_objs(val: Any) -> frozenset[cirq.MeasurementKey]: The measurement key objects of the value. If the value has no measurement, the result is the empty set. """ - result = _measurement_key_objs_from_magic_methods(val) + result = _measurement_key_objs_from_magic_methods(val, _skip_property_check) if result is not NotImplemented and result is not None: return result key_strings = _measurement_key_names_from_magic_methods(val) diff --git a/cirq-core/cirq/transformers/stratify.py b/cirq-core/cirq/transformers/stratify.py index f6d84096d3b..9e0bff77105 100644 --- a/cirq-core/cirq/transformers/stratify.py +++ b/cirq-core/cirq/transformers/stratify.py @@ -168,9 +168,9 @@ def _stratify_circuit( # Update qubit, measurement key, and control key moments. for qubit in op.qubits: qubit_time_index[qubit] = time_index - for key in protocols.measurement_key_objs(op): + for key in op.measurement_keys: measurement_time_index[key] = time_index - for key in protocols.control_keys(op): + for key in op.control_keys: control_time_index[key] = time_index return circuits.Circuit(circuits.Moment(moment) for moment in new_moments if moment) diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py index 3c4abc39918..783dd161765 100644 --- a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py @@ -50,11 +50,11 @@ def find_terminal_measurements(circuit: cirq.AbstractCircuit) -> list[tuple[int, op is not None and open_qubits.issuperset(op.qubits) and protocols.is_measurement(op) - and not (seen_control_keys & protocols.measurement_key_objs(op)) + and not (seen_control_keys & op.measurement_keys) ): terminal_measurements.add((i, op)) open_qubits -= moment.qubits - seen_control_keys |= protocols.control_keys(moment) + seen_control_keys |= moment.control_keys if not open_qubits: break return list(terminal_measurements)