From c862a982f980eda7f826d9ffba0d1c6b1093dc66 Mon Sep 17 00:00:00 2001 From: "Lee, Chon Ming" Date: Thu, 16 Oct 2025 19:10:56 +0800 Subject: [PATCH] moe_scatter_reduction --- .../primitives/moe_scatter_reduction.hpp | 65 +++++ .../ocl_v2/moe/moe_scatter_reduction.cpp | 149 ++++++++++++ .../ocl_v2/moe/moe_scatter_reduction.hpp | 55 +++++ .../impls/ocl_v2/moe_scatter_reduction_ref.cl | 97 ++++++++ .../include/moe_scatter_reduction_inst.h | 46 ++++ .../src/graph/moe_scatter_reduction.cpp | 65 +++++ .../registry/moe_scatter_reduction_impls.cpp | 29 +++ .../intel_gpu/src/graph/registry/registry.hpp | 1 + .../src/kernel_selector/common_types.h | 3 +- .../tests/unit/test_cases/moe_test.cpp | 224 ++++++++++++++++++ 10 files changed, 733 insertions(+), 1 deletion(-) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/moe_scatter_reduction.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.cpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_scatter_reduction_ref.cl create mode 100644 src/plugins/intel_gpu/src/graph/include/moe_scatter_reduction_inst.h create mode 100644 src/plugins/intel_gpu/src/graph/moe_scatter_reduction.cpp create mode 100644 src/plugins/intel_gpu/src/graph/registry/moe_scatter_reduction_impls.cpp create mode 100644 src/plugins/intel_gpu/tests/unit/test_cases/moe_test.cpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_scatter_reduction.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_scatter_reduction.hpp new file mode 100644 index 00000000000000..29b751ff5f90f9 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_scatter_reduction.hpp @@ -0,0 +1,65 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "primitive.hpp" + +namespace cldnn { + +/// @brief +/// @details +struct moe_scatter_reduction : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(moe_scatter_reduction) + + moe_scatter_reduction() : primitive_base("", {}) {} + + /// @brief Constructs moe_scatter_reduction primitive. + /// + /// @param id This primitive id. + /// @param input Input data primitive id. + /// @param experts_per_token sorted topk expert id per token + /// @param expert_weights_per_token sorted topk expert id weight per token + /// @param tokens_per_expert tokens per expert + /// @param experts_info_offsets offset of each expert's info from the tokens_per_expert + /// @param tokens_len_per_expert tokens len_per_expert + moe_scatter_reduction(const primitive_id& id, + const input_info& data, + const input_info& experts_per_token, + const input_info& expert_weights_per_token, + const input_info& tokens_per_expert, + const input_info& experts_info_offsets, + const input_info& tokens_len_per_expert, + int32_t num_active_experts_per_token = 0) + : primitive_base(id, {data, experts_per_token, expert_weights_per_token, tokens_per_expert, + experts_info_offsets, tokens_len_per_expert}), num_active_experts_per_token(num_active_experts_per_token) {} + + int32_t num_active_experts_per_token = 0; + + size_t hash() const override { + size_t seed = primitive::hash(); + seed = hash_combine(seed, num_active_experts_per_token); + return seed; + } + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return num_active_experts_per_token == rhs_casted.num_active_experts_per_token; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << num_active_experts_per_token; + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> num_active_experts_per_token; + } +}; +} + diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.cpp new file mode 100644 index 00000000000000..b91637b288b6e6 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "moe_scatter_reduction.hpp" + +#include "../common_utils/dispatch_utils.hpp" +#include "../common_utils/jitter.hpp" +#include "intel_gpu/primitives/moe_scatter_reduction.hpp" +#include "../primitive_ocl_base.hpp" +#include "../utils/kernel_generator.hpp" + +namespace ov::intel_gpu::ocl { +namespace { + +class MoeScatterReductionRefGenerator : public KernelGenerator { +public: + MoeScatterReductionRefGenerator() : KernelGenerator("moe_scatter_reduction_ref") {} + +protected: + static size_t GetBlockSize(const RuntimeParams& params) { + const auto& input = params.get_input_layout(0); + size_t vec_size = 1; + switch (input.data_type) { + case ov::element::i8: + case ov::element::u8: + vec_size = 16; + break; + case ov::element::f16: + vec_size = 8; + break; + case ov::element::f32: + case ov::element::i32: + vec_size = 4; + break; + case ov::element::i64: + vec_size = 2; + break; + default: + vec_size = 1; + break; + } + return vec_size; + } + + static auto calc_thread_count(RuntimeParams& params, const int vector_size, const int hidden_size) { + auto max_wgs = params.get_program().get_engine().get_device_info().max_work_group_size; + const uint64_t threads_needed = (hidden_size + vector_size - 1) / vector_size; + size_t local_threads_needed = std::min(threads_needed, max_wgs); + size_t batches_per_thread = 1; + size_t unaligned_elements = 0; + + if (threads_needed <= max_wgs) { + batches_per_thread = 1; + unaligned_elements = hidden_size % vector_size; + } else { + batches_per_thread = (threads_needed + max_wgs - 1) / max_wgs; + auto new_block_size = batches_per_thread * vector_size; + unaligned_elements = hidden_size % new_block_size; + + local_threads_needed = hidden_size / new_block_size; + auto partialblock = (hidden_size % new_block_size != 0) ? 1 : 0; + local_threads_needed += partialblock; + } + + return std::tuple{local_threads_needed, batches_per_thread, unaligned_elements}; + } + + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto in_l = params.input_layouts[0]; + auto hidden_size = extract_channel(ChannelName::FEATURE, in_l); + auto block_size = GetBlockSize(params); + auto [local_threads_count, batches_per_thread, unaligned_elements] = calc_thread_count( + const_cast(params), block_size, hidden_size); + + const auto& desc = params.typed_desc(); + + jit.make("ACTIVE_EXPERTS", desc->num_active_experts_per_token); + jit.make("HIDDEN_SIZE", hidden_size); + jit.make("VEC_BLK_SIZE", block_size); + jit.make("BATCHES_PER_THREAD", batches_per_thread); + jit.make("UNALIGNED_ELEMENTS", unaligned_elements); + + return jit; + } + + Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + if (params.is_dynamic()) { + args.push_back({ArgumentDescriptor::Types::SHAPE_INFO, 0}); + } + + uint32_t num_of_inputs = 6; + + for (uint32_t i = 0; i < num_of_inputs; i++) { + args.push_back({ArgumentDescriptor::Types::INPUT, i}); + } + + args.push_back({ArgumentDescriptor::Types::OUTPUT, 0}); + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{[](const RuntimeParams& params, KernelData& kd, ImplRuntimeParams* rt_params) { + auto& wgs = kd.params.workGroups; + + if (!params.is_dynamic()) { + auto hidden_size = extract_channel(ChannelName::FEATURE, params.input_layouts[0]); + auto block_size = GetBlockSize(params); + auto [local_threads_count, batches_per_thread, unaligned_elements] = calc_thread_count( + const_cast(params), block_size, hidden_size); + + auto num_tokens = extract_channel(ChannelName::BATCH, params.input_layouts[1]); + + wgs.global = {num_tokens * local_threads_count, 1, 1}; + wgs.local = { local_threads_count, 1, 1}; + } + }}; + } +}; + +class MoeScatterReductionRefImpl : public PrimitiveImplOCL { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MoeScatterReductionRefImpl) + + Stage::Ptr moe_scatter_reduction = make_stage(); + + MoeScatterReductionRefImpl() : PrimitiveImplOCL(MoeScatterReductionRef::get_type_info_static()) {} + MoeScatterReductionRefImpl(const program_node& node, const RuntimeParams& params) : MoeScatterReductionRefImpl() { + add_stage(moe_scatter_reduction, params); + } + + [[nodiscard]] std::unique_ptr clone() const override { + return make_deep_copy(this); + } +}; + +} // namespace + +std::unique_ptr MoeScatterReductionRef::create_impl(const program_node& node, const RuntimeParams& params) const { + assert(node.is_type()); + return std::make_unique(node, params); +} + +} // namespace ov::intel_gpu::ocl + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_scatter_reduction) +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::MoeScatterReductionRefImpl) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.hpp new file mode 100644 index 00000000000000..8090ef340ae0ee --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_scatter_reduction.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "program_node.h" +#include "registry/implementation_manager.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned + +namespace ov::intel_gpu::ocl { + +struct MoeScatterReductionRef : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("ocl::moe_scatter_reduction") + explicit MoeScatterReductionRef(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, std::move(vf)) {} + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const RuntimeParams& params) const override; + [[nodiscard]] bool validate_impl(const program_node& node) const override { + static constexpr std::array supported_fmts = { + format::bfyx, + }; + + static constexpr std::array supported_types = { + ov::element::f32, + ov::element::f16, + ov::element::i32, + ov::element::i64, + ov::element::i8, + ov::element::u8, + }; + + const auto& in0_layout = node.get_input_layout(0); + const auto& out_layout = node.get_output_layout(0); + const auto& input_pshapes = in0_layout.get_partial_shape(); + + if (input_pshapes.rank().get_length() != 2 || input_pshapes[1].is_dynamic()) { + return false; + } + + if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) { + return false; + } + + if (!one_of(in0_layout.data_type, supported_types) || !one_of(out_layout.data_type, supported_types)) { + return false; + } + + return true; + } +}; + +} // namespace ov::intel_gpu::ocl diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_scatter_reduction_ref.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_scatter_reduction_ref.cl new file mode 100644 index 00000000000000..5b772a3b65cb9c --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_scatter_reduction_ref.cl @@ -0,0 +1,97 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "include/batch_headers/common.cl" +#include "include/fetch_utils.cl" + +#define VLOAD CAT(vload, VEC_BLK_SIZE) +#define VSTORE CAT(vstore, VEC_BLK_SIZE) +#define INPUT_VEC_TYPE MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_BLK_SIZE) +#define OUTPUT_VEC_TYPE MAKE_VECTOR_TYPE(OUTPUT_TYPE, VEC_BLK_SIZE) + +KERNEL(moe_scatter_reduction_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* experts_per_token, + const __global INPUT2_TYPE* expert_weights, + const __global INPUT3_TYPE* tokens_per_expert, + const __global INPUT4_TYPE* experts_start_offset, + const __global INPUT5_TYPE* tokens_len_per_expert, + __global OUTPUT_TYPE* output +) +{ + const uint token_group_id = (uint)get_group_id(0); + const uint threads_index = (uint)get_local_id(0); + + OUTPUT_VEC_TYPE output_vec[BATCHES_PER_THREAD]; + +#if UNALIGNED_ELEMENTS > 0 + OUTPUT_TYPE output_scalar[UNALIGNED_ELEMENTS]; +#endif + + uint dest_index = token_group_id * HIDDEN_SIZE; + uint output_pos = dest_index + threads_index * VEC_BLK_SIZE * BATCHES_PER_THREAD; + + for (uint i = 0; i < BATCHES_PER_THREAD; i++) { + output_vec[i] = TO_OUTPUT_TYPE(0); + } + +#if UNALIGNED_ELEMENTS > 0 + for (uint i = 0; i < UNALIGNED_ELEMENTS; i++) { + output_scalar[i] = TO_OUTPUT_TYPE(0); + } +#endif + + for (uint i = 0; i < ACTIVE_EXPERTS; i++) { + INPUT1_TYPE expert_id = experts_per_token[token_group_id * ACTIVE_EXPERTS + i]; + INPUT2_TYPE expert_weight = expert_weights[token_group_id * ACTIVE_EXPERTS + i]; + INPUT5_TYPE token_len = tokens_len_per_expert[expert_id]; + INPUT4_TYPE expert_offset = experts_start_offset[expert_id]; + + uint input_offset = 0; + for (uint j = 0; j < token_len; j++) { + if (tokens_per_expert[expert_offset + j] == token_group_id) { + input_offset = expert_offset + j; + break; + } + } + + for (uint j = 0; j < BATCHES_PER_THREAD; j++) { + const uint input_pos = input_offset * HIDDEN_SIZE + j * VEC_BLK_SIZE + threads_index * VEC_BLK_SIZE * BATCHES_PER_THREAD; + +#if UNALIGNED_ELEMENTS > 0 + if ((threads_index == get_local_size(0) - 1) && (j == 0)) { + uint input_pos_unaligned = input_pos; + for (uint k = 0; k < UNALIGNED_ELEMENTS; k++) { + output_scalar[k] += input[input_pos_unaligned] * expert_weight; + input_pos_unaligned++; + } + } else { +#endif + INPUT_VEC_TYPE input_data = VLOAD(0, &input[input_pos]); + input_data *= expert_weight; + output_vec[j] += input_data; +#if UNALIGNED_ELEMENTS > 0 + } +#endif + } + } + +#if UNALIGNED_ELEMENTS > 0 + if ((threads_index == get_local_size(0) - 1)) { + uint output_pos_unaligned = output_pos; + for (uint s = 0; s < UNALIGNED_ELEMENTS; s++) { + output[output_pos_unaligned] = output_scalar[s]; + output_pos_unaligned++; + } + } else { +#endif + for (uint v = 0; v < BATCHES_PER_THREAD; v++) { + const uint out_pos = output_pos + v * VEC_BLK_SIZE; + VSTORE(output_vec[v], 0, &output[out_pos]); + } +#if UNALIGNED_ELEMENTS > 0 + } +#endif +} diff --git a/src/plugins/intel_gpu/src/graph/include/moe_scatter_reduction_inst.h b/src/plugins/intel_gpu/src/graph/include/moe_scatter_reduction_inst.h new file mode 100644 index 00000000000000..6cf9f70560c2b3 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/moe_scatter_reduction_inst.h @@ -0,0 +1,46 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "intel_gpu/primitives/moe_scatter_reduction.hpp" +#include "primitive_inst.h" + +#include +#include + +namespace cldnn { + +template <> +struct typed_program_node : public typed_program_node_base { + using parent = typed_program_node_base; + typed_program_node(const std::shared_ptr prim, program& prog) : parent(prim, prog) { + support_padding_all(true); + } + +public: + using parent::parent; + + std::vector get_shape_infer_dependencies() const override { return {}; } + program_node& input() const { return get_dependency(0); } +}; + +using moe_scatter_reduction_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + +public: + template + static std::vector calc_output_layouts(moe_scatter_reduction_node const& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(moe_scatter_reduction_node const& node, kernel_impl_params const& impl_param); + static std::string to_string(moe_scatter_reduction_node const& node); + + typed_primitive_inst(network& network, moe_scatter_reduction_node const& node); +}; + +using moe_scatter_reduction_inst = typed_primitive_inst; +} // namespace cldnn + diff --git a/src/plugins/intel_gpu/src/graph/moe_scatter_reduction.cpp b/src/plugins/intel_gpu/src/graph/moe_scatter_reduction.cpp new file mode 100644 index 00000000000000..469b6e272cc508 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/moe_scatter_reduction.cpp @@ -0,0 +1,65 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_scatter_reduction_inst.h" +#include "primitive_type_base.h" +#include "json_object.h" +#include "to_string_utils.h" +#include +#include + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(moe_scatter_reduction) + +layout moe_scatter_reduction_inst::calc_output_layout(moe_scatter_reduction_node const& node, kernel_impl_params const& impl_param) { + auto output_layouts = calc_output_layouts(node, impl_param); + return output_layouts[0]; +} + +template +std::vector moe_scatter_reduction_inst::calc_output_layouts(moe_scatter_reduction_node const& /*node*/, const kernel_impl_params& impl_param) { + const auto& desc = impl_param.typed_desc(); + const auto num_active_experts_per_token = desc->num_active_experts_per_token; + + const auto& input0_layout = impl_param.get_input_layout(0); + + const auto& input_shapes = impl_param.input_layouts[0].get(); + const auto& hidden_size = input_shapes[1]; + OPENVINO_ASSERT(hidden_size.is_static(), impl_param.desc->id, " hidden size dimension (shape[1]) must be static"); + + const auto& input0_pshape = input0_layout.get_partial_shape(); + auto input0_len = input0_pshape.size(); + + input0_len = input0_len; + if (impl_param.input_layouts[0].is_dynamic()) + return {layout{ov::PartialShape{ov::Dimension::dynamic(), ov::Dimension(hidden_size)}, + impl_param.input_layouts[0].data_type, impl_param.input_layouts[0].format}}; + const auto num_tokens = impl_param.input_layouts[0].get_shape()[0] / + num_active_experts_per_token; + const auto& out_shape = ov::PartialShape{ov::Dimension(num_tokens), ov::Dimension(hidden_size)}; + return {layout{out_shape, impl_param.input_layouts[0].data_type, impl_param.input_layouts[0].format}}; +} + +template std::vector moe_scatter_reduction_inst::calc_output_layouts(moe_scatter_reduction_node const& node, + const kernel_impl_params& impl_param); + +std::string moe_scatter_reduction_inst::to_string(moe_scatter_reduction_node const& node) { + auto node_info = node.desc_to_json(); + auto desc = node.get_primitive(); + + std::stringstream primitive_description; + + json_composite moe_scatter_reduction_info; + moe_scatter_reduction_info.add("input id", node.input().id()); + if (desc->output_data_types[0].has_value()) + moe_scatter_reduction_info.add("out dt: ", dt_to_str(*desc->output_data_types[0])); + node_info->add("moe_scatter_reduction_info", moe_scatter_reduction_info); + node_info->dump(primitive_description); + + return primitive_description.str(); +} + +moe_scatter_reduction_inst::typed_primitive_inst(network& network, moe_scatter_reduction_node const& node) : parent(network, node) { } +} // namespace cldnn + diff --git a/src/plugins/intel_gpu/src/graph/registry/moe_scatter_reduction_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/moe_scatter_reduction_impls.cpp new file mode 100644 index 00000000000000..1fefd5b4b4bdf7 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/registry/moe_scatter_reduction_impls.cpp @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "registry.hpp" +#include "intel_gpu/primitives/moe_scatter_reduction.hpp" +#include "primitive_inst.h" + +#if OV_GPU_WITH_OCL + #include "impls/ocl_v2/moe/moe_scatter_reduction.hpp" +#endif + +namespace ov { +namespace intel_gpu { + +using namespace cldnn; + +const std::vector>& Registry::get_implementations() { + static const std::vector> impls = { + OV_GPU_CREATE_INSTANCE_OCL(ocl::MoeScatterReductionRef, shape_types::dynamic_shape) + OV_GPU_CREATE_INSTANCE_OCL(ocl::MoeScatterReductionRef, shape_types::static_shape) + }; + + return impls; +} + +} // namespace intel_gpu +} // namespace ov + diff --git a/src/plugins/intel_gpu/src/graph/registry/registry.hpp b/src/plugins/intel_gpu/src/graph/registry/registry.hpp index 341503a2322125..d231a2bab7ea1b 100644 --- a/src/plugins/intel_gpu/src/graph/registry/registry.hpp +++ b/src/plugins/intel_gpu/src/graph/registry/registry.hpp @@ -166,6 +166,7 @@ REGISTER_IMPLS(strided_slice); REGISTER_IMPLS(tile); REGISTER_IMPLS(col2im); REGISTER_IMPLS(vl_sdpa); +REGISTER_IMPLS(moe_scatter_reduction); REGISTER_DEFAULT_IMPLS(assign, CPU_S, CPU_D); REGISTER_DEFAULT_IMPLS(read_value, CPU_S, CPU_D); diff --git a/src/plugins/intel_gpu/src/kernel_selector/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common_types.h index 01d9f019d3ef08..726fb6f9349073 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common_types.h @@ -109,7 +109,8 @@ enum class KernelType { STFT, ISTFT, COL2IM, - LORA + LORA, + MOE_SCATTER_REDUCTION }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/moe_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/moe_test.cpp new file mode 100644 index 00000000000000..045cee09b13580 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/test_cases/moe_test.cpp @@ -0,0 +1,224 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include +#include +#include "test_utils.h" +#include "random_generator.hpp" + +#include +#include + +using namespace cldnn; +using namespace ::tests; + +template +std::vector flatten (const std::vector>& vec2d) { + return std::accumulate(vec2d.begin(), vec2d.end(), std::vector{}, + [](auto& a, auto& b) { a.insert(a.end(), b.begin(), b.end()); return a; }); +} + +template +auto create_layout(ShapeType shape) { + if constexpr (std::is_same_v) + return layout{shape, data_types::f32, format::bfyx}; + else + return layout{shape, data_types::f16, format::bfyx}; +} + +template +void test_moe_scatter_reduction(bool is_caching_test, size_t k) { + auto& engine = get_test_engine(); + // num_tokens 30 + // hidden_size 64 + // num total experts 32 + // num_active_experts_per_token 2 + // num_actual_used_experts 7 + // input0 activation [30, 64] + // input1 experts_info_offset [7] + // input2 tokens_per_expert [30*2*64] + size_t num_tokens = 30; + size_t num_total_experts = 32; + size_t hidden_size = k; + size_t num_active_experts_per_token = 2; + + auto input_activation_shape = ov::PartialShape{ov::Dimension::dynamic(), ov::Dimension(hidden_size)}; + auto input_activation_layout = create_layout(input_activation_shape); + + auto experts_per_token_shape = ov::PartialShape{ov::Dimension::dynamic()}; + auto experts_per_token_layout = layout{experts_per_token_shape, data_types::i32, format::bfyx}; + + auto expert_weights_per_token_shape = ov::PartialShape{ov::Dimension::dynamic()}; + auto expert_weights_per_token_layout = create_layout(expert_weights_per_token_shape); + + auto experts_info_offsets_shape = ov::PartialShape{ov::Dimension::dynamic()}; + auto experts_info_offsets_layout = layout{experts_info_offsets_shape, data_types::i32, format::bfyx}; + + auto tokens_per_expert_shape = ov::PartialShape{ov::Dimension::dynamic()}; + auto tokens_per_expert_layout = layout{tokens_per_expert_shape, data_types::i32, format::bfyx}; + + auto tokens_len_per_expert_shape = ov::PartialShape{ov::Dimension::dynamic()}; + auto tokens_len_per_expert_layout = layout{tokens_len_per_expert_shape, data_types::i32, format::bfyx}; + + topology topology( + input_layout("input", input_activation_layout), + input_layout("experts_per_token", experts_per_token_layout), + input_layout("expert_weights_per_token", expert_weights_per_token_layout), + input_layout("tokens_per_expert", tokens_per_expert_layout), + input_layout("experts_info_offsets", experts_info_offsets_layout), + input_layout("tokens_len_per_expert", tokens_len_per_expert_layout), + moe_scatter_reduction("moe_scatter_reduction", input_info("input"), input_info("experts_per_token"), + input_info("expert_weights_per_token"), input_info("tokens_per_expert"), input_info("experts_info_offsets"), + input_info("tokens_len_per_expert"), num_active_experts_per_token) + ); + auto input_data_shape = ov::PartialShape{ov::Dimension(num_tokens * num_active_experts_per_token), ov::Dimension(hidden_size)}; + auto input_data_layout = create_layout(input_data_shape); + + std::vector input_data; + for (size_t i = 0; i < num_tokens * num_active_experts_per_token; ++i) { + for (size_t h = 0; h < hidden_size; ++h) + input_data.push_back(static_cast(i)); + } + auto input_mem = engine.allocate_memory(input_data_layout); + set_values(input_mem, input_data); + + // topk result + std::vector> experts_per_token = {{0, 5}, {5, 7}, {0, 10}, {11, 20}, {7, 10}, {0, 7}, {20, 31}, {11, 31}, {11, 20}, {7, 10}, + {0, 5}, {11, 31}, {0, 7}, {0, 20}, {10, 31}, {10, 20}, {7, 31}, {0, 31}, {5, 31}, {7, 31}, + {7, 20}, {0, 10}, {0, 5}, {5, 11}, {7, 11}, {5, 31}, {7, 31}, {0, 31}, {0, 10}, {11, 20}}; + + std::vector> expert_weights_per_token = {{1.0, .9}, {.8, .7}, {.6, .5}, {.4, .3}, {.2, .1}, {.9, .8}, {.7, .6}, {.5, .4}, {.3, .2}, {.1, .0}, + {1.0, .9}, {.8, .7}, {.6, .5}, {.4, .3}, {.2, .1}, {.9, .8}, {.7, .6}, {.5, .4}, {.3, .2}, {.1, .0}, + {1.0, .9}, {.8, .7}, {.6, .5}, {.4, .3}, {.2, .1}, {.9, .8}, {.7, .6}, {.5, .4}, {.3, .2}, {.1, .0}}; + + std::vector> tokens_per_expert_tmp(num_total_experts, std::vector{}); + + for (size_t i = 0; i < experts_per_token.size(); ++i) { + for (size_t j = 0; j < experts_per_token[i].size(); ++j) + tokens_per_expert_tmp[experts_per_token[i][j]].push_back(i); + } + + std::vector tokens_per_expert_data; + std::vector tokens_len_per_expert_data; + + for (size_t i = 0; i < tokens_per_expert_tmp.size(); ++i) { + tokens_len_per_expert_data.push_back(tokens_per_expert_tmp[i].size()); + if (tokens_per_expert_tmp[i].empty()) + continue; + for (size_t j = 0; j < tokens_per_expert_tmp[i].size(); ++j) { + tokens_per_expert_data.push_back(tokens_per_expert_tmp[i][j]); + } + } + + std::vector expert_info_start_idx(tokens_len_per_expert_data.size()); + std::exclusive_scan(tokens_len_per_expert_data.begin(), tokens_len_per_expert_data.end(), + expert_info_start_idx.begin(), 0); + + // tokens per expert + // experts 0, 5, 7, 10, 11, 20, 31 are used + // experts[0] offset : 0 {0, 2, 5, 10, 12, 13, 17, 21, 22, 27, 28} + // experts[5] offset : 11 {0, 1, 10, 18, 22, 23, 25} + // experts[7] offset : 18 {1, 4, 5, 9, 12, 16, 19, 20, 24, 26} + // experts[10] offset : 28 {2, 4, 9, 14, 15, 21, 28} + // experts[11] offset : 35 {3, 7, 8, 11, 23, 24, 29} + // experts[20] offset : 42 {3, 6, 8, 13, 15, 20, 29} + // experts[31] offset : 49 {6, 7, 11, 14, 16, 17, 18, 19, 25, 26, 27} + + auto experts_per_token_data_shape = ov::PartialShape{ov::Dimension(num_tokens), ov::Dimension(num_active_experts_per_token)}; + auto experts_per_token_data_layout = layout{experts_per_token_data_shape, data_types::i32, format::bfyx}; + auto experts_per_token_data_mem = engine.allocate_memory(experts_per_token_data_layout); + set_values(experts_per_token_data_mem, flatten(experts_per_token)); + + auto expert_weights_per_token_data_shape = ov::PartialShape{ov::Dimension(num_tokens), ov::Dimension(num_active_experts_per_token)}; + auto expert_weights_per_token_data_layout = create_layout(expert_weights_per_token_data_shape); + auto expert_weights_per_token_data_mem = engine.allocate_memory(expert_weights_per_token_data_layout); + set_values(expert_weights_per_token_data_mem, flatten(expert_weights_per_token)); + + auto tokens_per_expert_data_shape = ov::PartialShape{ov::Dimension(tokens_per_expert_data.size())}; + auto tokens_per_expert_data_layout = layout{tokens_per_expert_data_shape, data_types::i32, format::bfyx}; + auto tokens_per_expert_data_mem = engine.allocate_memory(tokens_per_expert_data_layout); + set_values(tokens_per_expert_data_mem, tokens_per_expert_data); + + auto expert_info_offsets_data_shape = ov::PartialShape{ov::Dimension(expert_info_start_idx.size())}; + auto expert_info_offsets_data_layout = layout{expert_info_offsets_data_shape, data_types::i32, format::bfyx}; + auto expert_info_offsets_data_mem = engine.allocate_memory(expert_info_offsets_data_layout); + set_values(expert_info_offsets_data_mem, expert_info_start_idx); + + auto tokens_len_per_expert_data_shape = ov::PartialShape{ov::Dimension(tokens_len_per_expert_data.size())}; + auto tokens_len_per_expert_data_layout = layout{expert_info_offsets_data_shape, data_types::i32, format::bfyx}; + auto tokens_len_per_expert_data_mem = engine.allocate_memory(tokens_len_per_expert_data_layout); + set_values(tokens_len_per_expert_data_mem, tokens_len_per_expert_data); + + auto config = get_test_default_config(engine); + config.set_property(ov::intel_gpu::optimize_data(true)); + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); + + network network(engine, topology, config, is_caching_test); + network.set_input_data("input", input_mem); + network.set_input_data("experts_per_token", experts_per_token_data_mem); + network.set_input_data("expert_weights_per_token", expert_weights_per_token_data_mem); + network.set_input_data("tokens_per_expert", tokens_per_expert_data_mem); + network.set_input_data("experts_info_offsets", expert_info_offsets_data_mem); + network.set_input_data("tokens_len_per_expert", tokens_len_per_expert_data_mem); + auto outputs = network.execute(); + auto output = outputs.begin()->second.get_memory(); + cldnn::mem_lock output_ptr(output, get_test_stream()); + std::vector ref_output; + for (size_t i = 0; i < num_tokens; ++i) { + std::vector token_output(hidden_size, 0.0f); + for (size_t j = 0; j < num_active_experts_per_token; j++) { + size_t expert_idx = experts_per_token[i][j]; + float expert_weight = expert_weights_per_token[i][j]; + for (size_t k = 0; k < tokens_per_expert_tmp[expert_idx].size(); k++) { + if (i == tokens_per_expert_tmp[expert_idx][k]) { + size_t input_idx = expert_info_start_idx[expert_idx] + k; + // copy out the data and multiply the weight then add to token output + std::vector token_data(hidden_size); + std::copy(input_data.begin() + input_idx * hidden_size, input_data.begin() + (input_idx + 1)*hidden_size, token_data.begin()); + std::transform(token_data.begin(), token_data.end(), token_data.begin(), [&expert_weight](auto& c){return c*expert_weight;}); + std::transform(token_data.begin(), token_data.end(), token_output.begin(), token_output.begin(), std::plus()); + break; + } + } + } + ref_output.insert(ref_output.end(), token_output.begin(), token_output.end()); + } + + for (size_t i = 0; i < num_tokens * hidden_size; i++) { + EXPECT_NEAR(ref_output[i], output_ptr[i], 1e-1); + } +} + +TEST(moe_unit, moe_scatter_reduction_test_one_batch_aligned_f32) { + test_moe_scatter_reduction(false, 64); +} + +TEST(moe_unit, moe_scatter_reduction_test_one_batch_unaligned_f32) { + test_moe_scatter_reduction(false, 66); +} + +TEST(moe_unit, moe_scatter_reduction_test_multi_batch_aligned_f32) { + test_moe_scatter_reduction(false, 2880); +} + +TEST(moe_unit, moe_scatter_reduction_test_multi_batch_unaligned_f32) { + test_moe_scatter_reduction(false, 2882); +} + +TEST(moe_unit, moe_scatter_reduction_test_one_batch_aligned_f16) { + test_moe_scatter_reduction(false, 64); +} +TEST(moe_unit, moe_scatter_reduction_test_one_batch_unaligned_f16) { + test_moe_scatter_reduction(false, 66); +} + +TEST(moe_unit, moe_scatter_reduction_test_multi_batch_aligned_f16) { + test_moe_scatter_reduction(false, 2880); +} + +TEST(moe_unit, moe_scatter_reduction_test_multi_batch_unaligned_f16) { + test_moe_scatter_reduction(false, 2882); +} + +TEST(moe_unit, moe_scatter_reduction_test_cached) { + test_moe_scatter_reduction(true, 64); +}