Skip to content

Conversation

snnn
Copy link
Member

@snnn snnn commented Oct 15, 2025

This commit introduces BFloat16 support for Gemm and MatMul operators on the CPU execution provider.

Key changes:

  • Added BFloat16 data type and moved related files to onnxruntime/core/common.
  • Implemented MlasBf16AccelerationSupported to detect hardware support for BFloat16.
  • Added Gemm and MatMul kernels for BFloat16 using Eigen.
  • Registered the new kernels for the CPU execution provider.
  • Added unit tests for BFloat16 Gemm and MatMul.
  • Fixed ambiguous comparison operators for BFloat16.
  • Moved endian.h/float8.h/float16.h from onnxruntime_frameworks.lib to onnxruntime_common.lib because onnxruntime_utils.lib depends on these headers. This change is to avoid circular dependency.

snnn added 2 commits October 15, 2025 12:44
This commit introduces BFloat16 support for Gemm and MatMul operators on the CPU execution provider.

Key changes:

- Added BFloat16 data type and moved related files to onnxruntime/core/common.

- Implemented MlasBf16AccelerationSupported to detect hardware support for BFloat16.

- Added Gemm and MatMul kernels for BFloat16 using Eigen.

- Registered the new kernels for the CPU execution provider.

- Added unit tests for BFloat16 Gemm and MatMul.

- Fixed ambiguous comparison operators for BFloat16.
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines 2431 to 2432
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,

Comment on lines 185 to 186
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta,
BFloat16* C, ThreadPool*) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta,
BFloat16* C, ThreadPool*) {
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta,
BFloat16* C, ThreadPool*) {

switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));

return;
case CblasTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));

switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());

return;
case CblasTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());

Comment on lines 168 to 169
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());

Comment on lines 175 to 176
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());

Comment on lines 211 to 212
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());
VECTOR_HEAD(X_bf16), VECTOR_HEAD(W_bf16), kZero_bf16, VECTOR_HEAD(Y_bf16),
tp.get());

Comment on lines 218 to 219
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());
VECTOR_HEAD(X_fp32), VECTOR_HEAD(W_fp32), 0.0f, VECTOR_HEAD(Y_ref),
tp.get());

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

InitializeCpuInfo();
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

*/
void WindowsEnv::InitializeCpuInfo() {
// Initialize cpuinfo once on Windows similar to PosixEnv constructor.
(void)cpuinfo_initialize(); //Ignore the error if it failed to initialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(void)cpuinfo_initialize(); //Ignore the error if it failed to initialize
(void)cpuinfo_initialize(); // Ignore the error if it failed to initialize

Copy link
Member

@yuslepukhin yuslepukhin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am not seeing tests specifically for BFloat16
  2. I think we better separate refactoring from BFloat16

Comment on lines 873 to +875
void WindowsEnv::InitializeCpuInfo() {
// Initialize cpuinfo once on Windows similar to PosixEnv constructor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we nee a macro here if cpuinfo supported?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added code to force the library be available on Windows.

ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta,
BFloat16* C, ThreadPool*) {
auto C_mat = EigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<Eigen::bfloat16*>(C), N, M);
if (beta == BFloat16(0.f)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is beta input, local var or a mistyped output? It is not clear from the code.

@snnn
Copy link
Member Author

snnn commented Oct 16, 2025

I am not seeing tests specifically for BFloat16

I added tests.

onnxruntime_test_all --gtest_filter=MathBFloat16GemmTests/*

And I tested them on my local machine which has BFloat16 support.

@snnn
Copy link
Member Author

snnn commented Oct 16, 2025

I will split the refactoring(the renaming of the files) to a new PR.

snnn added a commit that referenced this pull request Oct 17, 2025
Move `endian.h`, `float16.h`, and `float8.h` from `core/framework/` to
`core/common/` to avoid circular dependencies and improve architectural
layering.

## Motivation

These headers define fundamental data types that are used across
multiple low-level libraries:
- `onnxruntime_common` (foundation layer)
- `onnxruntime_mlas` (math library, depends on common)
- `onnxruntime_util` (utilities, depends on common)
- `onnxruntime_graph` (graph IR, depends on common)

Previously, these types were in `core/framework/`, which is part of the
`onnxruntime_framework` library that sits at a higher architectural
level. This created circular dependency issues since mlas uses the
"float16.h" .

## Changes

### File Moves (3 files):
- `include/onnxruntime/core/framework/endian.h` →
`include/onnxruntime/core/common/endian.h`
- `include/onnxruntime/core/framework/float16.h` →
`include/onnxruntime/core/common/float16.h`
- `include/onnxruntime/core/framework/float8.h` →
`include/onnxruntime/core/common/float8.h`

### Include Path Updates (53 files):
Updated all references from:
- `core/framework/endian.h` → `core/common/endian.h`
- `core/framework/float16.h` → `core/common/float16.h`
- `core/framework/float8.h` → `core/common/float8.h`

Affected components:
- Contrib ops (CPU, CUDA, ROCm)
- Core framework and utilities
- Providers (CPU, CUDA, CANN, QNN, OpenVINO, MIGraphX)
- Tests
- Training code

## Architectural Benefits

This change establishes clearer architectural boundaries:

```
Level 0 (Foundation):
  onnxruntime_common (includes endian, float16, float8)
  onnxruntime_mlas → depends on common

Level 1 (Core):
  onnxruntime_util → depends on common
  onnxruntime_graph → depends on common

Level 2 (Framework):
  onnxruntime_framework → depends on common
```

By placing fundamental types in `common`, we ensure:
1. No circular dependencies between library targets
2. Lower-level libraries can access these types without pulling in
framework
3. Clear separation between fundamental types (common) and
framework-specific types like int4, float4 (framework)

This PR is split from #26317 as suggested by the reviewer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

bfloat16 causing an error: : NOT_IMPLEMENTED : Could not find an implementation for MatMul(13)

2 participants