Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <random>
#include <regex>
#include <string>
#include <string_view>
#include <thread>
#include <vector>

Expand Down Expand Up @@ -1020,7 +1021,37 @@ struct test_case {
return t;
}

bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
// Checks an op against the test filter, which is a comma separated list of OP names or specific variations
bool matches_filter(ggml_tensor* op, const char* op_names_filter) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, this method could be made const. That would require making test_case::op_desc() and test_case::vars() const as well, which is a trivial, but noisy change.

If reviewers agree with that, I'm happy to make the additional change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not worth it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a welcome change in a separate PR.

if (op_names_filter) {
const auto op_name = op_desc(op);
const auto op_full_name = op_name + "(" + vars() + ")";
std::string_view filter(op_names_filter);
while (!filter.empty()) {
auto comma_pos = filter.find_first_of(',');
const auto lparen_pos = filter.find_first_of('(');
if (lparen_pos < comma_pos) {
auto rparen_pos = filter.find_first_of(')');
comma_pos = filter.find_first_of(',', rparen_pos);
const auto op_filter = filter.substr(0, comma_pos);
if (op_filter == op_full_name) {
return true;
}
} else {
const auto op_filter = filter.substr(0, comma_pos);
if (op_filter == op_name) {
return true;
}
}
filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : "";
}
return false;
} else {
return true;
}
}

bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
mode = MODE_TEST;

ggml_init_params params = {
Expand All @@ -1038,7 +1069,7 @@ struct test_case {

ggml_tensor * out = build_graph(ctx);
std::string current_op_name = op_desc(out);
if (op_name != nullptr && current_op_name != op_name) {
if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
Expand Down Expand Up @@ -1185,7 +1216,7 @@ struct test_case {
return test_passed;
}

bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) {
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
mode = MODE_PERF;

static const size_t graph_nodes = 8192;
Expand All @@ -1200,7 +1231,7 @@ struct test_case {

ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
if (op_name != nullptr && current_op_name != op_name) {
if (!matches_filter(out, op_names_filter)) {
//printf(" %s: skipping\n", op_desc(out).c_str());
return true;
}
Expand Down Expand Up @@ -1315,7 +1346,7 @@ struct test_case {
return true;
}

bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
mode = MODE_SUPPORT;

static const size_t graph_nodes = 8192;
Expand All @@ -1330,7 +1361,7 @@ struct test_case {

ggml_tensor * out = build_graph(ctx.get());
std::string current_op_name = op_desc(out);
if (op_name != nullptr && current_op_name != op_name) {
if (!matches_filter(out, op_names_filter)) {
return true;
}

Expand All @@ -1347,7 +1378,7 @@ struct test_case {
return true;
}

bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
mode = MODE_GRAD;
const std::vector<float> expect = grad_expect();

Expand All @@ -1364,7 +1395,7 @@ struct test_case {

ggml_tensor * out = build_graph(ctx.get());

if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {
return true;
}

Expand Down Expand Up @@ -5881,7 +5912,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
return test_cases;
}

static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter,
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
printer * output_printer) {
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
if (params_filter == nullptr) {
Expand Down Expand Up @@ -5913,7 +5944,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op

size_t n_ok = 0;
for (auto & test : test_cases) {
if (test->eval(backend, backend_cpu, op_name, output_printer)) {
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
n_ok++;
}
}
Expand All @@ -5929,7 +5960,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
filter_test_cases(test_cases, params_filter);
size_t n_ok = 0;
for (auto & test : test_cases) {
if (test->eval_grad(backend, op_name, output_printer)) {
if (test->eval_grad(backend, op_names_filter, output_printer)) {
n_ok++;
}
}
Expand All @@ -5942,7 +5973,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
auto test_cases = make_test_cases_perf();
filter_test_cases(test_cases, params_filter);
for (auto & test : test_cases) {
test->eval_perf(backend, op_name, output_printer);
test->eval_perf(backend, op_names_filter, output_printer);
}
return true;
}
Expand All @@ -5951,7 +5982,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
for (auto & test : test_cases) {
test->eval_support(backend, op_name, output_printer);
test->eval_support(backend, op_names_filter, output_printer);
}
return true;
}
Expand All @@ -5973,7 +6004,7 @@ static void usage(char ** argv) {
int main(int argc, char ** argv) {
test_mode mode = MODE_TEST;
output_formats output_format = CONSOLE;
const char * op_name_filter = nullptr;
const char * op_names_filter = nullptr;
const char * backend_filter = nullptr;
const char * params_filter = nullptr;

Expand All @@ -5988,7 +6019,7 @@ int main(int argc, char ** argv) {
mode = MODE_SUPPORT;
} else if (strcmp(argv[i], "-o") == 0) {
if (i + 1 < argc) {
op_name_filter = argv[++i];
op_names_filter = argv[++i];
} else {
usage(argv);
return 1;
Expand Down Expand Up @@ -6069,7 +6100,7 @@ int main(int argc, char ** argv) {
false, "", ggml_backend_dev_description(dev),
total / 1024 / 1024, free / 1024 / 1024, true));

bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get());

if (ok) {
n_ok++;
Expand Down