-
Notifications
You must be signed in to change notification settings - Fork 3.5k
feat: Add BFloat16 support for Gemm and MatMul CPU operators #26317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit introduces BFloat16 support for Gemm and MatMul operators on the CPU execution provider. Key changes: - Added BFloat16 data type and moved related files to onnxruntime/core/common. - Implemented MlasBf16AccelerationSupported to detect hardware support for BFloat16. - Added Gemm and MatMul kernels for BFloat16 using Eigen. - Registered the new kernels for the CPU execution provider. - Added unit tests for BFloat16 Gemm and MatMul. - Fixed ambiguous comparison operators for BFloat16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>, | ||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>, | |
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>, | |
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>, | |
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>, |
onnxruntime/core/util/math_cpu.cc
Outdated
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta, | ||
BFloat16* C, ThreadPool*) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta, | |
BFloat16* C, ThreadPool*) { | |
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta, | |
BFloat16* C, ThreadPool*) { |
onnxruntime/core/util/math_cpu.cc
Outdated
switch (TransB) { | ||
case CblasNoTrans: | ||
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) * | ||
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); | |
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); |
onnxruntime/core/util/math_cpu.cc
Outdated
return; | ||
case CblasTrans: | ||
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() * | ||
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); | |
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M)); |
onnxruntime/core/util/math_cpu.cc
Outdated
switch (TransB) { | ||
case CblasNoTrans: | ||
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) * | ||
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); | |
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); |
onnxruntime/core/util/math_cpu.cc
Outdated
return; | ||
case CblasTrans: | ||
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() * | ||
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); | |
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose()); |
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | ||
tp.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | |
tp.get()); | |
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | |
tp.get()); |
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | ||
tp.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | |
tp.get()); | |
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | |
tp.get()); |
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | ||
tp.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | |
tp.get()); | |
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16), | |
tp.get()); |
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | ||
tp.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | |
tp.get()); | |
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref), | |
tp.get()); |
…nitialize from MLAS bf16 check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
InitializeCpuInfo(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*/ | ||
void WindowsEnv::InitializeCpuInfo() { | ||
// Initialize cpuinfo once on Windows similar to PosixEnv constructor. | ||
(void)cpuinfo_initialize(); //Ignore the error if it failed to initialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(void)cpuinfo_initialize(); //Ignore the error if it failed to initialize | |
(void)cpuinfo_initialize(); // Ignore the error if it failed to initialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I am not seeing tests specifically for BFloat16
- I think we better separate refactoring from BFloat16
void WindowsEnv::InitializeCpuInfo() { | ||
// Initialize cpuinfo once on Windows similar to PosixEnv constructor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we nee a macro here if cpuinfo supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added code to force the library be available on Windows.
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta, | ||
BFloat16* C, ThreadPool*) { | ||
auto C_mat = EigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<Eigen::bfloat16*>(C), N, M); | ||
if (beta == BFloat16(0.f)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is beta input, local var or a mistyped output? It is not clear from the code.
I added tests.
And I tested them on my local machine which has BFloat16 support. |
I will split the refactoring(the renaming of the files) to a new PR. |
Move `endian.h`, `float16.h`, and `float8.h` from `core/framework/` to `core/common/` to avoid circular dependencies and improve architectural layering. ## Motivation These headers define fundamental data types that are used across multiple low-level libraries: - `onnxruntime_common` (foundation layer) - `onnxruntime_mlas` (math library, depends on common) - `onnxruntime_util` (utilities, depends on common) - `onnxruntime_graph` (graph IR, depends on common) Previously, these types were in `core/framework/`, which is part of the `onnxruntime_framework` library that sits at a higher architectural level. This created circular dependency issues since mlas uses the "float16.h" . ## Changes ### File Moves (3 files): - `include/onnxruntime/core/framework/endian.h` → `include/onnxruntime/core/common/endian.h` - `include/onnxruntime/core/framework/float16.h` → `include/onnxruntime/core/common/float16.h` - `include/onnxruntime/core/framework/float8.h` → `include/onnxruntime/core/common/float8.h` ### Include Path Updates (53 files): Updated all references from: - `core/framework/endian.h` → `core/common/endian.h` - `core/framework/float16.h` → `core/common/float16.h` - `core/framework/float8.h` → `core/common/float8.h` Affected components: - Contrib ops (CPU, CUDA, ROCm) - Core framework and utilities - Providers (CPU, CUDA, CANN, QNN, OpenVINO, MIGraphX) - Tests - Training code ## Architectural Benefits This change establishes clearer architectural boundaries: ``` Level 0 (Foundation): onnxruntime_common (includes endian, float16, float8) onnxruntime_mlas → depends on common Level 1 (Core): onnxruntime_util → depends on common onnxruntime_graph → depends on common Level 2 (Framework): onnxruntime_framework → depends on common ``` By placing fundamental types in `common`, we ensure: 1. No circular dependencies between library targets 2. Lower-level libraries can access these types without pulling in framework 3. Clear separation between fundamental types (common) and framework-specific types like int4, float4 (framework) This PR is split from #26317 as suggested by the reviewer.
This commit introduces BFloat16 support for Gemm and MatMul operators on the CPU execution provider.
Key changes:
BFloat16
data type and moved related files toonnxruntime/core/common
.MlasBf16AccelerationSupported
to detect hardware support for BFloat16.Gemm
andMatMul
kernels forBFloat16
using Eigen.BFloat16
Gemm
andMatMul
.BFloat16
.