@@ -13,11 +13,11 @@ Calculate MCWF trajectory where the Hamiltonian is given in hermitian form.
13
13
For more information see: [`mcwf`](@ref)
14
14
"""
15
15
function mcwf_h (tspan, psi0:: Ket , H:: AbstractOperator , J;
16
- seed= rand (UInt), rates= nothing ,
17
- fout= nothing , Jdagger= dagger .(J),
18
- tmp= copy (psi0),
19
- display_beforeevent= false , display_afterevent= false ,
20
- kwargs... )
16
+ seed= rand (UInt), rates= nothing ,
17
+ fout= nothing , Jdagger= dagger .(J),
18
+ tmp= copy (psi0),
19
+ display_beforeevent= false , display_afterevent= false ,
20
+ kwargs... )
21
21
_check_const (H)
22
22
_check_const .(J)
23
23
_check_const .(Jdagger)
@@ -47,9 +47,9 @@ H_{nh} = H - \\frac{i}{2} \\sum_k J^†_k J_k
47
47
For more information see: [`mcwf`](@ref)
48
48
"""
49
49
function mcwf_nh (tspan, psi0:: Ket , Hnh:: AbstractOperator , J;
50
- seed= rand (UInt), fout= nothing ,
51
- display_beforeevent= false , display_afterevent= false ,
52
- kwargs... )
50
+ seed= rand (UInt), fout= nothing ,
51
+ display_beforeevent= false , display_afterevent= false ,
52
+ kwargs... )
53
53
_check_const (Hnh)
54
54
_check_const .(J)
55
55
check_mcwf (psi0, Hnh, J, J, nothing )
@@ -107,10 +107,10 @@ of the jump operators with which the jump occured, respectively.
107
107
* `kwargs...`: Further arguments are passed on to the ode solver.
108
108
"""
109
109
function mcwf (tspan, psi0:: Ket , H:: AbstractOperator , J;
110
- seed= rand (UInt), rates= nothing ,
111
- fout= nothing , Jdagger= dagger .(J),
112
- display_beforeevent= false , display_afterevent= false ,
113
- kwargs... )
110
+ seed= rand (UInt), rates= nothing ,
111
+ fout= nothing , Jdagger= dagger .(J),
112
+ display_beforeevent= false , display_afterevent= false ,
113
+ kwargs... )
114
114
_check_const (H)
115
115
_check_const .(J)
116
116
_check_const .(Jdagger)
@@ -132,12 +132,12 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
132
132
else
133
133
Hnh = copy (H)
134
134
if isa (rates, Nothing)
135
- for i= 1 : length (J)
136
- Hnh -= complex (float (eltype (H)))(0.5im )* Jdagger[i]* J[i]
135
+ for i = 1 : length (J)
136
+ Hnh -= complex (float (eltype (H)))(0.5im ) * Jdagger[i] * J[i]
137
137
end
138
138
else
139
- for i= 1 : length (J)
140
- Hnh -= complex (float (eltype (H)))(0.5im * rates[i])* Jdagger[i]* J[i]
139
+ for i = 1 : length (J)
140
+ Hnh -= complex (float (eltype (H)))(0.5im * rates[i]) * Jdagger[i] * J[i]
141
141
end
142
142
end
143
143
dmcwf_nh_ = let Hnh = Hnh # Hnh type often not inferrable
@@ -322,25 +322,26 @@ Integrate a single Monte Carlo wave function trajectory.
322
322
* `kwargs`: Further arguments are passed on to the ode solver.
323
323
"""
324
324
function integrate_mcwf (dmcwf:: T , jumpfun:: J , tspan,
325
- psi0, seed, fout;
326
- display_beforeevent= false , display_afterevent= false ,
327
- display_jumps= false ,
328
- rng_state= nothing ,
329
- save_everystep= false , callback= nothing ,
330
- saveat= tspan,
331
- alg= OrdinaryDiffEq. DP5 (),
332
- kwargs... ) where {T, J}
325
+ psi0, seed, fout;
326
+ display_beforeevent= false , display_afterevent= false ,
327
+ display_jumps= false ,
328
+ rng_state= nothing ,
329
+ save_everystep= false , callback= nothing ,
330
+ saveat= tspan,
331
+ alg= OrdinaryDiffEq. DP5 (),
332
+ return_problem= false ,
333
+ kwargs... ) where {T,J}
333
334
334
335
tspan_ = convert (Vector{float (eltype (tspan))}, tspan)
335
336
# Display before or after events
336
- function save_func! (affect!,integrator)
337
+ function save_func! (affect!, integrator)
337
338
affect!. saveiter += 1
338
339
copyat_or_push! (affect!. saved_values. t, affect!. saveiter, integrator. t)
339
340
copyat_or_push! (affect!. saved_values. saveval, affect!. saveiter,
340
- affect!. save_func (integrator. u, integrator. t, integrator),Val{false })
341
+ affect!. save_func (integrator. u, integrator. t, integrator), Val{false })
341
342
return nothing
342
343
end
343
- no_save_func! (affect!,integrator) = nothing
344
+ no_save_func! (affect!, integrator) = nothing
344
345
save_before! = display_beforeevent ? save_func! : no_save_func!
345
346
save_after! = display_afterevent ? save_func! : no_save_func!
346
347
@@ -349,8 +350,8 @@ function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
349
350
jump_index = Int[]
350
351
351
352
function jump_saver (t, i)
352
- push! (jump_t,t)
353
- push! (jump_index,i)
353
+ push! (jump_t, t)
354
+ push! (jump_index, i)
354
355
return nothing
355
356
end
356
357
no_jump_saver (t, i) = nothing
@@ -362,51 +363,59 @@ function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
362
363
363
364
fout_ = let state = state, fout = fout
364
365
function fout_ (x, t, integrator)
365
- recast! (state,x)
366
+ recast! (state, x)
366
367
fout (t, state)
367
368
end
368
369
end
369
370
370
371
out_type = pure_inference (fout, Tuple{eltype (tspan_),typeof (state)})
371
- out = DiffEqCallbacks. SavedValues (eltype (tspan_),out_type)
372
- scb = DiffEqCallbacks. SavingCallback (fout_,out,saveat= tspan_,
373
- save_everystep= save_everystep,
374
- save_start = false )
372
+ out = DiffEqCallbacks. SavedValues (eltype (tspan_), out_type)
373
+ scb = DiffEqCallbacks. SavingCallback (fout_, out, saveat= tspan_,
374
+ save_everystep= save_everystep,
375
+ save_start= false )
375
376
376
377
cb = jump_callback (jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0, rng_state)
377
- full_cb = OrdinaryDiffEq. CallbackSet (callback,cb,scb)
378
+ full_cb = OrdinaryDiffEq. CallbackSet (callback, cb, scb)
378
379
379
380
df_ = let state = state, dstate = dstate # help inference along
380
381
function df_ (dx, x, p, t)
381
- recast! (state,x)
382
- recast! (dstate,dx)
382
+ recast! (state, x)
383
+ recast! (dstate, dx)
383
384
dmcwf (t, state, dstate)
384
- recast! (dx,dstate)
385
+ recast! (dx, dstate)
385
386
return nothing
386
387
end
387
388
end
388
389
389
- prob = OrdinaryDiffEq. ODEProblem {true} (df_, as_vector (psi0), (tspan_[1 ],tspan_[end ]))
390
-
391
- sol = OrdinaryDiffEq. solve (
392
- prob,
393
- alg;
394
- reltol = 1.0e-6 ,
395
- abstol = 1.0e-8 ,
396
- save_everystep = false , save_start = false ,
397
- save_end = false ,
398
- callback= full_cb, kwargs... )
390
+ prob = OrdinaryDiffEq. ODEProblem {true} (df_, as_vector (psi0), (tspan_[1 ], tspan_[end ]))
399
391
400
- if display_jumps
401
- return out. t, out. saveval, jump_t, jump_index
392
+ if return_problem
393
+ if display_jumps
394
+ return Dict (" out" => out, " jump_t" => jump_t, " jump_index" => jump_index, " prob" => prob, " alg" => alg, " solve_kwargs" => (reltol= 1.0e-6 , abstol= 1.0e-8 , save_everystep= false , save_start= false , save_end= false , callback= full_cb, kwargs... ))
395
+ else
396
+ return Dict (" out" => out, " prob" => prob, " alg" => alg, " solve_kwargs" => (reltol= 1.0e-6 , abstol= 1.0e-8 , save_everystep= false , save_start= false , save_end= false , callback= full_cb, kwargs... ))
397
+ end
402
398
else
403
- return out. t, out. saveval
399
+ sol = OrdinaryDiffEq. solve (
400
+ prob,
401
+ alg;
402
+ reltol= 1.0e-6 ,
403
+ abstol= 1.0e-8 ,
404
+ save_everystep= false , save_start= false ,
405
+ save_end= false ,
406
+ callback= full_cb, kwargs... )
407
+
408
+ if display_jumps
409
+ return out. t, out. saveval, jump_t, jump_index
410
+ else
411
+ return out. t, out. saveval
412
+ end
404
413
end
405
414
end
406
415
407
416
function integrate_mcwf (dmcwf, jumpfun, tspan,
408
- psi0, seed, fout:: Nothing ;
409
- kwargs... )
417
+ psi0, seed, fout:: Nothing ;
418
+ kwargs... )
410
419
function fout_ (t, x)
411
420
return normalize (x)
412
421
end
@@ -427,43 +436,43 @@ mutable struct JumpRNGState{T<:Real,R<:AbstractRNG}
427
436
rng:: R
428
437
threshold:: T
429
438
end
430
- function JumpRNGState (:: Type{T} , seed) where T
439
+ function JumpRNGState (:: Type{T} , seed) where {T}
431
440
rng = MersenneTwister (seed)
432
441
threshold = rand (rng, T)
433
442
JumpRNGState (rng, threshold)
434
443
end
435
- roll! (s:: JumpRNGState{T} ) where T = (s. threshold = rand (s. rng, T))
444
+ roll! (s:: JumpRNGState{T} ) where {T} = (s. threshold = rand (s. rng, T))
436
445
threshold (s:: JumpRNGState ) = s. threshold
437
446
438
447
function jump_callback (jumpfun:: F , seed, scb, save_before!:: G ,
439
- save_after!:: H , save_t_index:: I , psi0, rng_state:: JumpRNGState ) where {F,G,H,I}
448
+ save_after!:: H , save_t_index:: I , psi0, rng_state:: JumpRNGState ) where {F,G,H,I}
440
449
441
450
tmp = copy (psi0)
442
451
psi_tmp = copy (psi0)
443
452
444
- djumpnorm (x, t, integrator) = norm (x)^ 2 - (1 - threshold (rng_state))
453
+ djumpnorm (x, t, integrator) = norm (x)^ 2 - (1 - threshold (rng_state))
445
454
446
455
function dojump (integrator)
447
456
x = integrator. u
448
457
t = integrator. t
449
458
450
459
affect! = scb. affect!
451
- save_before! (affect!,integrator)
452
- recast! (psi_tmp,x)
460
+ save_before! (affect!, integrator)
461
+ recast! (psi_tmp, x)
453
462
i = jumpfun (rng_state. rng, t, psi_tmp, tmp)
454
463
x .= tmp. data
455
- save_after! (affect!,integrator)
456
- save_t_index (t,i)
464
+ save_after! (affect!, integrator)
465
+ save_t_index (t, i)
457
466
458
467
roll! (rng_state)
459
468
return nothing
460
469
end
461
470
462
- return OrdinaryDiffEq. ContinuousCallback (djumpnorm,dojump,
463
- save_positions = (false ,false ))
471
+ return OrdinaryDiffEq. ContinuousCallback (djumpnorm, dojump,
472
+ save_positions= (false , false ))
464
473
end
465
474
jump_callback (jumpfun, seed, scb, save_before!,
466
- save_after!, save_t_index, psi0, :: Nothing ) =
475
+ save_after!, save_t_index, psi0, :: Nothing ) =
467
476
jump_callback (jumpfun, seed, scb, save_before!,
468
477
save_after!, save_t_index, psi0, JumpRNGState (real (eltype (psi0)), seed))
469
478
@@ -483,13 +492,13 @@ Default jump function.
483
492
* `probs_tmp`: Temporary array for holding jump probailities.
484
493
"""
485
494
function jump (rng, t, psi, J, psi_new, probs_tmp, rates:: Nothing )
486
- if length (J)== 1
487
- QuantumOpticsBase. mul! (psi_new,J[1 ],psi,true ,false )
495
+ if length (J) == 1
496
+ QuantumOpticsBase. mul! (psi_new, J[1 ], psi, true , false )
488
497
psi_new. data ./= norm (psi_new)
489
- i= 1
498
+ i = 1
490
499
else
491
- for i= 1 : length (J)
492
- QuantumOpticsBase. mul! (psi_new,J[i],psi,true ,false )
500
+ for i = 1 : length (J)
501
+ QuantumOpticsBase. mul! (psi_new, J[i], psi, true , false )
493
502
probs_tmp[i] = real (dot (psi_new. data, psi_new. data))
494
503
end
495
504
r = rand (rng)
@@ -501,19 +510,19 @@ function jump(rng, t, psi, J, psi_new, probs_tmp, rates::Nothing)
501
510
cumulative_prob += p / total
502
511
cumulative_prob > r && break
503
512
end
504
- QuantumOpticsBase. mul! (psi_new,J[i],psi,eltype (psi)(1 / sqrt (probs_tmp[i])),zero (eltype (psi)))
513
+ QuantumOpticsBase. mul! (psi_new, J[i], psi, eltype (psi)(1 / sqrt (probs_tmp[i])), zero (eltype (psi)))
505
514
end
506
515
return i
507
516
end
508
517
509
518
function jump (rng, t, psi, J, psi_new, probs_tmp, rates:: AbstractVector )
510
- if length (J)== 1
511
- QuantumOpticsBase. mul! (psi_new,J[1 ],psi,eltype (psi)(sqrt (rates[1 ])),zero (eltype (psi)))
519
+ if length (J) == 1
520
+ QuantumOpticsBase. mul! (psi_new, J[1 ], psi, eltype (psi)(sqrt (rates[1 ])), zero (eltype (psi)))
512
521
psi_new. data ./= norm (psi_new)
513
- i= 1
522
+ i = 1
514
523
else
515
- for i= 1 : length (J)
516
- QuantumOpticsBase. mul! (psi_new,J[i],psi,eltype (psi)(sqrt (rates[i])),zero (eltype (psi)))
524
+ for i = 1 : length (J)
525
+ QuantumOpticsBase. mul! (psi_new, J[i], psi, eltype (psi)(sqrt (rates[i])), zero (eltype (psi)))
517
526
probs_tmp[i] = real (dot (psi_new. data, psi_new. data))
518
527
end
519
528
r = rand (rng)
@@ -525,7 +534,7 @@ function jump(rng, t, psi, J, psi_new, probs_tmp, rates::AbstractVector)
525
534
cumulative_prob += p / total
526
535
cumulative_prob > r && break
527
536
end
528
- QuantumOpticsBase. mul! (psi_new,J[i],psi,eltype (psi)(sqrt (rates[i]/ probs_tmp[i])),zero (eltype (psi)))
537
+ QuantumOpticsBase. mul! (psi_new, J[i], psi, eltype (psi)(sqrt (rates[i] / probs_tmp[i])), zero (eltype (psi)))
529
538
end
530
539
return i
531
540
end
@@ -540,19 +549,19 @@ the jump operators J.
540
549
See also: [`mcwf`](@ref)
541
550
"""
542
551
function dmcwf_h! (dpsi, H, J, Jdagger, rates:: Nothing , psi, dpsi_cache)
543
- QuantumOpticsBase. mul! (dpsi,H, psi,eltype (psi)(- im),zero (eltype (psi)))
544
- for i= 1 : length (J)
545
- QuantumOpticsBase. mul! (dpsi_cache,J[i],psi,true ,false )
546
- QuantumOpticsBase. mul! (dpsi,Jdagger[i],dpsi_cache,eltype (psi)(- 0.5 ),one (eltype (psi)))
552
+ QuantumOpticsBase. mul! (dpsi, H, psi, eltype (psi)(- im), zero (eltype (psi)))
553
+ for i = 1 : length (J)
554
+ QuantumOpticsBase. mul! (dpsi_cache, J[i], psi, true , false )
555
+ QuantumOpticsBase. mul! (dpsi, Jdagger[i], dpsi_cache, eltype (psi)(- 0.5 ), one (eltype (psi)))
547
556
end
548
557
return dpsi
549
558
end
550
559
551
560
function dmcwf_h! (dpsi, H, J, Jdagger, rates:: AbstractVector , psi, dpsi_cache)
552
- QuantumOpticsBase. mul! (dpsi,H, psi,eltype (psi)(- im),zero (eltype (psi)))
553
- for i= 1 : length (J)
554
- QuantumOpticsBase. mul! (dpsi_cache,J[i],psi,eltype (psi)(rates[i]),zero (eltype (psi)))
555
- QuantumOpticsBase. mul! (dpsi,Jdagger[i],dpsi_cache,eltype (psi)(- 0.5 ),one (eltype (psi)))
561
+ QuantumOpticsBase. mul! (dpsi, H, psi, eltype (psi)(- im), zero (eltype (psi)))
562
+ for i = 1 : length (J)
563
+ QuantumOpticsBase. mul! (dpsi_cache, J[i], psi, eltype (psi)(rates[i]), zero (eltype (psi)))
564
+ QuantumOpticsBase. mul! (dpsi, Jdagger[i], dpsi_cache, eltype (psi)(- 0.5 ), one (eltype (psi)))
556
565
end
557
566
return dpsi
558
567
end
@@ -568,13 +577,13 @@ function check_mcwf(psi0, H, J, Jdagger, rates)
568
577
if ! (isa (H, DenseOpType) || isa (H, SparseOpType))
569
578
isreducible = false
570
579
end
571
- for j= J
580
+ for j = J
572
581
@assert isa (j, AbstractOperator)
573
582
if ! (isa (j, DenseOpType) || isa (j, SparseOpType))
574
583
isreducible = false
575
584
end
576
585
end
577
- for j= Jdagger
586
+ for j = Jdagger
578
587
@assert isa (j, AbstractOperator)
579
588
if ! (isa (j, DenseOpType) || isa (j, SparseOpType))
580
589
isreducible = false
@@ -605,11 +614,11 @@ corresponding set of jump operators is calculated.
605
614
function diagonaljumps (rates:: AbstractMatrix , J)
606
615
@assert length (J) == size (rates)[1 ] == size (rates)[2 ]
607
616
d, v = eigen (rates)
608
- d, [sum ([v[j, i]* J[j] for j= 1 : length (d)]) for i= 1 : length (d)]
617
+ d, [sum ([v[j, i] * J[j] for j = 1 : length (d)]) for i = 1 : length (d)]
609
618
end
610
619
611
- function diagonaljumps (rates:: AbstractMatrix , J:: Vector{T} ) where T<: Union{LazySum,LazyTensor,LazyProduct}
620
+ function diagonaljumps (rates:: AbstractMatrix , J:: Vector{T} ) where { T<: Union{LazySum,LazyTensor,LazyProduct} }
612
621
@assert length (J) == size (rates)[1 ] == size (rates)[2 ]
613
622
d, v = eigen (rates)
614
- d, [LazySum ([v[j, i]* J[j] for j= 1 : length (d)]. .. ) for i= 1 : length (d)]
623
+ d, [LazySum ([v[j, i] * J[j] for j = 1 : length (d)]. .. ) for i = 1 : length (d)]
615
624
end
0 commit comments