From 8cee16f0c3fdbed4150e14cb46b353cb4fb53cb8 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 23 Jul 2025 12:44:34 +0200 Subject: [PATCH 1/7] Enables perf/eval testing of composite ops in test-backend-ops. --- ggml/include/ggml-backend.h | 4 + ggml/src/ggml-backend.cpp | 47 ++++ tests/test-backend-ops.cpp | 520 +++++++++++++++++++++++++++++++++--- 3 files changed, 540 insertions(+), 31 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a2977ea2e56d9..4363d1a4efeb2 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -340,6 +340,10 @@ extern "C" { // Compare the output of two backends GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + GGML_API bool ggml_backend_compare_graph_backend_node(ggml_backend_t backend1, ggml_backend_t backend2, + struct ggml_cgraph * graph1, struct ggml_cgraph * graph2, + ggml_backend_eval_callback callback, void * user_data, + const char * op_name_out_1, const char * op_name_out_2); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b7498b8d40238..fb7a470dddd9e 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -114,6 +114,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { return buffer->size; } + void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { // get_base is optional if the buffer is zero-sized if (buffer->size == 0) { @@ -1867,6 +1868,52 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } +bool ggml_backend_compare_graph_backend_node(ggml_backend_t backend1, ggml_backend_t backend2, + struct ggml_cgraph * graph1, struct ggml_cgraph * graph2, + ggml_backend_eval_callback callback, void * user_data, + const char * op_name_out_1, const char * op_name_out_2) { + ggml_tensor * out1 = NULL; + ggml_tensor * out2 = NULL; + + struct ggml_cgraph * g1 = graph1; + struct ggml_cgraph * g2 = graph2; + + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + ggml_backend_graph_compute(backend1, &g1v); + if (ggml_is_view_op(t1->op)) { + continue; + } + if (strcmp(t1->name, op_name_out_1) == 0) { + out1 = t1; + } + } + + for (int i = 0; i < g2->n_nodes; i++) { + struct ggml_tensor * t2 = g2->nodes[i]; + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + ggml_backend_graph_compute(backend2, &g2v); + if (ggml_is_view_op(t2->op)) { + continue; + } + if (strcmp(t2->name, op_name_out_2) == 0) { + out2 = t2; + } + } + + assert(out1 != NULL); + assert(out2 != NULL); + assert(ggml_are_same_layout(out1, out2)); + + // compare results, calculate rms etc + if (!callback(0, out1, out2, user_data)) { + return false; + } + + return true; +} + // CPU backend - buffer static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 4898094c918e1..132a24e8a2c22 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -952,6 +953,9 @@ struct test_case { } } + // Returns the names of the inputs. Used when comparing different ops. + virtual std::vector get_input_names() { return {}; } + virtual size_t op_size(ggml_tensor * t) { size_t size = ggml_nbytes(t); // add source tensors @@ -1020,7 +1024,7 @@ struct test_case { return t; } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { + virtual bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1161,6 +1165,7 @@ struct test_case { //exit(1); ud->ok = false; } + return true; GGML_UNUSED(index); @@ -1185,17 +1190,17 @@ struct test_case { return test_passed; } - bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) { + virtual bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) { mode = MODE_PERF; static const size_t graph_nodes = 8192; ggml_init_params params = { - /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false), + /* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead_custom(graph_nodes, false), /* .mem_base = */ NULL, /* .no_alloc = */ true, }; - ggml_context_ptr ctx(ggml_init(params)); // smart ptr + ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); ggml_tensor * out = build_graph(ctx.get()); @@ -1205,7 +1210,18 @@ struct test_case { return true; } - if (!ggml_backend_supports_op(backend, out)) { + uint32_t number_of_nodes = 0; + // Check if all nodes in the graph are supported + bool supported = true; + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { + number_of_nodes += 1; + if (!ggml_backend_supports_op(backend, t)) { + supported = false; + break; + } + } + + if (!supported) { // Create test result for unsupported performance test test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false, "not supported"); @@ -1216,7 +1232,7 @@ struct test_case { } // allocate - ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr + ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr if (buf == NULL) { printf("failed to allocate tensors\n"); @@ -1233,37 +1249,67 @@ struct test_case { // warmup run ggml_status status = ggml_backend_graph_compute(backend, gf); if (status != GGML_STATUS_SUCCESS) { - fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, + ggml_status_to_string(status)); return false; } // determine number of runs - int n_runs; + int n_runs; bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU; + + // how many nodes are added by the test case + uint32_t nodes_per_op = 1; + + // If input names is not empty, then we need to support composite ops: + // the majority of the computation is not necessarily the + // output node. In this case we need to add all nodes to the graph except the input nodes. + if (get_input_names().size() > 0) { + nodes_per_op = number_of_nodes - get_input_names().size(); + } + if (op_flops(out) > 0) { // based on flops - const uint64_t GFLOP = 1000 * 1000 * 1000; - const uint64_t target_flops_cpu = 8ULL * GFLOP; + const uint64_t GFLOP = 1000 * 1000 * 1000; + const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; - uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; + n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf)) / nodes_per_op, + target_flops / op_flops(out)) + + 1; } else { // based on memory size - const size_t GB = 1ULL << 30; - const size_t target_size_cpu = 8 * GB; + const size_t GB = 1ULL << 30; + const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; - size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; + n_runs = std::min((ggml_graph_size(gf) - ggml_graph_n_nodes(gf)) / nodes_per_op, + target_size / op_size(out)) + + 1; } - // duplicate the op + // duplicate the op: + std::vector input_names = get_input_names(); for (int i = 1; i < n_runs; i++) { - ggml_graph_add_node(gf, out); + // If it is a composite op, we need to add all ops that are not input nodes + if (input_names.size() > 0) { + uint32_t added_nodes = 0; + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; + t = ggml_get_next_tensor(ctx.get(), t)) { + if (std::find(input_names.begin(), input_names.end(), t->name) == input_names.end()) { + added_nodes += 1; + ggml_graph_add_node(gf, t); + } + } + assert(added_nodes == nodes_per_op); + } else { + ggml_graph_add_node(gf, out); + } } // calculate memory - size_t mem = n_runs * op_size(out); - auto tensor_op_size = [](ggml_tensor * t) { + size_t mem = n_runs * op_size(out); + auto tensor_op_size = [](ggml_tensor * t) { size_t size = ggml_nbytes(t); // add source tensors for (int i = 0; i < GGML_MAX_SRC; i++) { @@ -1282,13 +1328,14 @@ struct test_case { // run int64_t total_time_us = 0; - int64_t total_mem = 0; - int total_runs = 0; + int64_t total_mem = 0; + int total_runs = 0; do { - int64_t start_time = ggml_time_us(); - ggml_status status = ggml_backend_graph_compute(backend, gf); + int64_t start_time = ggml_time_us(); + ggml_status status = ggml_backend_graph_compute(backend, gf); if (status != GGML_STATUS_SUCCESS) { - fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, ggml_status_to_string(status)); + fprintf(stderr, "%s: ggml_backend_graph_compute failed. status=%s \n", __func__, + ggml_status_to_string(status)); return false; } int64_t end_time = ggml_time_us(); @@ -1296,7 +1343,7 @@ struct test_case { total_time_us += end_time - start_time; total_mem += mem; total_runs += n_runs; - } while (total_time_us < 1000*1000); // run for at least 1 second + } while (total_time_us < 1000 * 1000); // run for at least 1 second // Create test result double avg_time_us = (double) total_time_us / total_runs; @@ -1600,6 +1647,8 @@ struct test_case { }; + + // ################################### // ## Section 2: GGML Op Defintions ## // ################################### @@ -1650,6 +1699,274 @@ struct test_example : public test_case { // This is optional and only makes sense if a backward pass has actually been implemented for the new op. }; +// Receives two test cases, initializes test_case1, copies the inputs +// to test_case2 based on provided input name assignments, then compares +// the results. +struct test_case_compare : public test_case { + protected: + test_case_compare(test_case * case1, test_case * case2) : case1(case1), case2(case2) {} + public: + test_case * case1; + test_case * case2; + ggml_cgraph * gf2 = nullptr; + + std::string vars() override { return "(" + case1->vars() + "),(" + case2->vars() + ")"; } + + ggml_tensor * build_graph(ggml_context * ctx) { + GGML_UNUSED(ctx); + return nullptr; + } + + // Copies the inputs from test_case1 to test_case2. + virtual void copy_inputs(ggml_context * ctx, ggml_context * ctx2, + std::vector> input_names) { + std::map inputs; + std::map inputs2; + + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + for (auto inp_assignment : input_names) { + if (inp_assignment.first == t->name) { + inputs[inp_assignment.first] = t; + } + } + } + + for (ggml_tensor * t = ggml_get_first_tensor(ctx2); t != nullptr; t = ggml_get_next_tensor(ctx2, t)) { + for (auto inp_assignment : input_names) { + if (inp_assignment.second == t->name) { + inputs2[inp_assignment.second] = t; + } + } + } + + for (auto inp_assignment : input_names) { + GGML_ASSERT(inputs.count(inp_assignment.first) == 1); + GGML_ASSERT(inputs2.count(inp_assignment.second) == 1); + std::vector buf(ggml_nbytes(inputs[inp_assignment.first])); + ggml_backend_tensor_get(inputs[inp_assignment.first], buf.data(), 0, ggml_nbytes(inputs[inp_assignment.first])); + ggml_backend_tensor_set(inputs2[inp_assignment.second], buf.data(), 0, buf.size()); + } + } + + std::map> sentinels; + + void add_sentinel(ggml_context * ctx) { + if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { + return; + } + ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size); + ggml_format_name(sentinel, "sent_%zu", sentinels[ctx].size()); + sentinels[ctx].push_back(sentinel); + } + + // Compares the output of the actual graph to the output of the reference + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, + std::vector> input_names, std::string output_name_1, + std::string output_name_2, printer * output_printer) { + mode = MODE_TEST; + + ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(), + /* .mem_base = */ NULL, + /* .no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + ggml_context * ctx2 = ggml_init(params); + GGML_ASSERT(ctx); + GGML_ASSERT(ctx2); + + gf = ggml_new_graph(ctx); + gf2 = ggml_new_graph(ctx2); + + // pre-graph sentinel + add_sentinel(ctx); + add_sentinel(ctx2); + + ggml_tensor * out = case1->build_graph(ctx); + ggml_tensor * out2 = case2->build_graph(ctx2); + + std::string current_op_name = op_desc(out); + + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); + ggml_free(ctx); + return true; + } + + // check if the backends support the ops + bool supported = true; + ggml_backend * backend_tested = nullptr; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (!ggml_backend_supports_op(backend1, t)) { + supported = false; + backend_tested = backend1; + break; + } + } + + if (supported) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx2); t != NULL; t = ggml_get_next_tensor(ctx2, t)) { + if (!ggml_backend_supports_op(backend2, t)) { + supported = false; + backend_tested = backend2; + break; + } + } + } + + if (!supported) { + // Create test result for unsupported operation + test_result result(ggml_backend_name(backend_tested), current_op_name, vars(), "test", false, false, + "not supported"); + if (output_printer) { + output_printer->print_test_result(result); + } + + ggml_free(ctx); + return true; + } + + // post-graph sentinel + add_sentinel(ctx); + add_sentinel(ctx2); + + // allocate + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); + + if (buf == NULL) { + printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); + ggml_free(ctx); + return false; + } + + ggml_backend_buffer_t buf_ref = ggml_backend_alloc_ctx_tensors(ctx2, backend2); + if (buf_ref == NULL) { + printf("failed to allocate tensors [%s] ", ggml_backend_name(backend2)); + ggml_free(ctx2); + return false; + } + + // build graph + ggml_build_forward_expand(gf, out); + ggml_build_forward_expand(gf2, out2); + + // add sentinels as graph nodes so that they are checked in the callback + for (ggml_tensor * sentinel : sentinels[ctx]) { + ggml_graph_add_node(gf, sentinel); + } + + for (ggml_tensor * sentinel : sentinels[ctx2]) { + ggml_graph_add_node(gf2, sentinel); + } + + // randomize tensors + initialize_tensors(ctx); + copy_inputs(ctx, ctx2, input_names); + + // compare + struct callback_userdata { + bool ok; + double max_err; + ggml_backend_t backend1; + ggml_backend_t backend2; + }; + + callback_userdata ud{ true, max_nmse_err(), backend1, backend2 }; + + auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { + callback_userdata * ud = (callback_userdata *) user_data; + const char * bn1 = ggml_backend_name(ud->backend1); + const char * bn2 = ggml_backend_name(ud->backend2); + + if (t1->op == GGML_OP_NONE) { + // sentinels must be unchanged + std::vector t1_data(ggml_nbytes(t1)); + std::vector t2_data(ggml_nbytes(t2)); + ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); + + if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { + printf("sentinel mismatch: %s ", t1->name); + ud->ok = false; + return true; + } + } + + std::vector f1 = tensor_to_float(t1); + std::vector f2 = tensor_to_float(t2); + + for (size_t i = 0; i < f1.size(); i++) { + // check for nans + if (std::isnan(f1[i]) || std::isnan(f2[i])) { + printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + // check for infs: both must be inf of the same sign, or both must be finite + if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { + if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { + if (std::signbit(f1[i]) != std::signbit(f2[i])) { + printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } else { + printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } + } + + double err = nmse(f1.data(), f2.data(), f1.size()); + if (err > ud->max_err) { + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + //for (int i = 0; i < (int) f1.size(); i++) { + // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + //} + //printf("\n"); + //exit(1); + ud->ok = false; + } + return true; + + GGML_UNUSED(index); + }; + + const bool cmp_ok = ggml_backend_compare_graph_backend_node(backend1, backend2, gf, gf2, callback, &ud, + output_name_1.c_str(), output_name_2.c_str()); + + if (!cmp_ok) { + printf("compare failed "); + } + + ggml_backend_buffer_free(buf); + ggml_backend_buffer_free(buf_ref); + + ggml_free(ctx); + ggml_free(ctx2); + + // Create test result + bool test_passed = ud.ok && cmp_ok; + std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed"); + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed, + error_msg); + + if (output_printer) { + output_printer->print_test_result(result); + } + + return test_passed; + } + + bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) override { + printf("This test case does not support performance evaluation.\n"); + GGML_UNUSED(backend); + GGML_UNUSED(op_name); + GGML_UNUSED(output_printer); + return true; + } +}; // GGML_OP_UNARY struct test_unary : public test_case { @@ -3703,8 +4020,7 @@ struct test_im2col : public test_case { } }; -// CONV_2D -struct test_conv_2d : public test_case { +struct test_conv_2d_im2col : public test_case { const std::array ne_input; const std::array ne_kernel; const int stride0; @@ -3723,6 +4039,102 @@ struct test_conv_2d : public test_case { // * if the program is called with -o CONV_2D_INDIRECT_IMPL, the // IM2COL -> MUL_MM graph will be built. + virtual std::string op_desc(ggml_tensor * t) override { + (void) t; + return std::string("CONV_2D_IM2COL"); + } + + bool run_whole_graph() { return false; } + + std::string vars() override { + return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + // Just counting matmul costs: + // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + int64_t W = ne_input[0]; + int64_t H = ne_input[1]; + int64_t KW = ne_kernel[0]; + int64_t KH = ne_kernel[1]; + int64_t Cin = ne_kernel[2]; + int64_t Cout = ne_kernel[3]; + int64_t N = ne_input[3]; + int64_t OH = calc_conv_output_size(H, KH, stride0, padding0, dilation0); + int64_t OW = calc_conv_output_size(W, KW, stride0, padding0, dilation0); + + int64_t K = Cout; + int64_t CRS = Cin * KH * KW; + int64_t NPQ = N * OH * OW; + + return K * NPQ * (2 * CRS - 1); + } + + test_conv_2d_im2col(std::array ne_input = { 64, 64, 16, 1 }, + std::array ne_kernel = { 3, 3, 1, 16 }, int stride0 = 1, int stride1 = 1, + int padding0 = 0, int padding1 = 0, int dilation0 = 1, int dilation1 = 1, bool cwhn = false) : + ne_input(ne_input), + ne_kernel(ne_kernel), + stride0(stride0), + stride1(stride1), + padding0(padding0), + padding1(padding1), + dilation0(dilation0), + dilation1(dilation1), + cwhn(cwhn) {} + + std::vector get_input_names() { return { "kernel", "input" }; } + + ggml_tensor * build_graph(ggml_context * ctx) override { + //printf("Building conv_2d_im2col graph...\n"); + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); + ggml_set_name(input, "input"); + + ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data()); + ggml_set_name(kernel, "kernel"); + + GGML_ASSERT(cwhn == false); + + if (cwhn) { + // change memory layout to channel-most-contiguous (CWHN), + // then permute it back so NE matches the original input + input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3)); + input = ggml_permute(ctx, input, 2, 0, 1, 3); + kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0)); + kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1); + } + + ggml_tensor * conv2d_out = + ggml_conv_2d(ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1); + + ggml_tensor * out = ggml_cont(ctx, conv2d_out); + + ggml_set_name(out, "out"); + + return out; + } +}; + +// CONV_2D +struct test_conv_2d : public test_case { + const std::array ne_input; + const std::array ne_kernel; + const int stride0; + const int stride1; + const int padding0; + const int padding1; + const int dilation0; + const int dilation1; + // Whether the inputs are contiguous in the channel dim or the width dim + const bool cwhn; + std::string vars() override { return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); } @@ -3767,7 +4179,10 @@ struct test_conv_2d : public test_case { dilation1(dilation1), cwhn(cwhn) {} + std::vector get_input_names() { return { "kernel", "input" }; } + ggml_tensor * build_graph(ggml_context * ctx) override { + //printf("Building conv_2d graph...\n"); ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data()); ggml_set_name(input, "input"); @@ -3790,6 +4205,33 @@ struct test_conv_2d : public test_case { } }; +// Tests CONV_2D by comparing it to the IM2COL -> MUL_MM +// reference implementation. +struct test_conv_2d_compare : public test_case_compare { + // Input tensor names in (actual graph, reference graph) + std::vector> input_names = { + {"input", "input" }, + { "kernel", "kernel"} + }; + + std::string output_name_1 = "out"; + std::string output_name_2 = "out"; + + virtual std::string op_desc(ggml_tensor * t) override { + (void) t; + return std::string("CONV_2D_COMPARE"); + } + + test_conv_2d_compare(test_conv_2d_im2col * conv_2d_im2col, test_conv_2d * conv_2d_direct) : + test_case_compare(conv_2d_im2col, conv_2d_direct) {} + + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, + printer * output_printer) override { + return test_case_compare::eval(backend1, backend2, op_name, input_names, output_name_1, output_name_2, + output_printer); + } +}; + // GGML_OP_CONV_2D_DW struct test_conv_2d_dw : public test_case { const std::array ne_input; @@ -5089,7 +5531,7 @@ static std::vector> make_test_cases_eval() { } // extra tests for im2col 2D - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 4, 1, 32}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true)); @@ -5136,6 +5578,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_2d( { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d_im2col( + { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); + test_cases.emplace_back( + new test_conv_2d_compare({ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, + 1, 0, 0, 1, 1, false, true)); } #endif @@ -5161,8 +5610,13 @@ static std::vector> make_test_cases_eval() { for (uint32_t W : { 1, 141 }) { if (calc_conv_output_size(W, KW, s0, p0, d0) > 0 && calc_conv_output_size(H, KH, s1, p1, d1) > 0) { - test_cases.emplace_back(new test_conv_2d( - { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false)); + auto test_case_conv_2d = new test_conv_2d( + { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); + test_cases.emplace_back(test_case_conv_2d); + auto test_case_conv_2d_im2col = new test_conv_2d_im2col( + { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); + test_cases.emplace_back(test_case_conv_2d_im2col); + //test_cases.emplace_back(new test_conv_2d_compare(test_case_conv_2d, test_case_conv_2d_im2col); } } } @@ -5812,6 +6266,10 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_2d( { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); + // Indirect CONV_2D (uses im2col + sgemm) + test_cases.emplace_back(new test_conv_2d_im2col( + { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); } test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); From 1a2dfad832f6bc969d8d20c334bb544b0c59fdd0 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 23 Jul 2025 22:59:11 +0200 Subject: [PATCH 2/7] #include added, example tests removed --- tests/test-backend-ops.cpp | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 132a24e8a2c22..6bb44d6798afa 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -38,6 +38,7 @@ #include #include #include +#include static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); @@ -5531,7 +5532,7 @@ static std::vector> make_test_cases_eval() { } // extra tests for im2col 2D - test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 4, 1, 32}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true)); @@ -5575,16 +5576,27 @@ static std::vector> make_test_cases_eval() { }; for (auto act_case : cases) { - test_cases.emplace_back(new test_conv_2d( - { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, - { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); - test_cases.emplace_back(new test_conv_2d_im2col( + test_cases.emplace_back( + new test_conv_2d({ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, + { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, + 1, + 1, + 0, + 0, + 1, + 1, + false)); + /* + // Example test for testing a composite op + test_cases.emplace_back(new test_conv_2d_im2col( { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false)); - test_cases.emplace_back( + // Example test for testing a regular op against a composite op + test_cases.emplace_back( new test_conv_2d_compare({ act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] }, { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1, 1, 0, 0, 1, 1, false, true)); + */ } #endif @@ -5615,7 +5627,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(test_case_conv_2d); auto test_case_conv_2d_im2col = new test_conv_2d_im2col( { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); - test_cases.emplace_back(test_case_conv_2d_im2col); + //test_cases.emplace_back(test_case_conv_2d_im2col); //test_cases.emplace_back(new test_conv_2d_compare(test_case_conv_2d, test_case_conv_2d_im2col); } } From 5f4fc489fab8b5325d5416149259154513ec9296 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 23 Jul 2025 23:02:10 +0200 Subject: [PATCH 3/7] Unused variable removed. --- tests/test-backend-ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6bb44d6798afa..c1d301c9045a0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5625,8 +5625,8 @@ static std::vector> make_test_cases_eval() { auto test_case_conv_2d = new test_conv_2d( { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); test_cases.emplace_back(test_case_conv_2d); - auto test_case_conv_2d_im2col = new test_conv_2d_im2col( - { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); + //auto test_case_conv_2d_im2col = new test_conv_2d_im2col( + // { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false); //test_cases.emplace_back(test_case_conv_2d_im2col); //test_cases.emplace_back(new test_conv_2d_compare(test_case_conv_2d, test_case_conv_2d_im2col); } From e08e8fcf340fa924e447098c75270d6246b0d5e6 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Wed, 23 Jul 2025 23:04:23 +0200 Subject: [PATCH 4/7] Overrides added. --- tests/test-backend-ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c1d301c9045a0..f00b20ae95390 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4091,7 +4091,7 @@ struct test_conv_2d_im2col : public test_case { dilation1(dilation1), cwhn(cwhn) {} - std::vector get_input_names() { return { "kernel", "input" }; } + std::vector get_input_names() override { return { "kernel", "input" }; } ggml_tensor * build_graph(ggml_context * ctx) override { //printf("Building conv_2d_im2col graph...\n"); @@ -4180,7 +4180,7 @@ struct test_conv_2d : public test_case { dilation1(dilation1), cwhn(cwhn) {} - std::vector get_input_names() { return { "kernel", "input" }; } + std::vector get_input_names() override { return { "kernel", "input" }; } ggml_tensor * build_graph(ggml_context * ctx) override { //printf("Building conv_2d graph...\n"); From 41e9ec092bab94e8ecce5c5d1458c20f236f08a7 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Thu, 24 Jul 2025 10:36:24 +0200 Subject: [PATCH 5/7] Overrides added, eval() renamed to prevent shading. --- tests/test-backend-ops.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f00b20ae95390..319f814d290d3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1303,6 +1303,7 @@ struct test_case { } } assert(added_nodes == nodes_per_op); + (void) added_nodes; } else { ggml_graph_add_node(gf, out); } @@ -1713,7 +1714,7 @@ struct test_case_compare : public test_case { std::string vars() override { return "(" + case1->vars() + "),(" + case2->vars() + ")"; } - ggml_tensor * build_graph(ggml_context * ctx) { + ggml_tensor * build_graph(ggml_context * ctx) override { GGML_UNUSED(ctx); return nullptr; } @@ -1761,7 +1762,7 @@ struct test_case_compare : public test_case { } // Compares the output of the actual graph to the output of the reference - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, + bool eval_compare(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, std::vector> input_names, std::string output_name_1, std::string output_name_2, printer * output_printer) { mode = MODE_TEST; @@ -4045,7 +4046,7 @@ struct test_conv_2d_im2col : public test_case { return std::string("CONV_2D_IM2COL"); } - bool run_whole_graph() { return false; } + bool run_whole_graph() override { return false; } std::string vars() override { return VARS_TO_STR9(ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn); @@ -4228,7 +4229,7 @@ struct test_conv_2d_compare : public test_case_compare { bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) override { - return test_case_compare::eval(backend1, backend2, op_name, input_names, output_name_1, output_name_2, + return test_case_compare::eval_compare(backend1, backend2, op_name, input_names, output_name_1, output_name_2, output_printer); } }; From 2dc1a6a4339d8ee6ef6462d7d559f1a22aae8d91 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Thu, 24 Jul 2025 14:03:23 +0200 Subject: [PATCH 6/7] Replaces asserts to GGML_ASSERT and (void) to GGML_UNUSED. --- ggml/src/ggml-backend.cpp | 6 +++--- tests/test-backend-ops.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index fb7a470dddd9e..009b8740a2395 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1902,9 +1902,9 @@ bool ggml_backend_compare_graph_backend_node(ggml_backend_t backend1, ggml_backe } } - assert(out1 != NULL); - assert(out2 != NULL); - assert(ggml_are_same_layout(out1, out2)); + GGML_ASSERT(out1 != NULL); + GGML_ASSERT(out2 != NULL); + GGML_ASSERT(ggml_are_same_layout(out1, out2)); // compare results, calculate rms etc if (!callback(0, out1, out2, user_data)) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 319f814d290d3..981c4fc55b1e7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1302,8 +1302,8 @@ struct test_case { ggml_graph_add_node(gf, t); } } - assert(added_nodes == nodes_per_op); - (void) added_nodes; + GGML_ASSERT(added_nodes == nodes_per_op); + GGML_UNUSED(added_nodes); } else { ggml_graph_add_node(gf, out); } From 21108b825115ddacf345fc32d159661186297d27 Mon Sep 17 00:00:00 2001 From: Ervin Tasnadi Date: Thu, 24 Jul 2025 17:23:03 +0200 Subject: [PATCH 7/7] assert include removed, 2 further unused macro replaced --- tests/test-backend-ops.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 981c4fc55b1e7..7615e54e281a8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -20,7 +20,6 @@ #include #include -#include #include #include #include @@ -4042,7 +4041,7 @@ struct test_conv_2d_im2col : public test_case { // IM2COL -> MUL_MM graph will be built. virtual std::string op_desc(ggml_tensor * t) override { - (void) t; + GGML_UNUSED(t); return std::string("CONV_2D_IM2COL"); } @@ -4220,7 +4219,7 @@ struct test_conv_2d_compare : public test_case_compare { std::string output_name_2 = "out"; virtual std::string op_desc(ggml_tensor * t) override { - (void) t; + GGML_UNUSED(t); return std::string("CONV_2D_COMPARE"); }