Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ to turn off this check")
endif()

set(GINKGO_CHECKOUT_VERSION
"ogl_0600_gko190"
"ogl_0600_gko110"
CACHE STRING "Use specific version of ginkgo")

if(OGL_DEVEL_TOOLS)
Expand Down
101 changes: 84 additions & 17 deletions include/OGL/MatrixWrapper/Distributed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "OGL/MatrixWrapper/HostMatrix.hpp"
#include "OGL/Repartitioner.hpp"


/* The RepartDistMatrix class is a wrapper around Ginkgos distributed Matrix
* class
*
Expand All @@ -38,20 +37,60 @@ class RepartDistMatrix
using reorder_map_type =
std::tuple<std::shared_ptr<gko::array<label>>, scalar *,
std::shared_ptr<gko::array<label>>>;
using all_to_all_data = std::tuple<label, AllToAllPattern, scalar *>;
using all_to_all_data = std::tuple<label, AllToAllPattern, scalar *, label>;

struct pairwise_data {
label id; // original interface id on orig rank
label send; // 0 - send, 1, receive, 2 same_rank
label comm_rank; // other side of communication
label length; // length of the interface to communicate
label send_id; //
scalar *recv_ptr; //
label id; // original interface id on orig rank
label send; // 0 - send, 1, receive, 2 same_rank
label comm_rank; // other side of communication
label length; // length of the interface to communicate
label send_id; //
scalar *recv_ptr; // where to put the received data (linop data)
label recv_offset; // offset to begin of linop
};

using gko::EnableLinOp<RepartDistMatrix>::convert_to;
using gko::EnableLinOp<RepartDistMatrix>::move_to;

/* @brief replaces all data ptrs in all_to_all_data with new_ptr
*/
std::vector<reorder_map_type> update_reorder_map_ptr(
std::vector<reorder_map_type> in, scalar *new_ptr) const
{
std::vector<reorder_map_type> out;
auto [map, _, pad] = in[0]; //) {
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented code fragment '//) {' appears to be leftover from refactoring and should be removed.

Suggested change
auto [map, _, pad] = in[0]; //) {
auto [map, _, pad] = in[0];

Copilot uses AI. Check for mistakes.

out.emplace_back(map, new_ptr, pad);
// }
return out;
}

/* @brief replaces all data ptrs in all_to_all_data with new_ptr
*/
std::vector<all_to_all_data> update_all_to_all_recv_ptr(
std::vector<all_to_all_data> in, scalar *new_ptr) const
{
std::vector<all_to_all_data> out;
for (auto [id, comm_pattern, data_ptr, offset] : in) {
out.emplace_back(id, comm_pattern, new_ptr, offset);
}
return out;
}

/* @brief replaces all data ptrs in pairwise_data with new_ptr
*/
std::vector<pairwise_data> update_pairwise_recv_ptr(
std::vector<pairwise_data> in, scalar *new_ptr) const
{
std::vector<pairwise_data> out;
for (auto [linop_id, mode, comm_rank, length, send_id, recv_ptr,
offset] : in) {
out.push_back(pairwise_data{linop_id, mode, comm_rank, length,
send_id, new_ptr, offset});
}
return out;
}


std::shared_ptr<const gko::LinOp> get_dist_matrix() const
{
return this->dist_mtx_;
Expand Down Expand Up @@ -80,6 +119,33 @@ class RepartDistMatrix
}
}

std::shared_ptr<RepartDistMatrix> clone() const
{
auto new_dist_mtx = gko::share(this->dist_mtx_->clone());
auto new_local_ptr = gko::as<gko::matrix::Csr<scalar, label>>(
new_dist_mtx->get_local_matrix());
const scalar *new_data_ptr = new_local_ptr->get_const_values();

auto all_to_all_data = update_all_to_all_recv_ptr(
this->all_to_all_update_data_, const_cast<scalar *>(new_data_ptr));

auto new_non_local_ptr = gko::as<gko::matrix::Csr<scalar, label>>(
new_dist_mtx->get_non_local_matrix());
const scalar *new_non_local_data_ptr =
new_non_local_ptr->get_const_values();
auto pairwise_update_data = update_pairwise_recv_ptr(
this->pairwise_update_data_,
const_cast<scalar *>(new_non_local_data_ptr));

auto new_reorder_map = update_reorder_map_ptr(
this->reorder_maps_, const_cast<scalar *>(new_data_ptr));

return std::make_shared<RepartDistMatrix>(
this->get_executor(), this->get_communicator(),
this->matrix_format_, new_dist_mtx, this->repartitioner_,
this->fuse_, all_to_all_data, pairwise_update_data, new_reorder_map,
this->compress_to_global_);
}

/**
* Copy-assigns a CombinationMatrix matrix. Preserves executor, copies
Expand All @@ -101,7 +167,6 @@ class RepartDistMatrix
this->pairwise_update_data_ = other.pairwise_update_data_;
this->reorder_maps_ = other.reorder_maps_;
this->compress_to_global_ = other.compress_to_global_;
this->linops_ = other.linops_;
}
return *this;
}
Expand All @@ -126,7 +191,6 @@ class RepartDistMatrix
std::move(other.pairwise_update_data_);
this->reorder_maps_ = std::move(other.reorder_maps_);
this->compress_to_global_ = std::move(other.compress_to_global_);
this->linops_ = std::move(other.linops_);
}
return *this;
}
Expand All @@ -135,6 +199,9 @@ class RepartDistMatrix
void update(const ExecutorHandler &exec_handler,
std::shared_ptr<const HostMatrixWrapper> host_A, label verbose);

void update(const ExecutorHandler &exec_handler, const scalar *diag_ptr,
const scalar *face_ptr, label verbose);

word get_matrix_format() const { return matrix_format_; }

std::shared_ptr<const gko::LinOp> get_dist_mtx() const { return dist_mtx_; }
Expand All @@ -147,8 +214,10 @@ class RepartDistMatrix
std::vector<all_to_all_data> all_to_all_update_data,
std::vector<pairwise_data> pairwise_update_data,
std::vector<reorder_map_type> reorder_maps,
std::vector<label> compress_to_global,
std::map<label, scalar *> linops)
std::vector<label> compress_to_global
// ,
// std::map<label, scalar *> linops
)
: gko::EnableLinOp<RepartDistMatrix>(exec),
gko::experimental::distributed::DistributedBase(comm),
fuse_(fuse),
Expand All @@ -158,8 +227,7 @@ class RepartDistMatrix
pairwise_update_data_(pairwise_update_data),
repartitioner_(repartitioner),
reorder_maps_(reorder_maps),
compress_to_global_(compress_to_global),
linops_(linops)
compress_to_global_(compress_to_global)
{
this->set_size(dist_mtx_->get_size());
}
Expand Down Expand Up @@ -219,18 +287,17 @@ class RepartDistMatrix

std::shared_ptr<dist_mtx> dist_mtx_;

// id, comm_pattern, data_ptr, offset
std::vector<all_to_all_data> all_to_all_update_data_;

std::vector<pairwise_data> pairwise_update_data_;

std::shared_ptr<const Repartitioner> repartitioner_;

// map, data_ptr, pad
std::vector<reorder_map_type> reorder_maps_;

std::vector<label> compress_to_global_;

std::map<label, scalar *>
linops_; // map between linop id and shared ptr to linop
};

std::shared_ptr<const gko::LinOp> get_local(
Expand Down
Empty file.
Loading
Loading