Skip to content

Commit 27bfb4e

Browse files
authored
Introduce new methods of HEOMsolve_map for advanced usage (#195)
1 parent 0466625 commit 27bfb4e

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

src/evolution.jl

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,11 @@ function HEOMsolve_map(
418418

419419
# mapping initial states and parameters
420420
ados_iter = map-> _gen_ados_ode_vector(ρ, M), ρ0)
421-
iter =
422-
params isa NullParameters ? collect(Iterators.product(ados_iter, [params])) :
423-
collect(Iterators.product(ados_iter, params...))
424-
ntraj = length(iter)
421+
if params isa NullParameters
422+
iter = collect(Iterators.product(ados_iter, [params])) |> vec # convert nx1 Matrix into Vector
423+
else
424+
iter = collect(Iterators.product(ados_iter, params...))
425+
end
425426

426427
# we disable the progress bar of the HEOMsolveProblem because we use a global progress bar for all the trajectories
427428
prob = HEOMsolveProblem(
@@ -435,34 +436,48 @@ function HEOMsolve_map(
435436
kwargs...,
436437
)
437438

438-
# generate and solve ensemble problem
439-
_output_func = _ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) # setup global progress bar
439+
return HEOMsolve_map(prob, iter, alg, ensemblealg; progress_bar = progress_bar)
440+
end
441+
HEOMsolve_map(
442+
M::AbstractHEOMLSMatrix,
443+
ρ0::T_state,
444+
tlist::AbstractVector;
445+
kwargs...,
446+
) where {T_state<:Union{QuantumObject,ADOs}} = HEOMsolve_map(M, [ρ0], tlist; kwargs...)
447+
448+
# this method is for advanced usage
449+
# User can define their own iterator structure, prob_func and output_func
450+
# - `prob_func`: Function to use for generating the ODEProblem.
451+
# - `output_func`: a `Tuple` containing the `Function` to use for generating the output of a single trajectory, the (optional) `ProgressBar` object, and the (optional) `RemoteChannel` object.
452+
#
453+
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
454+
function HEOMsolve_map(
455+
prob::TimeEvolutionProblem{<:ODEProblem},
456+
iter::AbstractArray,
457+
alg::OrdinaryDiffEqAlgorithm = DP5(),
458+
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
459+
prob_func::Union{Function,Nothing} = nothing,
460+
output_func::Union{Tuple,Nothing} = nothing,
461+
progress_bar::Union{Val,Bool} = Val(true),
462+
)
463+
# generate ensemble problem
464+
ntraj = length(iter)
465+
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
466+
_output_func =
467+
isnothing(output_func) ?
468+
_ensemble_dispatch_output_func(ensemblealg, progress_bar, ntraj, _standard_output_func) : output_func
440469
ens_prob = TimeEvolutionProblem(
441-
EnsembleProblem(
442-
prob.prob,
443-
prob_func = (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter),
444-
output_func = _output_func[1],
445-
safetycopy = false,
446-
),
470+
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
447471
prob.times,
448472
prob.dimensions,
449473
(progr = _output_func[2], channel = _output_func[3]),
450474
)
475+
451476
sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)
452477

453478
# handle solution and make it become an Array of TimeEvolutionHEOMSol
454-
sol_vec = [_gen_HEOMsolve_solution(sol[:, i], prob.times, M) for i in eachindex(sol)] # map is type unstable
455-
if params isa NullParameters # if no parameters specified, just return a Vector
456-
return sol_vec
457-
else
458-
return reshape(sol_vec, size(iter))
459-
end
479+
sol_vec = [_gen_HEOMsolve_solution(sol[:, i], prob.times, prob.kwargs.M) for i in eachindex(sol)] # map is type unstable
480+
return reshape(sol_vec, size(iter))
460481
end
461-
HEOMsolve_map(
462-
M::AbstractHEOMLSMatrix,
463-
ρ0::T_state,
464-
tlist::AbstractVector;
465-
kwargs...,
466-
) where {T_state<:Union{QuantumObject,ADOs}} = HEOMsolve_map(M, [ρ0], tlist; kwargs...)
467482

468483
const heomsolve_map = HEOMsolve_map # a synonym to align with QuantumToolbox.jl

0 commit comments

Comments
 (0)