Skip to content

Commit e12958a

Browse files
committed
use optimizer as default
1 parent 56754de commit e12958a

File tree

4 files changed

+21
-12
lines changed

4 files changed

+21
-12
lines changed

include/OGL/DevicePersistent/ExecutorHandler.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ class ExecutorHandler
311311

312312
label get_ranks_per_gpu() const { return device_id_handler_.ranks_per_gpu; }
313313

314+
void set_ranks_per_gpu(label ranks_per_gpu) { device_id_handler_.ranks_per_gpu= ranks_per_gpu; }
315+
314316
label get_owner_rank() const { return device_id_handler_.global_owner(); }
315317

316318
const std::shared_ptr<gko::Executor> get_device_exec() const

include/OGL/StoppingCriterion.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,34 +201,30 @@ class StoppingCriterion {
201201
bool export_res, label prev_solve_iters,
202202
scalar prev_rel_cost) const
203203
{
204+
word frequencyMode = "optimizer";
204205
label minIter = minIter_;
205206
label frequency = frequency_;
206207
// in case of export_res all residuals need to be computed
207-
std::cout << __FILE__ << "adapt minIter and frequency0 \n";
208208
if (!export_res) {
209-
std::cout << __FILE__ << "adapt minIter and frequency1 \n";
210209
if (prev_solve_iters > 0 && adapt_minIter_ && prev_rel_cost > 0) {
211-
std::cout << __FILE__ << "adapt minIter and frequency2 \n";
212210
minIter = prev_solve_iters * relaxationFactor_;
213-
if (frequencyMode_ == "optimizer") {
211+
if (frequencyMode == "optimizer") {
214212
auto alpha = sqrt(
215213
1.0 / (prev_solve_iters * (1.0 - relaxationFactor_)) *
216214
prev_rel_cost);
217215
frequency = min(norm_eval_limit_, max(1, label(1 / alpha)));
218216
}
219-
// if (frequencyMode_ == "relative") {
220-
frequency = label(prev_solve_iters * 0.075) + 1;
221-
// }
217+
if (frequencyMode == "relative") {
218+
frequency = label(prev_solve_iters * 0.075) + 1;
219+
}
222220
}
223221
}
224222

225223
word msg = "Creating stopping criterion with minIter " +
226224
std::to_string(minIter) + " frequency " +
227225
std::to_string(frequency) + " prev_solve_iters " +
228-
std::to_string(prev_solve_iters) + +" adapt_minIter_ " +
229-
std::to_string(adapt_minIter_) + +" prev_rel_cost " +
230-
std::to_string(prev_rel_cost) + " prev_solve_iters*0.075 " +
231-
std::to_string(prev_solve_iters * 0.075);
226+
std::to_string(prev_solve_iters) + " adapt_minIter_ " +
227+
std::to_string(adapt_minIter_) + " prev_rel_cost ";
232228

233229
MLOG_0(verbose, msg)
234230

src/MatrixWrapper/Distributed.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,6 @@ std::shared_ptr<RepartDistMatrix> create_impl(
527527
label rank = exec_handler.get_host_rank();
528528
auto exec = exec_handler.get_ref_exec();
529529
auto host_comm = *exec_handler.get_host_comm().get();
530-
exec_handler.init_device_comm();
531530
auto device_comm = *exec_handler.get_device_comm().get();
532531
bool owner = repartitioner->is_owner(exec_handler);
533532

test/unit/MatrixWrapper/Distributed.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class Environment : public testing::Environment {
5656
Foam::IOobject::MUST_READ),
5757
false);
5858

59+
// FIXME this needs the device_id_handler
5960
exec = std::make_shared<ExecutorHandler>(runTime_->thisDb(), dict,
6061
"dummy", true);
6162

@@ -172,6 +173,7 @@ TEST_P(DistMatL2D, canCreateDistributedMatrix)
172173
{
173174
/* The test mesh is 6x6 grid decomposed into 4 3x3 subdomains */
174175
auto [ranks_per_gpu, matrix_format, fused] = GetParam();
176+
exec.set_ranks_per_gpu(ranks_per_gpu);
175177

176178
auto mesh = ((Environment *)global_env)->mesh;
177179
auto hostMatrix = ((Environment *)global_env)->hostMatrix;
@@ -182,6 +184,7 @@ TEST_P(DistMatL2D, canCreateDistributedMatrix)
182184
gko::dim<2> global_vec_dim{repartitioner->get_orig_partition()->get_size(),
183185
1};
184186
gko::dim<2> local_vec_dim{repartitioner->get_repart_dim()[0], 1};
187+
exec.init_device_comm();
185188

186189
auto distributed = create_distributed(exec, repartitioner, hostMatrix,
187190
matrix_format, fused, 0);
@@ -200,6 +203,8 @@ TEST_P(DistMatL2D, hasCorrectLocalMatrix)
200203
{
201204
/* The test mesh is 6x6 grid decomposed into 4 3x3 subdomains */
202205
auto [ranks_per_gpu, matrix_format, fused] = GetParam();
206+
exec.set_ranks_per_gpu(ranks_per_gpu);
207+
exec.init_device_comm();
203208
auto mesh = ((Environment *)global_env)->mesh;
204209
auto hostMatrix = ((Environment *)global_env)->hostMatrix;
205210
auto repartitioner = std::make_shared<Repartitioner>(
@@ -254,6 +259,8 @@ TEST_P(DistMatL2D, hasCorrectNonLocalMatrix)
254259
{
255260
/* The test mesh is 6x6 grid decomposed into 4 3x3 subdomains */
256261
auto [ranks_per_gpu, matrix_format, fused] = GetParam();
262+
exec.set_ranks_per_gpu(ranks_per_gpu);
263+
exec.init_device_comm();
257264
auto mesh = ((Environment *)global_env)->mesh;
258265
auto hostMatrix = ((Environment *)global_env)->hostMatrix;
259266
auto name = ((Environment *)global_env)->name_;
@@ -292,6 +299,8 @@ TEST_P(DistMatL2D, hasCorrectNonLocalMatrix)
292299
TEST_P(DistMatL2D, canApplyCorrectly)
293300
{
294301
auto [ranks_per_gpu, format, fused] = GetParam();
302+
exec.set_ranks_per_gpu(ranks_per_gpu);
303+
exec.init_device_comm();
295304
auto mesh = ((Environment *)global_env)->mesh;
296305
auto hostMatrix = ((Environment *)global_env)->hostMatrix;
297306
auto name = ((Environment *)global_env)->name_;
@@ -314,12 +323,15 @@ TEST_P(DistMatL2D, canApplyCorrectly)
314323
x->fill(0);
315324

316325
// Act
326+
bool active = repartitioner->get_repart_size() != 0;
327+
if (active){
317328
distributed->apply(b, x);
318329
auto res_x = std::vector<scalar>(
319330
x->get_local_vector()->get_const_values(),
320331
x->get_local_vector()->get_const_values() + local_vec_dim[0]);
321332

322333
ASSERT_EQ(res_x, exp_x[name][fused][ranks_per_gpu][rank]);
334+
}
323335
}
324336

325337
int main(int argc, char *argv[])

0 commit comments

Comments
 (0)