Skip to content

Commit d16d59c

Browse files
committed
return problem feature added
1 parent beb0f37 commit d16d59c

8 files changed

+2311
-707
lines changed

Manifest.toml

Lines changed: 1560 additions & 0 deletions
Large diffs are not rendered by default.

src/mcwf.jl

Lines changed: 98 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ Calculate MCWF trajectory where the Hamiltonian is given in hermitian form.
1313
For more information see: [`mcwf`](@ref)
1414
"""
1515
function mcwf_h(tspan, psi0::Ket, H::AbstractOperator, J;
16-
seed=rand(UInt), rates=nothing,
17-
fout=nothing, Jdagger=dagger.(J),
18-
tmp=copy(psi0),
19-
display_beforeevent=false, display_afterevent=false,
20-
kwargs...)
16+
seed=rand(UInt), rates=nothing,
17+
fout=nothing, Jdagger=dagger.(J),
18+
tmp=copy(psi0),
19+
display_beforeevent=false, display_afterevent=false,
20+
kwargs...)
2121
_check_const(H)
2222
_check_const.(J)
2323
_check_const.(Jdagger)
@@ -47,9 +47,9 @@ H_{nh} = H - \\frac{i}{2} \\sum_k J^†_k J_k
4747
For more information see: [`mcwf`](@ref)
4848
"""
4949
function mcwf_nh(tspan, psi0::Ket, Hnh::AbstractOperator, J;
50-
seed=rand(UInt), fout=nothing,
51-
display_beforeevent=false, display_afterevent=false,
52-
kwargs...)
50+
seed=rand(UInt), fout=nothing,
51+
display_beforeevent=false, display_afterevent=false,
52+
kwargs...)
5353
_check_const(Hnh)
5454
_check_const.(J)
5555
check_mcwf(psi0, Hnh, J, J, nothing)
@@ -107,10 +107,10 @@ of the jump operators with which the jump occured, respectively.
107107
* `kwargs...`: Further arguments are passed on to the ode solver.
108108
"""
109109
function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
110-
seed=rand(UInt), rates=nothing,
111-
fout=nothing, Jdagger=dagger.(J),
112-
display_beforeevent=false, display_afterevent=false,
113-
kwargs...)
110+
seed=rand(UInt), rates=nothing,
111+
fout=nothing, Jdagger=dagger.(J),
112+
display_beforeevent=false, display_afterevent=false,
113+
kwargs...)
114114
_check_const(H)
115115
_check_const.(J)
116116
_check_const.(Jdagger)
@@ -132,12 +132,12 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
132132
else
133133
Hnh = copy(H)
134134
if isa(rates, Nothing)
135-
for i=1:length(J)
136-
Hnh -= complex(float(eltype(H)))(0.5im)*Jdagger[i]*J[i]
135+
for i = 1:length(J)
136+
Hnh -= complex(float(eltype(H)))(0.5im) * Jdagger[i] * J[i]
137137
end
138138
else
139-
for i=1:length(J)
140-
Hnh -= complex(float(eltype(H)))(0.5im*rates[i])*Jdagger[i]*J[i]
139+
for i = 1:length(J)
140+
Hnh -= complex(float(eltype(H)))(0.5im * rates[i]) * Jdagger[i] * J[i]
141141
end
142142
end
143143
dmcwf_nh_ = let Hnh = Hnh # Hnh type often not inferrable
@@ -322,25 +322,26 @@ Integrate a single Monte Carlo wave function trajectory.
322322
* `kwargs`: Further arguments are passed on to the ode solver.
323323
"""
324324
function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
325-
psi0, seed, fout;
326-
display_beforeevent=false, display_afterevent=false,
327-
display_jumps=false,
328-
rng_state=nothing,
329-
save_everystep=false, callback=nothing,
330-
saveat=tspan,
331-
alg=OrdinaryDiffEq.DP5(),
332-
kwargs...) where {T, J}
325+
psi0, seed, fout;
326+
display_beforeevent=false, display_afterevent=false,
327+
display_jumps=false,
328+
rng_state=nothing,
329+
save_everystep=false, callback=nothing,
330+
saveat=tspan,
331+
alg=OrdinaryDiffEq.DP5(),
332+
return_problem=false,
333+
kwargs...) where {T,J}
333334

334335
tspan_ = convert(Vector{float(eltype(tspan))}, tspan)
335336
# Display before or after events
336-
function save_func!(affect!,integrator)
337+
function save_func!(affect!, integrator)
337338
affect!.saveiter += 1
338339
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
339340
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
340-
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
341+
affect!.save_func(integrator.u, integrator.t, integrator), Val{false})
341342
return nothing
342343
end
343-
no_save_func!(affect!,integrator) = nothing
344+
no_save_func!(affect!, integrator) = nothing
344345
save_before! = display_beforeevent ? save_func! : no_save_func!
345346
save_after! = display_afterevent ? save_func! : no_save_func!
346347

@@ -349,8 +350,8 @@ function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
349350
jump_index = Int[]
350351

351352
function jump_saver(t, i)
352-
push!(jump_t,t)
353-
push!(jump_index,i)
353+
push!(jump_t, t)
354+
push!(jump_index, i)
354355
return nothing
355356
end
356357
no_jump_saver(t, i) = nothing
@@ -362,51 +363,59 @@ function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
362363

363364
fout_ = let state = state, fout = fout
364365
function fout_(x, t, integrator)
365-
recast!(state,x)
366+
recast!(state, x)
366367
fout(t, state)
367368
end
368369
end
369370

370371
out_type = pure_inference(fout, Tuple{eltype(tspan_),typeof(state)})
371-
out = DiffEqCallbacks.SavedValues(eltype(tspan_),out_type)
372-
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan_,
373-
save_everystep=save_everystep,
374-
save_start = false)
372+
out = DiffEqCallbacks.SavedValues(eltype(tspan_), out_type)
373+
scb = DiffEqCallbacks.SavingCallback(fout_, out, saveat=tspan_,
374+
save_everystep=save_everystep,
375+
save_start=false)
375376

376377
cb = jump_callback(jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0, rng_state)
377-
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)
378+
full_cb = OrdinaryDiffEq.CallbackSet(callback, cb, scb)
378379

379380
df_ = let state = state, dstate = dstate # help inference along
380381
function df_(dx, x, p, t)
381-
recast!(state,x)
382-
recast!(dstate,dx)
382+
recast!(state, x)
383+
recast!(dstate, dx)
383384
dmcwf(t, state, dstate)
384-
recast!(dx,dstate)
385+
recast!(dx, dstate)
385386
return nothing
386387
end
387388
end
388389

389-
prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0), (tspan_[1],tspan_[end]))
390-
391-
sol = OrdinaryDiffEq.solve(
392-
prob,
393-
alg;
394-
reltol = 1.0e-6,
395-
abstol = 1.0e-8,
396-
save_everystep = false, save_start = false,
397-
save_end = false,
398-
callback=full_cb, kwargs...)
390+
prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0), (tspan_[1], tspan_[end]))
399391

400-
if display_jumps
401-
return out.t, out.saveval, jump_t, jump_index
392+
if return_problem
393+
if display_jumps
394+
return Dict("out" => out, "jump_t" => jump_t, "jump_index" => jump_index, "prob" => prob, "alg" => alg, "solve_kwargs" => (reltol=1.0e-6, abstol=1.0e-8, save_everystep=false, save_start=false, save_end=false, callback=full_cb, kwargs...))
395+
else
396+
return Dict("out" => out, "prob" => prob, "alg" => alg, "solve_kwargs" => (reltol=1.0e-6, abstol=1.0e-8, save_everystep=false, save_start=false, save_end=false, callback=full_cb, kwargs...))
397+
end
402398
else
403-
return out.t, out.saveval
399+
sol = OrdinaryDiffEq.solve(
400+
prob,
401+
alg;
402+
reltol=1.0e-6,
403+
abstol=1.0e-8,
404+
save_everystep=false, save_start=false,
405+
save_end=false,
406+
callback=full_cb, kwargs...)
407+
408+
if display_jumps
409+
return out.t, out.saveval, jump_t, jump_index
410+
else
411+
return out.t, out.saveval
412+
end
404413
end
405414
end
406415

407416
function integrate_mcwf(dmcwf, jumpfun, tspan,
408-
psi0, seed, fout::Nothing;
409-
kwargs...)
417+
psi0, seed, fout::Nothing;
418+
kwargs...)
410419
function fout_(t, x)
411420
return normalize(x)
412421
end
@@ -427,43 +436,43 @@ mutable struct JumpRNGState{T<:Real,R<:AbstractRNG}
427436
rng::R
428437
threshold::T
429438
end
430-
function JumpRNGState(::Type{T}, seed) where T
439+
function JumpRNGState(::Type{T}, seed) where {T}
431440
rng = MersenneTwister(seed)
432441
threshold = rand(rng, T)
433442
JumpRNGState(rng, threshold)
434443
end
435-
roll!(s::JumpRNGState{T}) where T = (s.threshold = rand(s.rng, T))
444+
roll!(s::JumpRNGState{T}) where {T} = (s.threshold = rand(s.rng, T))
436445
threshold(s::JumpRNGState) = s.threshold
437446

438447
function jump_callback(jumpfun::F, seed, scb, save_before!::G,
439-
save_after!::H, save_t_index::I, psi0, rng_state::JumpRNGState) where {F,G,H,I}
448+
save_after!::H, save_t_index::I, psi0, rng_state::JumpRNGState) where {F,G,H,I}
440449

441450
tmp = copy(psi0)
442451
psi_tmp = copy(psi0)
443452

444-
djumpnorm(x, t, integrator) = norm(x)^2 - (1-threshold(rng_state))
453+
djumpnorm(x, t, integrator) = norm(x)^2 - (1 - threshold(rng_state))
445454

446455
function dojump(integrator)
447456
x = integrator.u
448457
t = integrator.t
449458

450459
affect! = scb.affect!
451-
save_before!(affect!,integrator)
452-
recast!(psi_tmp,x)
460+
save_before!(affect!, integrator)
461+
recast!(psi_tmp, x)
453462
i = jumpfun(rng_state.rng, t, psi_tmp, tmp)
454463
x .= tmp.data
455-
save_after!(affect!,integrator)
456-
save_t_index(t,i)
464+
save_after!(affect!, integrator)
465+
save_t_index(t, i)
457466

458467
roll!(rng_state)
459468
return nothing
460469
end
461470

462-
return OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
463-
save_positions = (false,false))
471+
return OrdinaryDiffEq.ContinuousCallback(djumpnorm, dojump,
472+
save_positions=(false, false))
464473
end
465474
jump_callback(jumpfun, seed, scb, save_before!,
466-
save_after!, save_t_index, psi0, ::Nothing) =
475+
save_after!, save_t_index, psi0, ::Nothing) =
467476
jump_callback(jumpfun, seed, scb, save_before!,
468477
save_after!, save_t_index, psi0, JumpRNGState(real(eltype(psi0)), seed))
469478

@@ -483,13 +492,13 @@ Default jump function.
483492
* `probs_tmp`: Temporary array for holding jump probailities.
484493
"""
485494
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::Nothing)
486-
if length(J)==1
487-
QuantumOpticsBase.mul!(psi_new,J[1],psi,true,false)
495+
if length(J) == 1
496+
QuantumOpticsBase.mul!(psi_new, J[1], psi, true, false)
488497
psi_new.data ./= norm(psi_new)
489-
i=1
498+
i = 1
490499
else
491-
for i=1:length(J)
492-
QuantumOpticsBase.mul!(psi_new,J[i],psi,true,false)
500+
for i = 1:length(J)
501+
QuantumOpticsBase.mul!(psi_new, J[i], psi, true, false)
493502
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
494503
end
495504
r = rand(rng)
@@ -501,19 +510,19 @@ function jump(rng, t, psi, J, psi_new, probs_tmp, rates::Nothing)
501510
cumulative_prob += p / total
502511
cumulative_prob > r && break
503512
end
504-
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(1/sqrt(probs_tmp[i])),zero(eltype(psi)))
513+
QuantumOpticsBase.mul!(psi_new, J[i], psi, eltype(psi)(1 / sqrt(probs_tmp[i])), zero(eltype(psi)))
505514
end
506515
return i
507516
end
508517

509518
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::AbstractVector)
510-
if length(J)==1
511-
QuantumOpticsBase.mul!(psi_new,J[1],psi,eltype(psi)(sqrt(rates[1])),zero(eltype(psi)))
519+
if length(J) == 1
520+
QuantumOpticsBase.mul!(psi_new, J[1], psi, eltype(psi)(sqrt(rates[1])), zero(eltype(psi)))
512521
psi_new.data ./= norm(psi_new)
513-
i=1
522+
i = 1
514523
else
515-
for i=1:length(J)
516-
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i])),zero(eltype(psi)))
524+
for i = 1:length(J)
525+
QuantumOpticsBase.mul!(psi_new, J[i], psi, eltype(psi)(sqrt(rates[i])), zero(eltype(psi)))
517526
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
518527
end
519528
r = rand(rng)
@@ -525,7 +534,7 @@ function jump(rng, t, psi, J, psi_new, probs_tmp, rates::AbstractVector)
525534
cumulative_prob += p / total
526535
cumulative_prob > r && break
527536
end
528-
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i]/probs_tmp[i])),zero(eltype(psi)))
537+
QuantumOpticsBase.mul!(psi_new, J[i], psi, eltype(psi)(sqrt(rates[i] / probs_tmp[i])), zero(eltype(psi)))
529538
end
530539
return i
531540
end
@@ -540,19 +549,19 @@ the jump operators J.
540549
See also: [`mcwf`](@ref)
541550
"""
542551
function dmcwf_h!(dpsi, H, J, Jdagger, rates::Nothing, psi, dpsi_cache)
543-
QuantumOpticsBase.mul!(dpsi,H,psi,eltype(psi)(-im),zero(eltype(psi)))
544-
for i=1:length(J)
545-
QuantumOpticsBase.mul!(dpsi_cache,J[i],psi,true,false)
546-
QuantumOpticsBase.mul!(dpsi,Jdagger[i],dpsi_cache,eltype(psi)(-0.5),one(eltype(psi)))
552+
QuantumOpticsBase.mul!(dpsi, H, psi, eltype(psi)(-im), zero(eltype(psi)))
553+
for i = 1:length(J)
554+
QuantumOpticsBase.mul!(dpsi_cache, J[i], psi, true, false)
555+
QuantumOpticsBase.mul!(dpsi, Jdagger[i], dpsi_cache, eltype(psi)(-0.5), one(eltype(psi)))
547556
end
548557
return dpsi
549558
end
550559

551560
function dmcwf_h!(dpsi, H, J, Jdagger, rates::AbstractVector, psi, dpsi_cache)
552-
QuantumOpticsBase.mul!(dpsi,H,psi,eltype(psi)(-im),zero(eltype(psi)))
553-
for i=1:length(J)
554-
QuantumOpticsBase.mul!(dpsi_cache,J[i],psi,eltype(psi)(rates[i]),zero(eltype(psi)))
555-
QuantumOpticsBase.mul!(dpsi,Jdagger[i],dpsi_cache,eltype(psi)(-0.5),one(eltype(psi)))
561+
QuantumOpticsBase.mul!(dpsi, H, psi, eltype(psi)(-im), zero(eltype(psi)))
562+
for i = 1:length(J)
563+
QuantumOpticsBase.mul!(dpsi_cache, J[i], psi, eltype(psi)(rates[i]), zero(eltype(psi)))
564+
QuantumOpticsBase.mul!(dpsi, Jdagger[i], dpsi_cache, eltype(psi)(-0.5), one(eltype(psi)))
556565
end
557566
return dpsi
558567
end
@@ -568,13 +577,13 @@ function check_mcwf(psi0, H, J, Jdagger, rates)
568577
if !(isa(H, DenseOpType) || isa(H, SparseOpType))
569578
isreducible = false
570579
end
571-
for j=J
580+
for j = J
572581
@assert isa(j, AbstractOperator)
573582
if !(isa(j, DenseOpType) || isa(j, SparseOpType))
574583
isreducible = false
575584
end
576585
end
577-
for j=Jdagger
586+
for j = Jdagger
578587
@assert isa(j, AbstractOperator)
579588
if !(isa(j, DenseOpType) || isa(j, SparseOpType))
580589
isreducible = false
@@ -605,11 +614,11 @@ corresponding set of jump operators is calculated.
605614
function diagonaljumps(rates::AbstractMatrix, J)
606615
@assert length(J) == size(rates)[1] == size(rates)[2]
607616
d, v = eigen(rates)
608-
d, [sum([v[j, i]*J[j] for j=1:length(d)]) for i=1:length(d)]
617+
d, [sum([v[j, i] * J[j] for j = 1:length(d)]) for i = 1:length(d)]
609618
end
610619

611-
function diagonaljumps(rates::AbstractMatrix, J::Vector{T}) where T<:Union{LazySum,LazyTensor,LazyProduct}
620+
function diagonaljumps(rates::AbstractMatrix, J::Vector{T}) where {T<:Union{LazySum,LazyTensor,LazyProduct}}
612621
@assert length(J) == size(rates)[1] == size(rates)[2]
613622
d, v = eigen(rates)
614-
d, [LazySum([v[j, i]*J[j] for j=1:length(d)]...) for i=1:length(d)]
623+
d, [LazySum([v[j, i] * J[j] for j = 1:length(d)]...) for i = 1:length(d)]
615624
end

0 commit comments

Comments
 (0)