From e17f8755e85ac10c6c20428af8f8248ee8530450 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 7 Oct 2025 16:48:27 +0200 Subject: [PATCH] Support ZZPowGate when lowering from cirq --- src/bloqade/cirq_utils/lowering.py | 22 ++++++++++++- test/cirq_utils/test_cirq_to_squin.py | 45 +++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 5745027f..aa14bd03 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -169,7 +169,6 @@ def main(): | cirq.CSwapGate | cirq.XXPowGate | cirq.YYPowGate - | cirq.ZZPowGate | cirq.CCXPowGate | cirq.CCZPowGate ) @@ -523,6 +522,27 @@ def visit_CZPowGate( gate.stmts.CZ(controls=control_qarg, targets=target_qarg) ) + def visit_ZZPowGate( + self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation + ): + if node.gate.exponent % 2 == 0: + return + + qubit1, qubit2 = node.qubits + qarg1 = self.lower_qubit_getindices(state, (qubit1,)) + qarg2 = self.lower_qubit_getindices(state, (qubit2,)) + + if node.gate.exponent % 2 == 1: + state.current_frame.push(gate.stmts.X(qarg1)) + state.current_frame.push(gate.stmts.X(qarg2)) + return + + # NOTE: arbitrary exponent, write as CX * Rz * CX (up to global phase) + state.current_frame.push(gate.stmts.CX(qarg1, qarg2)) + angle = state.current_frame.push(py.Constant(0.5 * node.gate.exponent)) + state.current_frame.push(gate.stmts.Rz(angle.result, qarg2)) + state.current_frame.push(gate.stmts.CX(qarg1, qarg2)) + def visit_ControlledOperation( self, state: lowering.State[cirq.Circuit], node: cirq.ControlledOperation ): diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 6e0fc626..c72943b5 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -1,6 +1,7 @@ import math import cirq +import numpy as np import pytest from kirin import types from kirin.passes import inline @@ -414,3 +415,47 @@ def multi_arg(n: int, p: float): @pytest.mark.xfail def test_amplitude_damping(): test_circuit(amplitude_damping) + + +def test_trotter(): + + # NOTE: stolen from jonathan's tutorial + def trotter_layer( + qubits, dt: float = 0.01, J: float = 1, h: float = 1 + ) -> cirq.Circuit: + """ + Cirq builder function that returns a circuit of + a Trotter step of the 1D transverse Ising model + """ + op_zz = cirq.ZZ ** (dt * J / math.pi) + op_x = cirq.X ** (dt * h / math.pi) + circuit = cirq.Circuit() + for i in range(0, len(qubits) - 1, 2): + circuit.append(op_zz.on(qubits[i], qubits[i + 1])) + for i in range(1, len(qubits) - 1, 2): + circuit.append(op_zz.on(qubits[i], qubits[i + 1])) + for i in range(len(qubits)): + circuit.append(op_x.on(qubits[i])) + return circuit + + N = 4 + steps = 10 + dt = 0.01 + J = 1 + h = 1 + + qubits = cirq.LineQubit.range(N) + circuit = cirq.Circuit() + for _ in range(steps): + circuit += trotter_layer(qubits, dt, J, h) + + main = load_circuit(circuit) + + # actually run + cirq_statevector = cirq.Simulator().simulate(circuit).state_vector() + sim = DynamicMemorySimulator() + ket = sim.state_vector(main) + + assert math.isclose( + np.abs(np.dot(np.conj(ket), cirq_statevector)) ** 2, 1.0, abs_tol=1e-3 + )