Skip to content
33 changes: 32 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Iterator, Sequence, Tuple, Union
from typing import Dict, Iterator, Sequence, Tuple, Union

import attrs
import cirq
Expand Down Expand Up @@ -137,5 +137,36 @@
yield self.target_gate(target[target_idx]).controlled_by(control)
yield cirq.CZ(*accumulator, target[target_idx])

def on_classical_vals(self, **vals ) -> Dict[str, 'ClassicalValT']:
if self.target_gate != cirq.X:
return NotImplemented
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return NotImplemented
control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']
if control:
max_selection = self.selection_registers[0].dtype.iteration_length - 1

Check failure on line 151 in qualtran/bloqs/multiplexers/selected_majorana_fermion.py

View workflow job for this annotation

GitHub Actions / mypy

"QCDType[Any]" has no attribute "iteration_length"; maybe "iteration_length_or_zero"? [attr-defined]
target = (2**(max_selection - selection)) ^ target
return {control_name: control, selection_name: selection, 'target': target}

def basis_state_phase(self, **vals ) -> Union[complex, None]:
if self.target_gate != cirq.X:
return None
if len(self.control_registers) > 1 or len(self.selection_registers) > 1:
return None
control_name = self.control_registers[0].name
control = vals[control_name]
selection_name = self.selection_registers[0].name
selection = vals[selection_name]
target = vals['target']
if control:
max_selection = self.selection_registers[0].dtype.iteration_length - 1

Check failure on line 166 in qualtran/bloqs/multiplexers/selected_majorana_fermion.py

View workflow job for this annotation

GitHub Actions / mypy

"QCDType[Any]" has no attribute "iteration_length"; maybe "iteration_length_or_zero"? [attr-defined]
num_phases = (target >> (max_selection - selection + 1)).bit_count()
return 1 if (num_phases % 2) == 0 else -1
return 1

def __str__(self):
return f'SelectedMajoranaFermion({self.target_gate})'
14 changes: 13 additions & 1 deletion qualtran/bloqs/multiplexers/selected_majorana_fermion_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -20,7 +20,7 @@
from qualtran._infra.gate_with_registers import get_named_qubits, total_bits
from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion
from qualtran.cirq_interop.testing import GateHelper
from qualtran.testing import assert_valid_bloq_decomposition
from qualtran.testing import assert_valid_bloq_decomposition, assert_consistent_phased_classical_action


@pytest.mark.slow
Expand Down Expand Up @@ -148,3 +148,15 @@
op = gate.on_registers(**get_named_qubits(gate.signature))
op2 = SelectedMajoranaFermion.make_on(target_gate=cirq.X, **get_named_qubits(gate.signature))
assert op == op2

@pytest.mark.parametrize("selection_bitsize, target_bitsize", [(2, 4), (3, 5)])
def test_selected_majorana_fermion_classical_action(selection_bitsize, target_bitsize):
gate = SelectedMajoranaFermion(
Register('selection', BQUInt(selection_bitsize, target_bitsize)), target_gate=cirq.X
)
assert_consistent_phased_classical_action(
gate,
selection=range(target_bitsize),
target=range(2**target_bitsize),
control=range(2)
)
26 changes: 26 additions & 0 deletions qualtran/testing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -40,6 +40,7 @@
)
from qualtran._infra.composite_bloq import _get_flat_dangling_soqs
from qualtran.symbolics import is_symbolic
from qualtran.simulation.classical_sim import do_phased_classical_simulation

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
Expand Down Expand Up @@ -716,3 +717,28 @@
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)

def assert_consistent_phased_classical_action(
bloq: Bloq,
**parameter_ranges: Union[NDArray, Sequence[int], Sequence[Union[Sequence[int], NDArray]]],
):
"""Check that the bloq has a phased classical action consistent with its decomposition.

Args:
bloq: bloq to test.
parameter_ranges: named arguments giving ranges for each of the registers of the bloq.
"""
cb = bloq.decompose_bloq()
parameter_names = tuple(parameter_ranges.keys())
for vals in itertools.product(*[parameter_ranges[p] for p in parameter_names]):
call_with = {p: v for p, v in zip(parameter_names, vals)}
bloq_res, bloq_phase = do_phased_classical_simulation(bloq, call_with)
decomposed_res, decomposed_phase = do_phased_classical_simulation(cb, call_with)
np.testing.assert_equal(
bloq_res, decomposed_res, err_msg=f'{bloq=} {call_with=} {bloq_res=} {decomposed_res=}'
)
np.testing.assert_equal(
bloq_phase,
decomposed_phase,
err_msg=f'{bloq=} {call_with=} {bloq_phase=} {decomposed_phase=}'
)
Loading