Skip to content

Commit beb0f37

Browse files
apkilleKrastanov
andauthored
Broadcasting for semiclassical objects (#404)
Co-authored-by: Stefan Krastanov <github.acc@krastanov.org> Co-authored-by: Stefan Krastanov <stefan@krastanov.org>
1 parent f730f1a commit beb0f37

File tree

6 files changed

+148
-16
lines changed

6 files changed

+148
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ RecursiveArrayTools = "2, 3"
3636
Reexport = "0.2, 1.0"
3737
StochasticDiffEq = "6"
3838
WignerSymbols = "1, 2"
39-
julia = "1.3"
39+
julia = "1.10"
4040

4141
[extras]
4242
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/QuantumOptics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module QuantumOptics
33
using Reexport
44
@reexport using QuantumOpticsBase
55
using SparseArrays, LinearAlgebra
6+
import RecursiveArrayTools
67

78
export
89
ylm,

src/semiclassical.jl

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module semiclassical
22

33
using QuantumOpticsBase
4-
import Base: ==
4+
import QuantumOpticsBase: IncompatibleBases
5+
import Base: ==, isapprox, +, -, *, /
56
import ..timeevolution: integrate, recast!, jump, integrate_mcwf, jump_callback,
67
JumpRNGState, threshold, roll!, as_vector, QO_CHECKS
78
import LinearAlgebra: normalize, normalize!
9+
import RecursiveArrayTools
810

911
using Random, LinearAlgebra
1012
import OrdinaryDiffEq
@@ -31,26 +33,104 @@ mutable struct State{B,T,C}
3133
new{B,T,C}(quantum, classical)
3234
end
3335
end
34-
35-
Base.length(state::State) = length(state.quantum) + length(state.classical)
36-
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
37-
Base.eltype(state::State) = promote_type(eltype(state.quantum),eltype(state.classical))
38-
normalize!(state::State) = (normalize!(state.quantum); state)
39-
normalize(state::State) = State(normalize(state.quantum),copy(state.classical))
40-
41-
function ==(a::State, b::State)
42-
QuantumOpticsBase.samebases(a.quantum, b.quantum) &&
43-
length(a.classical)==length(b.classical) &&
44-
(a.classical==b.classical) &&
45-
(a.quantum==b.quantum)
46-
end
36+
State{B}(q::T, c::C) where {B,T<:QuantumState{B},C} = State(q,c)
37+
38+
# Standard interfaces
39+
Base.zero(x::State) = State(zero(x.quantum), zero(x.classical))
40+
Base.length(x::State) = length(x.quantum) + length(x.classical)
41+
Base.axes(x::State) = (Base.OneTo(length(x)),)
42+
Base.size(x::State) = size(x.quantum)
43+
Base.ndims(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = ndims(T)
44+
Base.copy(x::State) = State(copy(x.quantum), copy(x.classical))
45+
Base.copyto!(x::State, y::State) = (copyto!(x.quantum, y.quantum); copyto!(x.classical, y.classical); x)
46+
Base.fill!(x::State, a) = (fill!(x.quantum, a), fill!(x.classical, a))
47+
Base.eltype(x::State) = promote_type(eltype(x.quantum),eltype(x.classical))
48+
Base.eltype(x::Type{<:State{B,T,C}}) where {B,T<:QuantumState{B},C} = promote_type(eltype(T), eltype(C))
49+
Base.similar(x::State, ::Type{T} = eltype(x)) where {T} = State(similar(x.quantum, T), similar(x.classical, T))
50+
Base.getindex(x::State, idx) = idx <= length(x.quantum) ? getindex(x.quantum, idx) : getindex(x.classical, idx-length(x.quantum))
51+
52+
normalize!(x::State) = (normalize!(x.quantum); x)
53+
normalize(x::State) = State(normalize(x.quantum),copy(x.classical))
54+
LinearAlgebra.norm(x::State) = LinearAlgebra.norm(x.quantum)
55+
56+
==(x::State{B}, y::State{B}) where {B} = (x.classical==y.classical) && (x.quantum==y.quantum)
57+
==(x::State, y::State) = false
58+
59+
isapprox(x::State{B}, y::State{B}; kwargs...) where {B} = isapprox(x.quantum,y.quantum; kwargs...) && isapprox(x.classical,y.classical; kwargs...)
60+
isapprox(x::State, y::State; kwargs...) = false
4761

4862
QuantumOpticsBase.expect(op, state::State) = expect(op, state.quantum)
4963
QuantumOpticsBase.variance(op, state::State) = variance(op, state.quantum)
5064
QuantumOpticsBase.ptrace(state::State, indices) = State(ptrace(state.quantum, indices), state.classical)
51-
5265
QuantumOpticsBase.dm(x::State) = State(dm(x.quantum), x.classical)
5366

67+
Base.broadcastable(x::State) = x
68+
69+
# Custom broadcasting style
70+
struct StateStyle{B} <: Broadcast.BroadcastStyle end
71+
72+
# Style precedence rules
73+
Broadcast.BroadcastStyle(::Type{<:State{B}}) where {B} = StateStyle{B}()
74+
Broadcast.BroadcastStyle(::StateStyle{B1}, ::StateStyle{B2}) where {B1,B2} = throw(IncompatibleBases())
75+
Broadcast.BroadcastStyle(::StateStyle{B}, ::Broadcast.DefaultArrayStyle{0}) where {B} = StateStyle{B}()
76+
Broadcast.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::StateStyle{B}) where {B} = StateStyle{B}()
77+
78+
# Out-of-place broadcasting
79+
@inline function Base.copy(bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
80+
bcf = Broadcast.flatten(bc)
81+
# extract quantum object from broadcast container
82+
qobj = find_quantum(bcf)
83+
data_q = zeros(eltype(qobj), size(qobj)...)
84+
Nq = length(qobj)
85+
# allocate quantum data from broadcast container
86+
@inbounds @simd for I in 1:Nq
87+
data_q[I] = bcf[I]
88+
end
89+
# extract classical object from broadcast container
90+
cobj = find_classical(bcf)
91+
data_c = zeros(eltype(cobj), length(cobj))
92+
Nc = length(cobj)
93+
# allocate classical data from broadcast container
94+
@inbounds @simd for I in 1:Nc
95+
data_c[I] = bcf[I+Nq]
96+
end
97+
type = eval(nameof(typeof(qobj)))
98+
return State{B}(type(basis(qobj), data_q), data_c)
99+
end
100+
101+
for f [:find_quantum, :find_classical]
102+
@eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args)
103+
@eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args))
104+
@eval ($f)(x) = x
105+
@eval ($f)(::Any, rest) = ($f)(rest)
106+
end
107+
find_quantum(x::State, rest) = x.quantum
108+
find_classical(x::State, rest) = x.classical
109+
110+
# In-place broadcasting
111+
@inline function Base.copyto!(dest::State{B}, bc::Broadcast.Broadcasted{<:StateStyle{B},Axes,F,Args}) where {B,Axes,F,Args<:Tuple}
112+
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
113+
bc′ = Base.Broadcast.preprocess(dest, bc)
114+
# write broadcasted quantum data to dest
115+
qobj = dest.quantum
116+
@inbounds @simd for I in 1:length(qobj)
117+
qobj.data[I] = bc′[I]
118+
end
119+
# write broadcasted classical data to dest
120+
cobj = dest.classical
121+
@inbounds @simd for I in 1:length(cobj)
122+
cobj[I] = bc′[I+length(qobj)]
123+
end
124+
return dest
125+
end
126+
@inline Base.copyto!(dest::State{B1}, bc::Broadcast.Broadcasted{<:StateStyle{B2},Axes,F,Args}) where {B1,B2,Axes,F,Args<:Tuple} =
127+
throw(IncompatibleBases())
128+
129+
Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::State, i) = Base.getindex(x, i)
130+
RecursiveArrayTools.recursive_unitless_bottom_eltype(x::State) = eltype(x)
131+
RecursiveArrayTools.recursivecopy!(dest::State, src::State) = copyto!(dest, src)
132+
RecursiveArrayTools.recursivecopy(x::State) = copy(x)
133+
RecursiveArrayTools.recursivefill!(x::State, a) = fill!(x, a)
54134

55135
"""
56136
semiclassical.schroedinger_dynamic(tspan, state0, fquantum, fclassical[; fout, ...])

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ names = [
2323

2424
"test_timeevolution_abstractdata.jl",
2525

26+
"test_sciml_broadcast_interfaces.jl",
2627
"test_ForwardDiff.jl"
2728
]
2829

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Test
2+
using QuantumOptics
3+
using OrdinaryDiffEq
4+
5+
@testset "sciml interface" begin
6+
7+
# semiclassical ODE problem
8+
b = SpinBasis(1//2)
9+
psi0 = spindown(b)
10+
u0 = ComplexF64[0.5, 0.75]
11+
sc = semiclassical.State(psi0, u0)
12+
t₀, t₁ = (0.0, pi)
13+
σx = sigmax(b)
14+
15+
fquantum(t, q, u) = σx + cos(u[1])*identityoperator(σx)
16+
fclassical!(du, u, q, t) = (du[1] = sin(u[2]); du[2] = 2*u[1])
17+
f!(dstate, state, p, t) = semiclassical.dschroedinger_dynamic!(dstate, fquantum, fclassical!, state, t)
18+
prob = ODEProblem(f!, sc, (t₀, t₁))
19+
20+
sol = solve(prob, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false)
21+
tout, ψt = semiclassical.schroedinger_dynamic([t₀, t₁], sc, fquantum, fclassical!; reltol = 1.0e-8, abstol = 1.0e-10)
22+
23+
@test sol[end] ψt[end]
24+
25+
end

test/test_semiclassical.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using QuantumOptics
33
using LinearAlgebra
4+
using QuantumOpticsBase: IncompatibleBases
45

56
@testset "semiclassical" begin
67

@@ -175,4 +176,28 @@ after_jump = findlast(t-> !(t∈T), tout4)
175176
@test ψt4[before_jump].quantum == ψ0.quantum
176177
@test ψt4[after_jump].quantum == spindown(ba)fockstate(bf,0)
177178

179+
# Test broadcasting interface
180+
b = FockBasis(10)
181+
bn = FockBasis(20)
182+
u0 = ComplexF64[0.7, 0.2]
183+
psi = fockstate(b, 2)
184+
psin = fockstate(bn, 2)
185+
rho = dm(psi)
186+
187+
sc_ket = semiclassical.State(psi, u0)
188+
sc_ketn = semiclassical.State(psin, u0)
189+
sc_dm = semiclassical.State(rho, u0)
190+
191+
@test Base.size(sc_dm) == Base.size(rho)
192+
@test (copy_sc = copy(sc_ket); Base.fill!(copy_sc, 0.0); copy_sc) == semiclassical.State(fill!(copy(psi), 0.0), fill!(copy(u0), 0.0))
193+
@test Base.similar(sc_ket, Int) isa semiclassical.State
194+
@test normalize!(copy(sc_ket)) == semiclassical.State(normalize!(copy(psi)), u0)
195+
@test !(sc_ket == sc_ketn)
196+
@test !(isapprox(sc_ket, sc_ketn))
197+
@test sc_ket .* 1.0 == sc_ket
198+
@test sc_dm .* 1.0 == sc_dm
199+
@test sc_ket .+ 2.0 == semiclassical.State(psi .+ 2.0, u0 .+ 2.0)
200+
@test sc_dm .+ 2.0 == semiclassical.State(rho .+ 2.0, u0 .+ 2.0)
201+
@test_throws IncompatibleBases sc_ket .+ semiclassical.State(spinup(SpinBasis(10)), u0)
202+
178203
end # testsets

0 commit comments

Comments
 (0)