Skip to content

Commit bda6219

Browse files
authored
test-backend-ops : extend test case filtering (#14865)
* Extend test case filtering 1. Allow passing multiple (comma-separated?) ops to test-backend-ops. This can be convenient when working on a set of ops, when you'd want to test them together (but without having to run every single op). For example: `test-backend-ops.exe test -o "ADD,RMS_NORM,ROPE,SILU,SOFT_MAX"` 2. Support full test-case variation string in addition to basic op names. This would make it easy to select a single variation, either for testing or for benchmarking. It can be particularly useful for profiling a particular variation (ex. a CUDA kernel), for example: `test-backend-ops.exe perf -b CUDA0 -o "MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=2)"` These two can be combined. As the current `-o`, this change doesn't try to detect/report an error if an filter doesn't name existing ops (ex. misspelled) * Updating the usage help text * Update tests/test-backend-ops.cpp
1 parent c556418 commit bda6219

File tree

1 file changed

+50
-18
lines changed

1 file changed

+50
-18
lines changed

tests/test-backend-ops.cpp

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <random>
3636
#include <regex>
3737
#include <string>
38+
#include <string_view>
3839
#include <thread>
3940
#include <vector>
4041

@@ -1047,7 +1048,37 @@ struct test_case {
10471048
return t;
10481049
}
10491050

1050-
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
1051+
// Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1052+
bool matches_filter(ggml_tensor * op, const char * op_names_filter) {
1053+
if (op_names_filter) {
1054+
const auto op_name = op_desc(op);
1055+
const auto op_full_name = op_name + "(" + vars() + ")";
1056+
std::string_view filter(op_names_filter);
1057+
while (!filter.empty()) {
1058+
auto comma_pos = filter.find_first_of(',');
1059+
const auto lparen_pos = filter.find_first_of('(');
1060+
if (lparen_pos < comma_pos) {
1061+
auto rparen_pos = filter.find_first_of(')');
1062+
comma_pos = filter.find_first_of(',', rparen_pos);
1063+
const auto op_filter = filter.substr(0, comma_pos);
1064+
if (op_filter == op_full_name) {
1065+
return true;
1066+
}
1067+
} else {
1068+
const auto op_filter = filter.substr(0, comma_pos);
1069+
if (op_filter == op_name) {
1070+
return true;
1071+
}
1072+
}
1073+
filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : "";
1074+
}
1075+
return false;
1076+
} else {
1077+
return true;
1078+
}
1079+
}
1080+
1081+
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
10511082
mode = MODE_TEST;
10521083

10531084
ggml_init_params params = {
@@ -1065,7 +1096,7 @@ struct test_case {
10651096

10661097
ggml_tensor * out = build_graph(ctx);
10671098
std::string current_op_name = op_desc(out);
1068-
if (op_name != nullptr && current_op_name != op_name) {
1099+
if (!matches_filter(out, op_names_filter)) {
10691100
//printf(" %s: skipping\n", op_desc(out).c_str());
10701101
ggml_free(ctx);
10711102
return true;
@@ -1212,7 +1243,7 @@ struct test_case {
12121243
return test_passed;
12131244
}
12141245

1215-
bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1246+
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
12161247
mode = MODE_PERF;
12171248

12181249
static const size_t graph_nodes = 8192;
@@ -1227,7 +1258,7 @@ struct test_case {
12271258

12281259
ggml_tensor * out = build_graph(ctx.get());
12291260
std::string current_op_name = op_desc(out);
1230-
if (op_name != nullptr && current_op_name != op_name) {
1261+
if (!matches_filter(out, op_names_filter)) {
12311262
//printf(" %s: skipping\n", op_desc(out).c_str());
12321263
return true;
12331264
}
@@ -1342,7 +1373,7 @@ struct test_case {
13421373
return true;
13431374
}
13441375

1345-
bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1376+
bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
13461377
mode = MODE_SUPPORT;
13471378

13481379
static const size_t graph_nodes = 8192;
@@ -1357,7 +1388,7 @@ struct test_case {
13571388

13581389
ggml_tensor * out = build_graph(ctx.get());
13591390
std::string current_op_name = op_desc(out);
1360-
if (op_name != nullptr && current_op_name != op_name) {
1391+
if (!matches_filter(out, op_names_filter)) {
13611392
return true;
13621393
}
13631394

@@ -1374,7 +1405,7 @@ struct test_case {
13741405
return true;
13751406
}
13761407

1377-
bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1408+
bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
13781409
mode = MODE_GRAD;
13791410
const std::vector<float> expect = grad_expect();
13801411

@@ -1391,7 +1422,7 @@ struct test_case {
13911422

13921423
ggml_tensor * out = build_graph(ctx.get());
13931424

1394-
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1425+
if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {
13951426
return true;
13961427
}
13971428

@@ -5922,7 +5953,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
59225953
return test_cases;
59235954
}
59245955

5925-
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter,
5956+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
59265957
printer * output_printer) {
59275958
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
59285959
if (params_filter == nullptr) {
@@ -5954,7 +5985,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59545985

59555986
size_t n_ok = 0;
59565987
for (auto & test : test_cases) {
5957-
if (test->eval(backend, backend_cpu, op_name, output_printer)) {
5988+
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
59585989
n_ok++;
59595990
}
59605991
}
@@ -5970,7 +6001,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59706001
filter_test_cases(test_cases, params_filter);
59716002
size_t n_ok = 0;
59726003
for (auto & test : test_cases) {
5973-
if (test->eval_grad(backend, op_name, output_printer)) {
6004+
if (test->eval_grad(backend, op_names_filter, output_printer)) {
59746005
n_ok++;
59756006
}
59766007
}
@@ -5983,7 +6014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59836014
auto test_cases = make_test_cases_perf();
59846015
filter_test_cases(test_cases, params_filter);
59856016
for (auto & test : test_cases) {
5986-
test->eval_perf(backend, op_name, output_printer);
6017+
test->eval_perf(backend, op_names_filter, output_printer);
59876018
}
59886019
return true;
59896020
}
@@ -5992,7 +6023,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59926023
auto test_cases = make_test_cases_eval();
59936024
filter_test_cases(test_cases, params_filter);
59946025
for (auto & test : test_cases) {
5995-
test->eval_support(backend, op_name, output_printer);
6026+
test->eval_support(backend, op_names_filter, output_printer);
59966027
}
59976028
return true;
59986029
}
@@ -6001,20 +6032,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
60016032
}
60026033

60036034
static void usage(char ** argv) {
6004-
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
6035+
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
60056036
printf(" valid modes:\n");
60066037
printf(" - test (default, compare with CPU backend for correctness)\n");
60076038
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
60086039
printf(" - perf (performance evaluation)\n");
60096040
printf(" - support (probe backend operation support)\n");
6010-
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
6041+
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n");
6042+
printf(" optionally including the full test case string (e.g. \"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\")\n");
60116043
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
60126044
}
60136045

60146046
int main(int argc, char ** argv) {
60156047
test_mode mode = MODE_TEST;
60166048
output_formats output_format = CONSOLE;
6017-
const char * op_name_filter = nullptr;
6049+
const char * op_names_filter = nullptr;
60186050
const char * backend_filter = nullptr;
60196051
const char * params_filter = nullptr;
60206052

@@ -6029,7 +6061,7 @@ int main(int argc, char ** argv) {
60296061
mode = MODE_SUPPORT;
60306062
} else if (strcmp(argv[i], "-o") == 0) {
60316063
if (i + 1 < argc) {
6032-
op_name_filter = argv[++i];
6064+
op_names_filter = argv[++i];
60336065
} else {
60346066
usage(argv);
60356067
return 1;
@@ -6110,7 +6142,7 @@ int main(int argc, char ** argv) {
61106142
false, "", ggml_backend_dev_description(dev),
61116143
total / 1024 / 1024, free / 1024 / 1024, true));
61126144

6113-
bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
6145+
bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get());
61146146

61156147
if (ok) {
61166148
n_ok++;

0 commit comments

Comments
 (0)