From 21064f631e548452a6e3d65a973f1561744b4b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Sat, 15 Mar 2025 17:34:34 +0100 Subject: [PATCH 1/5] batch-spanning chunking in threaded execution --- src/construction.jl | 4 +- src/coreloop.jl | 111 +++++++++++++++++++++++++++++++++++++-- src/executionstyles.jl | 5 +- src/network_structure.jl | 6 ++- 4 files changed, 117 insertions(+), 9 deletions(-) diff --git a/src/construction.jl b/src/construction.jl index f2e3e7f92..0f914f145 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -183,14 +183,14 @@ function Network(g::AbstractGraph, # create map for extenral inputs extmap = has_external_input(im) ? ExtMap(im) : nothing - nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches), - typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}( + nw = Network( vertexbatches, nl, im, caches, mass_matrix, gbufprovider, extmap, + execution, ) end diff --git a/src/coreloop.jl b/src/coreloop.jl index 1162c1520..1e9b57479 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -88,15 +88,116 @@ end end end -@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} - unrolled_foreach(filt, batches) do batch +@inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} + # unrolled_foreach(filt, batches) do batch + # (du, u, o, p, t) = duopt + # Threads.@threads for i in 1:length(batch) + # _type = dispatchT(batch) + # apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + # end + # end + # return + + Nchunks = Threads.nthreads() + # Nchunks = 4 + # chunking is kinda expensive, so we cache it + key = hash((Base.objectid(batches), filt, fg, Nchunks)) + chunks = get!(ex.chunk_cache, key) do + _chunk_batches(batches, filt, fg, Nchunks) + end + + _progress_in_batch = function(batch, ci, processed, N) (du, u, o, p, t) = duopt - Threads.@threads for i in 1:length(batch) - _type = dispatchT(batch) - apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + _type = dispatchT(batch) + while ci ≤ length(batch) && processed < N + apply_comp!(_type, fg, batch, ci, du, u, o, inbufs, p, t) + ci += 1 + processed += 1 + end + processed, ci + end + + Threads.@sync for chunk in chunks + chunk.N == 0 && continue + Threads.@spawn begin + local N = chunk.N + local bi = chunk.batch_i + local ci = chunk.comp_i + local processed = 0 + while processed < N + batch = batches[bi] + filt(batch) || continue + processed, ci = @noinline _progress_in_batch(batch, ci, processed, N) + bi += 1 + ci = 1 + end + end + end +end +function _chunk_batches(batches, filt, fg, workers) + Ncomp = 0 + total_eqs = 0 + unrolled_foreach(filt, batches) do batch + Ncomp += length(batch)::Int + total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int + end + chunks = Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}(undef, workers) + + eqs_per_worker = total_eqs / workers + bi = 1 + ci = 1 + assigned = 0 + eqs_assigned = 0 + for w in 1:workers + ci_start = ci + bi_start = bi + eqs_in_worker = 0 + assigned_in_worker = 0 + while assigned < Ncomp + batch = batches[bi] + filt(batch) || continue + + Neqs = _N_eqs(fg, batch) + stop_collecting = false + while true + if ci > length(batch) + break + end + + # compare, whether adding the new component helps to come closer to eqs_per_worker + diff_now = abs(eqs_in_worker - eqs_per_worker) + diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) + stop_collecting = assigned == Ncomp || diff_now ≤ diff_next + if stop_collecting + break + end + + # add component to worker + eqs_assigned += Neqs + eqs_in_worker += Neqs + assigned_in_worker += 1 + assigned += 1 + ci += 1 + end + # if the hard stop collection is reached, break, otherwise jump to next batch and continue + stop_collecting && break + + bi += 1 + ci = 1 end + chunk = (; batch_i=bi_start, comp_i=ci_start, N=assigned_in_worker) + chunks[w] = chunk + + # update eqs per worker estimate for the other workders + eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w) end + @assert assigned == Ncomp + return chunks end +_N_eqs(::Val{:f}, batch) = Int(dim(batch)) +_N_eqs(::Val{:g}, batch) = Int(outdim(batch)) +_N_eqs(::Val{:fg}, batch) = Int(dim(batch)) + Int(outdim(batch)) + @inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F} unrolled_foreach(filt, batches) do batch diff --git a/src/executionstyles.jl b/src/executionstyles.jl index 3fc46a17c..4931bbb6d 100644 --- a/src/executionstyles.jl +++ b/src/executionstyles.jl @@ -41,7 +41,10 @@ struct PolyesterExecution{buffered} <: ExecutionStyle{buffered} end Parallel execution using Julia threads. For `buffered` see [`ExecutionStyle`](@ref). """ -struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} end +@kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} + chunk_cache::Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}= + Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}() +end usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered usebuffer(::Type{<:ExecutionStyle{buffered}}) where {buffered} = buffered diff --git a/src/network_structure.jl b/src/network_structure.jl index 5910f6f37..eabae40c0 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -83,8 +83,10 @@ struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT,EM} gbufprovider::GBT "map to gather external inputs" extmap::EM + "execution style" + executionstyle::EX end -executionstyle(::Network{ex}) where {ex} = ex() +executionstyle(nw::Network) = nw.executionstyle nvbatches(::Network) = length(vertexbatches) """ @@ -164,6 +166,8 @@ end @inline compf(b::ComponentBatch) = b.compf @inline compg(b::ComponentBatch) = b.compg @inline fftype(b::ComponentBatch) = b.ff +@inline dim(b::ComponentBatch) = sum(b.statestride.strides) +@inline outdim(b::ComponentBatch) = sum(b.outbufstride.strides) @inline pdim(b::ComponentBatch) = b.pstride.strides @inline extdim(b::ComponentBatch) = b.extbufstride.strides From dc38e3ba295777241318c9f0f2d9ba7a674782a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Sat, 15 Mar 2025 18:09:34 +0100 Subject: [PATCH 2/5] chunk less lowlevel --- src/coreloop.jl | 78 +++++++++++++++++++----------------------- src/executionstyles.jl | 4 +-- 2 files changed, 37 insertions(+), 45 deletions(-) diff --git a/src/coreloop.jl b/src/coreloop.jl index 1e9b57479..e09565657 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -99,37 +99,27 @@ end # return Nchunks = Threads.nthreads() - # Nchunks = 4 + # Nchunks = 8 # chunking is kinda expensive, so we cache it key = hash((Base.objectid(batches), filt, fg, Nchunks)) chunks = get!(ex.chunk_cache, key) do _chunk_batches(batches, filt, fg, Nchunks) end - _progress_in_batch = function(batch, ci, processed, N) + _eval_batchportion = function (batch, idxs) (du, u, o, p, t) = duopt _type = dispatchT(batch) - while ci ≤ length(batch) && processed < N - apply_comp!(_type, fg, batch, ci, du, u, o, inbufs, p, t) - ci += 1 - processed += 1 + for i in idxs + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) end - processed, ci end Threads.@sync for chunk in chunks - chunk.N == 0 && continue + isempty(chunk) && continue Threads.@spawn begin - local N = chunk.N - local bi = chunk.batch_i - local ci = chunk.comp_i - local processed = 0 - while processed < N - batch = batches[bi] - filt(batch) || continue - processed, ci = @noinline _progress_in_batch(batch, ci, processed, N) - bi += 1 - ci = 1 + for (; bi, idxs) in chunk + batch = batches[bi] # filtering don in chunks + @noinline _eval_batchportion(batch, idxs) end end end @@ -141,7 +131,7 @@ function _chunk_batches(batches, filt, fg, workers) Ncomp += length(batch)::Int total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int end - chunks = Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}(undef, workers) + chunks = Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}(undef, workers) eqs_per_worker = total_eqs / workers bi = 1 @@ -149,43 +139,45 @@ function _chunk_batches(batches, filt, fg, workers) assigned = 0 eqs_assigned = 0 for w in 1:workers + chunk = @NamedTuple{bi::Int,idxs::UnitRange{Int64}}[] ci_start = ci - bi_start = bi eqs_in_worker = 0 assigned_in_worker = 0 while assigned < Ncomp batch = batches[bi] - filt(batch) || continue - Neqs = _N_eqs(fg, batch) - stop_collecting = false - while true - if ci > length(batch) - break + if filt(batch) #only process if the batch is not filtered out + Neqs = _N_eqs(fg, batch) + stop_collecting = false + while true + if ci > length(batch) + break + end + + # compare, whether adding the new component helps to come closer to eqs_per_worker + diff_now = abs(eqs_in_worker - eqs_per_worker) + diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) + stop_collecting = assigned == Ncomp || diff_now ≤ diff_next + if stop_collecting + break + end + + # add component to worker + eqs_assigned += Neqs + eqs_in_worker += Neqs + assigned_in_worker += 1 + assigned += 1 + ci += 1 end - - # compare, whether adding the new component helps to come closer to eqs_per_worker - diff_now = abs(eqs_in_worker - eqs_per_worker) - diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) - stop_collecting = assigned == Ncomp || diff_now ≤ diff_next - if stop_collecting - break + if ci > ci_start # don't push empty chunks + push!(chunk, (; bi, idxs=ci_start:ci-1)) end - - # add component to worker - eqs_assigned += Neqs - eqs_in_worker += Neqs - assigned_in_worker += 1 - assigned += 1 - ci += 1 + stop_collecting && break end - # if the hard stop collection is reached, break, otherwise jump to next batch and continue - stop_collecting && break bi += 1 ci = 1 end - chunk = (; batch_i=bi_start, comp_i=ci_start, N=assigned_in_worker) chunks[w] = chunk # update eqs per worker estimate for the other workders diff --git a/src/executionstyles.jl b/src/executionstyles.jl index 4931bbb6d..302afeadf 100644 --- a/src/executionstyles.jl +++ b/src/executionstyles.jl @@ -42,8 +42,8 @@ Parallel execution using Julia threads. For `buffered` see [`ExecutionStyle`](@ref). """ @kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} - chunk_cache::Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}= - Dict{UInt, Vector{@NamedTuple{batch_i::Int, comp_i::Int, N::Int}}}() + chunk_cache::Dict{UInt, Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}} = + Dict{UInt, Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}}() end usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered From ffaa103bd2df66b33a5b830a8ae9f5ac601395a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Sat, 15 Mar 2025 18:59:17 +0100 Subject: [PATCH 3/5] try to make it slightly more typestable --- src/coreloop.jl | 46 +++++++++++++++++++----------------------- src/executionstyles.jl | 3 +-- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/coreloop.jl b/src/coreloop.jl index e09565657..be8df9cbe 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -89,38 +89,28 @@ end end @inline function process_batches!(ex::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} - # unrolled_foreach(filt, batches) do batch - # (du, u, o, p, t) = duopt - # Threads.@threads for i in 1:length(batch) - # _type = dispatchT(batch) - # apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) - # end - # end - # return - Nchunks = Threads.nthreads() - # Nchunks = 8 + # chunking is kinda expensive, so we cache it key = hash((Base.objectid(batches), filt, fg, Nchunks)) chunks = get!(ex.chunk_cache, key) do _chunk_batches(batches, filt, fg, Nchunks) end - _eval_batchportion = function (batch, idxs) - (du, u, o, p, t) = duopt - _type = dispatchT(batch) - for i in idxs - apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + # each chunk consists of array or tuple [(batch, idxs), ...] + _eval_chunk = function(chunk) + unrolled_foreach(chunk) do ch + (; batch, idxs) = ch + (du, u, o, p, t) = duopt + _type = dispatchT(batch) + for i in idxs + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) + end end end - Threads.@sync for chunk in chunks - isempty(chunk) && continue Threads.@spawn begin - for (; bi, idxs) in chunk - batch = batches[bi] # filtering don in chunks - @noinline _eval_batchportion(batch, idxs) - end + @noinline _eval_chunk(chunk) end end end @@ -131,7 +121,7 @@ function _chunk_batches(batches, filt, fg, workers) Ncomp += length(batch)::Int total_eqs += length(batch)::Int * _N_eqs(fg, batch)::Int end - chunks = Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}(undef, workers) + chunks = Vector{Any}(undef, workers) eqs_per_worker = total_eqs / workers bi = 1 @@ -139,7 +129,7 @@ function _chunk_batches(batches, filt, fg, workers) assigned = 0 eqs_assigned = 0 for w in 1:workers - chunk = @NamedTuple{bi::Int,idxs::UnitRange{Int64}}[] + chunk = Vector{Any}() ci_start = ci eqs_in_worker = 0 assigned_in_worker = 0 @@ -170,7 +160,7 @@ function _chunk_batches(batches, filt, fg, workers) ci += 1 end if ci > ci_start # don't push empty chunks - push!(chunk, (; bi, idxs=ci_start:ci-1)) + push!(chunk, (; batch, idxs=ci_start:ci-1)) end stop_collecting && break end @@ -178,7 +168,13 @@ function _chunk_batches(batches, filt, fg, workers) bi += 1 ci = 1 end - chunks[w] = chunk + + # narrow down type / make tuple + chunks[w] = if length(chunk) < 10 + Tuple(chunk) + else + [c for c in chunk] # narrow down type + end # update eqs per worker estimate for the other workders eqs_per_worker = (total_eqs - eqs_assigned) / (workers - w) diff --git a/src/executionstyles.jl b/src/executionstyles.jl index 302afeadf..523be9dcb 100644 --- a/src/executionstyles.jl +++ b/src/executionstyles.jl @@ -42,8 +42,7 @@ Parallel execution using Julia threads. For `buffered` see [`ExecutionStyle`](@ref). """ @kwdef struct ThreadedExecution{buffered} <: ExecutionStyle{buffered} - chunk_cache::Dict{UInt, Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}} = - Dict{UInt, Vector{Vector{@NamedTuple{bi::Int,idxs::UnitRange{Int64}}}}}() + chunk_cache::Dict{UInt, Vector} = Dict{UInt, Vector}() end usebuffer(::ExecutionStyle{buffered}) where {buffered} = buffered From 8c6f9cf93ee04af6c2bd7be8de99321ee83ed613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Sat, 15 Mar 2025 21:07:43 +0100 Subject: [PATCH 4/5] fix adapt and make chunking slighly greedy --- ext/CUDAExt.jl | 6 ++---- src/coreloop.jl | 12 ++++++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index f2f49c80c..b1d7dce21 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -34,12 +34,10 @@ function Adapt.adapt_structure(to, n::Network) caches = (;output = _adapt_diffcache(to, n.caches.output), aggregation = _adapt_diffcache(to, n.caches.aggregation), external = _adapt_diffcache(to, n.caches.external)) - exT = typeof(executionstyle(n)) - gT = typeof(n.im.g) + ex = executionstyle(n) extmap = adapt(to, n.extmap) - Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}( - vb, layer, n.im, caches, mm, gbp, extmap) + Network(vb, layer, n.im, caches, mm, gbp, extmap, ex) end Adapt.@adapt_structure NetworkLayer diff --git a/src/coreloop.jl b/src/coreloop.jl index be8df9cbe..0a9c288ce 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -124,11 +124,13 @@ function _chunk_batches(batches, filt, fg, workers) chunks = Vector{Any}(undef, workers) eqs_per_worker = total_eqs / workers + # println("Total eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)") bi = 1 ci = 1 assigned = 0 eqs_assigned = 0 for w in 1:workers + # println("Assign worker $w: goal: $eqs_per_worker") chunk = Vector{Any}() ci_start = ci eqs_in_worker = 0 @@ -147,12 +149,13 @@ function _chunk_batches(batches, filt, fg, workers) # compare, whether adding the new component helps to come closer to eqs_per_worker diff_now = abs(eqs_in_worker - eqs_per_worker) diff_next = abs(eqs_in_worker + Neqs - eqs_per_worker) - stop_collecting = assigned == Ncomp || diff_now ≤ diff_next + stop_collecting = assigned == Ncomp || diff_now < diff_next if stop_collecting break end # add component to worker + # println(" - Assign component $ci ($Neqs eqs)") eqs_assigned += Neqs eqs_in_worker += Neqs assigned_in_worker += 1 @@ -160,9 +163,14 @@ function _chunk_batches(batches, filt, fg, workers) ci += 1 end if ci > ci_start # don't push empty chunks - push!(chunk, (; batch, idxs=ci_start:ci-1)) + # println(" - Assign batch $(bi) -> $(ci_start:(ci-1)) $(length(ci_start:(ci-1))*Neqs) eqs)") + push!(chunk, (; batch, idxs=ci_start:(ci-1))) + else + # println(" - Skip empty batch $(bi) -> $(ci_start:(ci-1))") end stop_collecting && break + else + # println(" - Skip batch $(bi)") end bi += 1 From ea60fb1a84b1d83fbd9a6d24ee086ba839f63845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Mon, 17 Mar 2025 09:32:55 +0100 Subject: [PATCH 5/5] fix bug in chunking --- src/coreloop.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/coreloop.jl b/src/coreloop.jl index 0a9c288ce..c6ccb8d15 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -124,7 +124,7 @@ function _chunk_batches(batches, filt, fg, workers) chunks = Vector{Any}(undef, workers) eqs_per_worker = total_eqs / workers - # println("Total eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)") + # println("\nTotal eqs: $total_eqs in $Ncomp components, eqs per worker: $eqs_per_worker ($fg)") bi = 1 ci = 1 assigned = 0 @@ -132,13 +132,13 @@ function _chunk_batches(batches, filt, fg, workers) for w in 1:workers # println("Assign worker $w: goal: $eqs_per_worker") chunk = Vector{Any}() - ci_start = ci eqs_in_worker = 0 assigned_in_worker = 0 while assigned < Ncomp batch = batches[bi] if filt(batch) #only process if the batch is not filtered out + ci_start = ci Neqs = _N_eqs(fg, batch) stop_collecting = false while true