35
35
#include < random>
36
36
#include < regex>
37
37
#include < string>
38
+ #include < string_view>
38
39
#include < thread>
39
40
#include < vector>
40
41
@@ -1047,7 +1048,37 @@ struct test_case {
1047
1048
return t;
1048
1049
}
1049
1050
1050
- bool eval (ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
1051
+ // Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1052
+ bool matches_filter (ggml_tensor * op, const char * op_names_filter) {
1053
+ if (op_names_filter) {
1054
+ const auto op_name = op_desc (op);
1055
+ const auto op_full_name = op_name + " (" + vars () + " )" ;
1056
+ std::string_view filter (op_names_filter);
1057
+ while (!filter.empty ()) {
1058
+ auto comma_pos = filter.find_first_of (' ,' );
1059
+ const auto lparen_pos = filter.find_first_of (' (' );
1060
+ if (lparen_pos < comma_pos) {
1061
+ auto rparen_pos = filter.find_first_of (' )' );
1062
+ comma_pos = filter.find_first_of (' ,' , rparen_pos);
1063
+ const auto op_filter = filter.substr (0 , comma_pos);
1064
+ if (op_filter == op_full_name) {
1065
+ return true ;
1066
+ }
1067
+ } else {
1068
+ const auto op_filter = filter.substr (0 , comma_pos);
1069
+ if (op_filter == op_name) {
1070
+ return true ;
1071
+ }
1072
+ }
1073
+ filter = comma_pos != std::string_view::npos ? filter.substr (comma_pos + 1 ) : " " ;
1074
+ }
1075
+ return false ;
1076
+ } else {
1077
+ return true ;
1078
+ }
1079
+ }
1080
+
1081
+ bool eval (ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
1051
1082
mode = MODE_TEST;
1052
1083
1053
1084
ggml_init_params params = {
@@ -1065,7 +1096,7 @@ struct test_case {
1065
1096
1066
1097
ggml_tensor * out = build_graph (ctx);
1067
1098
std::string current_op_name = op_desc (out);
1068
- if (op_name != nullptr && current_op_name != op_name ) {
1099
+ if (! matches_filter (out, op_names_filter) ) {
1069
1100
// printf(" %s: skipping\n", op_desc(out).c_str());
1070
1101
ggml_free (ctx);
1071
1102
return true ;
@@ -1212,7 +1243,7 @@ struct test_case {
1212
1243
return test_passed;
1213
1244
}
1214
1245
1215
- bool eval_perf (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1246
+ bool eval_perf (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
1216
1247
mode = MODE_PERF;
1217
1248
1218
1249
static const size_t graph_nodes = 8192 ;
@@ -1227,7 +1258,7 @@ struct test_case {
1227
1258
1228
1259
ggml_tensor * out = build_graph (ctx.get ());
1229
1260
std::string current_op_name = op_desc (out);
1230
- if (op_name != nullptr && current_op_name != op_name ) {
1261
+ if (! matches_filter (out, op_names_filter) ) {
1231
1262
// printf(" %s: skipping\n", op_desc(out).c_str());
1232
1263
return true ;
1233
1264
}
@@ -1342,7 +1373,7 @@ struct test_case {
1342
1373
return true ;
1343
1374
}
1344
1375
1345
- bool eval_support (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1376
+ bool eval_support (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
1346
1377
mode = MODE_SUPPORT;
1347
1378
1348
1379
static const size_t graph_nodes = 8192 ;
@@ -1357,7 +1388,7 @@ struct test_case {
1357
1388
1358
1389
ggml_tensor * out = build_graph (ctx.get ());
1359
1390
std::string current_op_name = op_desc (out);
1360
- if (op_name != nullptr && current_op_name != op_name ) {
1391
+ if (! matches_filter (out, op_names_filter) ) {
1361
1392
return true ;
1362
1393
}
1363
1394
@@ -1374,7 +1405,7 @@ struct test_case {
1374
1405
return true ;
1375
1406
}
1376
1407
1377
- bool eval_grad (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1408
+ bool eval_grad (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
1378
1409
mode = MODE_GRAD;
1379
1410
const std::vector<float > expect = grad_expect ();
1380
1411
@@ -1391,7 +1422,7 @@ struct test_case {
1391
1422
1392
1423
ggml_tensor * out = build_graph (ctx.get ());
1393
1424
1394
- if ((op_name != nullptr && op_desc (out) != op_name ) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1425
+ if (! matches_filter (out, op_names_filter ) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1395
1426
return true ;
1396
1427
}
1397
1428
@@ -5922,7 +5953,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
5922
5953
return test_cases;
5923
5954
}
5924
5955
5925
- static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name , const char * params_filter,
5956
+ static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_names_filter , const char * params_filter,
5926
5957
printer * output_printer) {
5927
5958
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
5928
5959
if (params_filter == nullptr ) {
@@ -5954,7 +5985,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
5954
5985
5955
5986
size_t n_ok = 0 ;
5956
5987
for (auto & test : test_cases) {
5957
- if (test->eval (backend, backend_cpu, op_name , output_printer)) {
5988
+ if (test->eval (backend, backend_cpu, op_names_filter , output_printer)) {
5958
5989
n_ok++;
5959
5990
}
5960
5991
}
@@ -5970,7 +6001,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
5970
6001
filter_test_cases (test_cases, params_filter);
5971
6002
size_t n_ok = 0 ;
5972
6003
for (auto & test : test_cases) {
5973
- if (test->eval_grad (backend, op_name , output_printer)) {
6004
+ if (test->eval_grad (backend, op_names_filter , output_printer)) {
5974
6005
n_ok++;
5975
6006
}
5976
6007
}
@@ -5983,7 +6014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
5983
6014
auto test_cases = make_test_cases_perf ();
5984
6015
filter_test_cases (test_cases, params_filter);
5985
6016
for (auto & test : test_cases) {
5986
- test->eval_perf (backend, op_name , output_printer);
6017
+ test->eval_perf (backend, op_names_filter , output_printer);
5987
6018
}
5988
6019
return true ;
5989
6020
}
@@ -5992,7 +6023,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
5992
6023
auto test_cases = make_test_cases_eval ();
5993
6024
filter_test_cases (test_cases, params_filter);
5994
6025
for (auto & test : test_cases) {
5995
- test->eval_support (backend, op_name , output_printer);
6026
+ test->eval_support (backend, op_names_filter , output_printer);
5996
6027
}
5997
6028
return true ;
5998
6029
}
@@ -6001,20 +6032,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
6001
6032
}
6002
6033
6003
6034
static void usage (char ** argv) {
6004
- printf (" Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n " , argv[0 ]);
6035
+ printf (" Usage: %s [mode] [-o <op,.. >] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n " , argv[0 ]);
6005
6036
printf (" valid modes:\n " );
6006
6037
printf (" - test (default, compare with CPU backend for correctness)\n " );
6007
6038
printf (" - grad (compare gradients from backpropagation with method of finite differences)\n " );
6008
6039
printf (" - perf (performance evaluation)\n " );
6009
6040
printf (" - support (probe backend operation support)\n " );
6010
- printf (" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n " );
6041
+ printf (" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n " );
6042
+ printf (" optionally including the full test case string (e.g. \" ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\" )\n " );
6011
6043
printf (" --output specifies output format (default: console, options: console, sql, csv)\n " );
6012
6044
}
6013
6045
6014
6046
int main (int argc, char ** argv) {
6015
6047
test_mode mode = MODE_TEST;
6016
6048
output_formats output_format = CONSOLE;
6017
- const char * op_name_filter = nullptr ;
6049
+ const char * op_names_filter = nullptr ;
6018
6050
const char * backend_filter = nullptr ;
6019
6051
const char * params_filter = nullptr ;
6020
6052
@@ -6029,7 +6061,7 @@ int main(int argc, char ** argv) {
6029
6061
mode = MODE_SUPPORT;
6030
6062
} else if (strcmp (argv[i], " -o" ) == 0 ) {
6031
6063
if (i + 1 < argc) {
6032
- op_name_filter = argv[++i];
6064
+ op_names_filter = argv[++i];
6033
6065
} else {
6034
6066
usage (argv);
6035
6067
return 1 ;
@@ -6110,7 +6142,7 @@ int main(int argc, char ** argv) {
6110
6142
false , " " , ggml_backend_dev_description (dev),
6111
6143
total / 1024 / 1024 , free / 1024 / 1024 , true ));
6112
6144
6113
- bool ok = test_backend (backend, mode, op_name_filter , params_filter, output_printer.get ());
6145
+ bool ok = test_backend (backend, mode, op_names_filter , params_filter, output_printer.get ());
6114
6146
6115
6147
if (ok) {
6116
6148
n_ok++;
0 commit comments