diff --git a/demos/FiniteVolume/CMakeLists.txt b/demos/FiniteVolume/CMakeLists.txt index b8d79691d..2f731b436 100644 --- a/demos/FiniteVolume/CMakeLists.txt +++ b/demos/FiniteVolume/CMakeLists.txt @@ -16,6 +16,7 @@ level_set.cpp:finite-volume-level-set level_set_from_scratch.cpp:finite-volume-level-set-from-scratch advection_1d.cpp:finite-volume-advection-1d advection_2d.cpp:finite-volume-advection-2d +advection_2d_load_balancing.cpp:finite-volume-advection-2d-load-balancing advection_2d_user_bc.cpp:finite-volume-advection-2d-user-bc scalar_burgers_2d.cpp:finite-volume-scalar-burgers-2d linear_convection.cpp:finite-volume-linear-convection diff --git a/demos/FiniteVolume/advection_2d_load_balancing.cpp b/demos/FiniteVolume/advection_2d_load_balancing.cpp new file mode 100644 index 000000000..7ceb9f1a0 --- /dev/null +++ b/demos/FiniteVolume/advection_2d_load_balancing.cpp @@ -0,0 +1,305 @@ +// Copyright 2018-2025 the samurai's authors +// SPDX-License-Identifier: BSD-3-Clause + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +namespace fs = std::filesystem; + +template +void init(Field& u) +{ + auto& mesh = u.mesh(); + u.resize(); + + samurai::for_each_cell( + mesh, + [&](auto& cell) + { + auto center = cell.center(); + const double radius = .2; + const double x_center = 0.3; + const double y_center = 0.3; + if (((center[0] - x_center) * (center[0] - x_center) + (center[1] - y_center) * (center[1] - y_center)) <= radius * radius) + { + u[cell] = 1; + } + else + { + u[cell] = 0; + } + }); +} + +template +void flux_correction(double dt, const std::array& a, const Field& u, Field& unp1) +{ + using mesh_t = typename Field::mesh_t; + using mesh_id_t = typename mesh_t::mesh_id_t; + using interval_t = typename mesh_t::interval_t; + constexpr std::size_t dim = Field::dim; + + auto mesh = u.mesh(); + + for (std::size_t level = mesh.min_level(); level < mesh.max_level(); ++level) + { + xt::xtensor_fixed> stencil; + + stencil = { + {-1, 0} + }; + + auto subset_right = samurai::intersection(samurai::translate(mesh[mesh_id_t::cells][level + 1], stencil), + mesh[mesh_id_t::cells][level]) + .on(level); + + subset_right( + [&](const auto& i, const auto& index) + { + auto j = index[0]; + const double dx = mesh.cell_length(level); + + unp1(level, i, j) = unp1(level, i, j) + + dt / dx + * (samurai::upwind_op(level, i, j).right_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i + 1, 2 * j).right_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i + 1, 2 * j + 1).right_flux(a, u)); + }); + + stencil = { + {1, 0} + }; + + auto subset_left = samurai::intersection(samurai::translate(mesh[mesh_id_t::cells][level + 1], stencil), + mesh[mesh_id_t::cells][level]) + .on(level); + + subset_left( + [&](const auto& i, const auto& index) + { + auto j = index[0]; + const double dx = mesh.cell_length(level); + + unp1(level, i, j) = unp1(level, i, j) + - dt / dx + * (samurai::upwind_op(level, i, j).left_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i, 2 * j).left_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i, 2 * j + 1).left_flux(a, u)); + }); + + stencil = { + {0, -1} + }; + + auto subset_up = samurai::intersection(samurai::translate(mesh[mesh_id_t::cells][level + 1], stencil), mesh[mesh_id_t::cells][level]) + .on(level); + + subset_up( + [&](const auto& i, const auto& index) + { + auto j = index[0]; + const double dx = mesh.cell_length(level); + + unp1(level, i, j) = unp1(level, i, j) + + dt / dx + * (samurai::upwind_op(level, i, j).up_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i, 2 * j + 1).up_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i + 1, 2 * j + 1).up_flux(a, u)); + }); + + stencil = { + {0, 1} + }; + + auto subset_down = samurai::intersection(samurai::translate(mesh[mesh_id_t::cells][level + 1], stencil), + mesh[mesh_id_t::cells][level]) + .on(level); + + subset_down( + [&](const auto& i, const auto& index) + { + auto j = index[0]; + const double dx = mesh.cell_length(level); + + unp1(level, i, j) = unp1(level, i, j) + - dt / dx + * (samurai::upwind_op(level, i, j).down_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i, 2 * j).down_flux(a, u) + - .5 * samurai::upwind_op(level + 1, 2 * i + 1, 2 * j).down_flux(a, u)); + }); + } +} + +template +void save(const fs::path& path, const std::string& filename, const Field& u, const std::string& suffix = "") +{ + auto mesh = u.mesh(); + auto level_ = samurai::make_scalar_field("level", mesh); + + if (!fs::exists(path)) + { + fs::create_directory(path); + } + + samurai::for_each_cell(mesh, + [&](const auto& cell) + { + level_[cell] = cell.level; + }); +#ifdef SAMURAI_WITH_MPI + mpi::communicator world; + samurai::save(path, fmt::format("{}_size_{}{}", filename, world.size(), suffix), mesh, u, level_); +#else + samurai::save(path, fmt::format("{}{}", filename, suffix), mesh, u, level_); + samurai::dump(path, fmt::format("{}_restart{}", filename, suffix), mesh, u); +#endif +} + +int main(int argc, char* argv[]) +{ + auto& app = samurai::initialize("Finite volume example for the advection equation in 2d using multiresolution", argc, argv); + + constexpr std::size_t dim = 2; + using Config = samurai::MRConfig; + + // Simulation parameters + xt::xtensor_fixed> min_corner = {0., 0.}; + xt::xtensor_fixed> max_corner = {1., 1.}; + std::array a{ + {1, 1} + }; + double Tf = .1; + double cfl = 0.5; + double t = 0.; + std::string restart_file; + + // Multiresolution parameters + std::size_t min_level = 4; + std::size_t max_level = 10; + double mr_epsilon = 2.e-4; // Threshold used by multiresolution + double mr_regularity = 1.; // Regularity guess for multiresolution + bool correction = false; + + // Output parameters + fs::path path = fs::current_path(); + std::string filename = "FV_advection_2d"; + std::size_t nfiles = 1; +#ifdef SAMURAI_WITH_MPI + std::size_t nt_loadbalance = 1; +#endif + + app.add_option("--min-corner", min_corner, "The min corner of the box")->capture_default_str()->group("Simulation parameters"); + app.add_option("--max-corner", max_corner, "The max corner of the box")->capture_default_str()->group("Simulation parameters"); + app.add_option("--velocity", a, "The velocity of the advection equation")->capture_default_str()->group("Simulation parameters"); + app.add_option("--cfl", cfl, "The CFL")->capture_default_str()->group("Simulation parameters"); + app.add_option("--Ti", t, "Initial time")->capture_default_str()->group("Simulation parameters"); + app.add_option("--Tf", Tf, "Final time")->capture_default_str()->group("Simulation parameters"); + app.add_option("--restart-file", restart_file, "Restart file")->capture_default_str()->group("Simulation parameters"); + app.add_option("--min-level", min_level, "Minimum level of the multiresolution")->capture_default_str()->group("Multiresolution"); + app.add_option("--max-level", max_level, "Maximum level of the multiresolution")->capture_default_str()->group("Multiresolution"); + app.add_option("--mr-eps", mr_epsilon, "The epsilon used by the multiresolution to adapt the mesh") + ->capture_default_str() + ->group("Multiresolution"); + app.add_option("--mr-reg", + mr_regularity, + "The regularity criteria used by the multiresolution to " + "adapt the mesh") + ->capture_default_str() + ->group("Multiresolution"); + app.add_option("--with-correction", correction, "Apply flux correction at the interface of two refinement levels") + ->capture_default_str() + ->group("Multiresolution"); + app.add_option("--path", path, "Output path")->capture_default_str()->group("Output"); + app.add_option("--filename", filename, "File name prefix")->capture_default_str()->group("Output"); + app.add_option("--nfiles", nfiles, "Number of output files")->capture_default_str()->group("Output"); +#ifdef SAMURAI_WITH_MPI + app.add_option("--nt-loadbalance", nt_loadbalance, "load balance each nt steps")->capture_default_str()->group("Multiresolution"); +#endif + SAMURAI_PARSE(argc, argv); + + const samurai::Box box(min_corner, max_corner); + samurai::MRMesh mesh; + auto u = samurai::make_scalar_field("u", mesh); + + if (restart_file.empty()) + { + mesh = {box, min_level, max_level}; + init(u); + } + else + { + samurai::load(restart_file, mesh, u); + } + samurai::make_bc>(u, 0.); + + double dt = cfl * mesh.cell_length(max_level); + const double dt_save = Tf / static_cast(nfiles); + + auto unp1 = samurai::make_scalar_field("unp1", mesh); + + auto MRadaptation = samurai::make_MRAdapt(u); + MRadaptation(mr_epsilon, mr_regularity); + save(path, filename, u, "_init"); + + std::size_t nsave = 1; + std::size_t nt = 0; + +#ifdef SAMURAI_WITH_MPI + samurai::DiffusionLoadBalancer balancer; +#endif + + while (t != Tf) + { +#ifdef SAMURAI_WITH_MPI + if (((nt % nt_loadbalance == 0) && nt > 1) || nt == 1) + { + auto weights = samurai::Weight::uniform(mesh); + balancer.load_balance(mesh, weights, u); + } +#endif + + MRadaptation(mr_epsilon, mr_regularity); + + t += dt; + if (t > Tf) + { + dt += Tf - t; + t = Tf; + } + + std::cout << fmt::format("iteration {}: t = {}, dt = {}", nt++, t, dt) << std::endl; + + samurai::update_ghost_mr(u); + unp1.resize(); + unp1 = u - dt * samurai::upwind(a, u); + if (correction) + { + flux_correction(dt, a, u, unp1); + } + + std::swap(u.array(), unp1.array()); + + if (t >= static_cast(nsave + 1) * dt_save || t == Tf) + { + const std::string suffix = (nfiles != 1) ? fmt::format("_ite_{}", nsave++) : ""; + save(path, filename, u, suffix); + } + } + samurai::finalize(); + return 0; +} diff --git a/include/samurai/load_balancing.hpp b/include/samurai/load_balancing.hpp new file mode 100644 index 000000000..764fb34f2 --- /dev/null +++ b/include/samurai/load_balancing.hpp @@ -0,0 +1,419 @@ +#pragma once + +#include +#include +#include +#include + +#include "algorithm.hpp" +#include "algorithm/utils.hpp" +#include "mesh.hpp" +#include "mr/mesh.hpp" +#include "timers.hpp" + +#ifdef SAMURAI_WITH_MPI +#include +#endif + +#ifdef SAMURAI_WITH_MPI +namespace samurai +{ + enum BalanceElement_t + { + CELL + }; + + class Weight + { + public: + + template + static auto from_field(const Field& f) + { + auto weight = samurai::make_scalar_field("weight", f.mesh()); + weight.fill(0.); + samurai::for_each_cell(f.mesh(), + [&](auto cell) + { + weight[cell] = f[cell]; + }); + return weight; + } + + template + static auto uniform(const Mesh& mesh) + { + auto weight = samurai::make_scalar_field("weight", mesh); + weight.fill(1.); + + return weight; + } + + template + static double compute_load(const Mesh_t& mesh, const Field_t& weight) + { + using mesh_id_t = typename Mesh_t::mesh_id_t; + const auto& current_mesh = mesh[mesh_id_t::cells]; + double current_process_load = 0.; + // cell-based load with weight. + samurai::for_each_cell(current_mesh, + [&](const auto& cell) + { + current_process_load += weight[cell]; + }); + return current_process_load; + } + }; + + template + class LoadBalancer + { + public: + + int nloadbalancing; + + // Exchange only the CellArray of meshes (cells part) + template + auto exchange_meshes(const Mesh_t& new_mesh, const Mesh_t& old_mesh) + { + samurai::times::timers.start("load_balancing_exchange_meshes"); + + using CellArray_t = typename Mesh_t::ca_type; + + boost::mpi::communicator world; + + const auto& neighbours = new_mesh.mpi_neighbourhood(); + std::size_t nb_neigh = neighbours.size(); + + std::vector all_new_cells(nb_neigh); + std::vector all_old_cells(nb_neigh); + std::vector reqs; + + // Phase 1: non-blocking receptions of CellArrays + for (std::size_t idx = 0; idx < nb_neigh; ++idx) + { + const auto& nbr = neighbours[idx]; + reqs.push_back(world.irecv(nbr.rank, 0, all_new_cells[idx])); + reqs.push_back(world.irecv(nbr.rank, 1, all_old_cells[idx])); + } + + // Phase 2: non-blocking sends of CellArrays + for (const auto& nbr : neighbours) + { + reqs.push_back(world.isend(nbr.rank, 0, new_mesh[Mesh_t::mesh_id_t::cells])); + reqs.push_back(world.isend(nbr.rank, 1, old_mesh[Mesh_t::mesh_id_t::cells])); + } + + // Finalize communications + mpi::wait_all(reqs.begin(), reqs.end()); + + samurai::times::timers.stop("load_balancing_exchange_meshes"); + + return std::make_pair(std::move(all_new_cells), std::move(all_old_cells)); + } + + template + void update_field(Mesh_t& new_mesh, + Field_t& field, + const std::vector& all_new_cells, + const std::vector& all_old_cells) + { + samurai::times::timers.start("load_balancing_update_field"); + using mesh_id_t = typename Mesh_t::mesh_id_t; + using value_t = typename Field_t::value_type; + boost::mpi::communicator world; + + Field_t new_field("new_f", new_mesh); + new_field.fill(0); + + auto& old_mesh = field.mesh(); + // TODO : check if this is correct + auto min_level = old_mesh.min_level(); + auto max_level = old_mesh.max_level(); + + // Copy data of intervals that didn't move + for (std::size_t level = min_level; level <= max_level; ++level) + { + auto intersect_old_new = intersection(old_mesh[mesh_id_t::cells][level], new_mesh[mesh_id_t::cells][level]); + intersect_old_new.apply_op(samurai::copy(new_field, field)); + } + + std::vector req; + std::vector> to_send(static_cast(world.size())); + + // Build payload of field that has been sent to neighbour, so compare old mesh with new neighbour mesh + for (size_t neighbour_idx = 0; neighbour_idx < all_new_cells.size(); ++neighbour_idx) + { + auto& neighbour_new_cells = all_new_cells[neighbour_idx]; + + for (std::size_t level = min_level; level <= max_level; ++level) + { + if (!old_mesh[mesh_id_t::cells][level].empty() && !neighbour_new_cells[level].empty()) + { + auto intersect_old_mesh_new_neigh = intersection(old_mesh[mesh_id_t::cells][level], neighbour_new_cells[level]); + intersect_old_mesh_new_neigh( + [&](const auto& interval, const auto& index) + { + std::copy(field(level, interval, index).begin(), + field(level, interval, index).end(), + std::back_inserter(to_send[neighbour_idx])); + }); + } + } + + if (to_send[neighbour_idx].size() != 0) + { + auto neighbour_rank = new_mesh.mpi_neighbourhood()[neighbour_idx].rank; + req.push_back(world.isend(neighbour_rank, neighbour_rank, to_send[neighbour_idx])); + } + } + + // Build payload of field that I need to receive from neighbour, so compare NEW mesh with OLD neighbour mesh + for (size_t neighbour_idx = 0; neighbour_idx < all_old_cells.size(); ++neighbour_idx) + { + bool isintersect = false; + for (std::size_t level = min_level; level <= max_level; ++level) + { + if (!new_mesh[mesh_id_t::cells][level].empty() && !all_old_cells[neighbour_idx][level].empty()) + { + std::vector to_recv; + + auto in_interface = intersection(all_old_cells[neighbour_idx][level], new_mesh[mesh_id_t::cells][level]); + + in_interface( + [&]([[maybe_unused]] const auto& i, [[maybe_unused]] const auto& index) + { + isintersect = true; + }); + + if (isintersect) + { + break; + } + } + } + + if (isintersect) + { + std::ptrdiff_t count = 0; + std::vector to_recv; + world.recv(new_mesh.mpi_neighbourhood()[neighbour_idx].rank, world.rank(), to_recv); + + for (std::size_t level = min_level; level <= max_level; ++level) + { + if (!new_mesh[mesh_id_t::cells][level].empty() && !all_old_cells[neighbour_idx][level].empty()) + { + auto in_interface = intersection(all_old_cells[neighbour_idx][level], new_mesh[mesh_id_t::cells][level]); + + in_interface( + [&](const auto& i, const auto& index) + { + std::copy(to_recv.begin() + count, + to_recv.begin() + count + static_cast(i.size() * field.n_comp), + new_field(level, i, index).begin()); + count += static_cast(i.size() * field.n_comp); + }); + } + } + } + } + + if (!req.empty()) + { + mpi::wait_all(req.begin(), req.end()); + } + + std::swap(field.array(), new_field.array()); + samurai::times::timers.stop("load_balancing_update_field"); + } + + template + void update_fields(Mesh_t& new_mesh, Field_t& field, Fields_t&... kw) + { + // Exchange meshes once for all fields + auto [all_new_cells, all_old_cells] = exchange_meshes(new_mesh, field.mesh()); + + // Update all fields using already exchanged CellArrays + update_field(new_mesh, field, all_new_cells, all_old_cells); + update_fields_impl(new_mesh, all_new_cells, all_old_cells, kw...); + } + + template + void update_fields_impl(Mesh_t& new_mesh, + const std::vector& all_new_cells, + const std::vector& all_old_cells, + Field_t& field, + Fields_t&... kw) + { + update_field(new_mesh, field, all_new_cells, all_old_cells); + update_fields_impl(new_mesh, all_new_cells, all_old_cells, kw...); + } + + template + void update_fields_impl([[maybe_unused]] Mesh_t& new_mesh, + [[maybe_unused]] const std::vector& all_new_cells, + [[maybe_unused]] const std::vector& all_old_cells) + { + } + + public: + + LoadBalancer() + { + boost::mpi::communicator world; + nloadbalancing = 0; + } + + template + Mesh_t update_mesh(Mesh_t& mesh, const Field_t& flags) + { + samurai::times::timers.start("load_balancing_mesh_update"); + + using CellList_t = typename Mesh_t::cl_type; + using CellArray_t = typename Mesh_t::ca_type; + + boost::mpi::communicator world; + + CellList_t new_cl; + std::vector payload(static_cast(world.size())); + + // Phase 1: build payload (cell sorting) + samurai::times::timers.start("load_balancing_build_payload"); + samurai::for_each_cell( + mesh[Mesh_t::mesh_id_t::cells], + [&](const auto& cell) + { + if (flags[cell] == world.rank()) + { + if constexpr (Mesh_t::dim == 1) + { + new_cl[cell.level][{}].add_point(cell.indices[0]); + } + if constexpr (Mesh_t::dim == 2) + { + new_cl[cell.level][{cell.indices[1]}].add_point(cell.indices[0]); + } + if constexpr (Mesh_t::dim == 3) + { + new_cl[cell.level][{cell.indices[1], cell.indices[2]}].add_point(cell.indices[0]); + } + } + else + { + assert(static_cast(flags[cell]) < payload.size()); + + if constexpr (Mesh_t::dim == 1) + { + payload[static_cast(flags[cell])][cell.level][{}].add_point(cell.indices[0]); + } + if constexpr (Mesh_t::dim == 2) + { + payload[static_cast(flags[cell])][cell.level][{cell.indices[1]}].add_point(cell.indices[0]); + } + if constexpr (Mesh_t::dim == 3) + { + payload[static_cast(flags[cell])][cell.level][{cell.indices[1], cell.indices[2]}].add_point( + cell.indices[0]); + } + } + }); + samurai::times::timers.stop("load_balancing_build_payload"); + + std::vector req; + + // Actual data exchange **only** with known neighbours of the mesh + const auto& neighbours = mesh.mpi_neighbourhood(); + + // Phase 2: non-blocking cell sends + samurai::times::timers.start("load_balancing_send_cells"); + // Non-blocking send to each neighbour (possibly empty message) + for (const auto& nbr : neighbours) + { + int rank = nbr.rank; + if (rank == world.rank()) + { + continue; + } + + CellArray_t to_send = {payload[static_cast(rank)], false}; + req.push_back(world.isend(rank, 17, to_send)); + } + samurai::times::timers.stop("load_balancing_send_cells"); + + // Phase 3: cell reception + samurai::times::timers.start("load_balancing_recv_cells"); + // Blocking reception from each neighbour + for (const auto& nbr : neighbours) + { + int rank = nbr.rank; + if (rank == world.rank()) + { + continue; + } + + CellArray_t to_rcv; + world.recv(rank, 17, to_rcv); + + samurai::for_each_interval(to_rcv, + [&](std::size_t level, const auto& interval, const auto& index) + { + new_cl[level][index].add_interval(interval); + }); + } + samurai::times::timers.stop("load_balancing_recv_cells"); + + samurai::times::timers.start("load_balancing_wait"); + boost::mpi::wait_all(req.begin(), req.end()); + samurai::times::timers.stop("load_balancing_wait"); + + samurai::times::timers.start("load_balancing_construct_mesh"); + Mesh_t new_mesh(new_cl, mesh); + samurai::times::timers.stop("load_balancing_construct_mesh"); + + samurai::times::timers.stop("load_balancing_mesh_update"); + + return new_mesh; + } + + template + void load_balance(Mesh_t& mesh, Weight_t& weight, Field_t& field, Fields&... kw) + { + // Early check: no load balancing with single process + boost::mpi::communicator world; + if (world.size() <= 1) + { + std::cout << "Process " << world.rank() << " : Single MPI process detected, load balancing ignored" << std::endl; + return; + } + + samurai::times::timers.start("load_balancing"); + + // Compute flags for this single pass + auto flags = static_cast(*this).load_balance_impl(mesh, weight); + + // Update mesh + auto new_mesh = update_mesh(mesh, flags); + + // Update physical fields (excluding weights) + update_fields(new_mesh, field, kw...); + + // Replace reference mesh + mesh.swap(new_mesh); + + nloadbalancing += 1; + + samurai::times::timers.stop("load_balancing"); + + // Final display of cell count after load balancing + { + using mesh_id_t = typename Mesh_t::mesh_id_t; + double total_weight = Weight::compute_load(field.mesh(), weight); + auto nb_cells = field.mesh().nb_cells(mesh_id_t::cells); + std::cout << "Process " << world.rank() << " : " << nb_cells << " cells (total weight " << total_weight + << ") after load balancing" << std::endl; + } + } + }; + +} // namespace samurai +#endif diff --git a/include/samurai/load_balancing_diffusion.hpp b/include/samurai/load_balancing_diffusion.hpp new file mode 100644 index 000000000..bc424a1ab --- /dev/null +++ b/include/samurai/load_balancing_diffusion.hpp @@ -0,0 +1,171 @@ +#include "field.hpp" +#include "load_balancing.hpp" +#include "timers.hpp" + +#include +#include + +#ifdef SAMURAI_WITH_MPI +namespace samurai +{ + + class DiffusionLoadBalancer : public samurai::LoadBalancer + { + private: + + template + std::vector compute_fluxes(Mesh_t& mesh, const Field_t& weight, int niterations) + { + samurai::times::timers.start("load_balancing_flux_computation"); + + using mpi_subdomain_t = typename Mesh_t::mpi_subdomain_t; + boost::mpi::communicator world; + std::vector& neighbourhood = mesh.mpi_neighbourhood(); + size_t n_neighbours = neighbourhood.size(); + + // Load of current process + double my_load = samurai::Weight::compute_load(mesh, weight); + // Fluxes between processes + std::vector fluxes(n_neighbours, 0.); + // Load of each process (all processes not only neighbours) + std::vector loads; + int iteration_count = 0; + while (iteration_count < niterations) + { + boost::mpi::all_gather(world, my_load, loads); + + // Compute updated my_load for current process based on its neighbourhood + double my_load_new = my_load; + bool all_fluxes_zero = true; + for (std::size_t neighbour_idx = 0; neighbour_idx < n_neighbours; ++neighbour_idx) + { + std::size_t neighbour_rank = static_cast(neighbourhood[neighbour_idx].rank); + double neighbour_load = loads[neighbour_rank]; + double diff_load = neighbour_load - my_load_new; + + // If transferLoad < 0 -> need to send data, if transferLoad > 0 need to receive data + // TODO : Use diffusion factor 1/(deg+1) for stability + double transfertLoad = 0.5 * diff_load; + + // Accumulate total flux on current edge + fluxes[neighbour_idx] += transfertLoad; + + // Mark if a non-zero transfer was performed + if (transfertLoad != 0) + { + all_fluxes_zero = false; + } + + // Update intermediate local load before processing next neighbour + my_load_new += transfertLoad; + } + + // Update reference load for next iteration + my_load = my_load_new; + + // Check if all processes have reached convergence + bool global_convergence = boost::mpi::all_reduce(world, all_fluxes_zero, std::logical_and()); + + // If all processes have zero fluxes, state will no longer change + if (global_convergence) + { + std::cout << "Process " << world.rank() << " : Global convergence reached at iteration " << iteration_count << std::endl; + break; + } + + iteration_count++; + } + + samurai::times::timers.stop("load_balancing_flux_computation"); + + return fluxes; + } + + public: + + DiffusionLoadBalancer() = default; + + template + auto load_balance_impl(Mesh_t& mesh, const Weight_t& weight) + { + using mesh_id_t = typename Mesh_t::mesh_id_t; + boost::mpi::communicator world; + + auto flags = samurai::make_scalar_field("diffusion_flag", mesh); + flags.fill(world.rank()); + + // Compute fluxes in terms of load to transfer/receive + std::vector fluxes = compute_fluxes(mesh, weight, 50); + + using cell_t = typename Mesh_t::cell_t; + std::vector cells; + samurai::for_each_cell(mesh[mesh_id_t::cells], + [&](auto cell) + { + cells.emplace_back(cell); + }); + + if (cells.empty()) + { + return flags; + } + + // Sort cells from "top" to "bottom", then from "left" to "right" + std::sort(cells.begin(), + cells.end(), + [](const cell_t& a, const cell_t& b) + { + auto center_a = a.center(); + auto center_b = b.center(); + if (center_a(1) != center_b(1)) + { + return center_a(1) > center_b(1); // First, cells with highest y coordinate + } + return center_a(0) < center_b(0); // Then, cells with lowest x coordinate + }); + + auto& neighbourhood = mesh.mpi_neighbourhood(); + + std::size_t top_index = 0; + std::size_t bottom_index = cells.size() - 1; + + for (std::size_t i = 0; i < neighbourhood.size(); ++i) + { + double flux = fluxes[i]; + auto neighbour_rank = neighbourhood[i].rank; + + if (flux < 0) // We must send cells + { + double weight_to_send = -flux; + double accumulated_weight = 0; + + // Send from the "top" to higher ranks, and from the "bottom" to lower ranks + if (neighbour_rank > world.rank()) + { + while (top_index <= bottom_index && accumulated_weight < weight_to_send) + { + accumulated_weight += weight[cells[top_index]]; + flags[cells[top_index]] = neighbour_rank; + top_index++; + } + } + else + { + while (bottom_index >= top_index && accumulated_weight < weight_to_send) + { + accumulated_weight += weight[cells[bottom_index]]; + flags[cells[bottom_index]] = neighbour_rank; + if (bottom_index == 0) + { + break; // Éviter l'underflow + } + bottom_index--; + } + } + } + } + return flags; + } + }; +} +#endif