Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ option(onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH "Path to external transformer s

option(onnxruntime_ENABLE_CUDA_PROFILING "Enable CUDA kernel profiling" OFF)

option(onnxruntime_ENABLE_CPUINFO "Enable cpuinfo" ON)
cmake_dependent_option(onnxruntime_ENABLE_CPUINFO "Enable cpuinfo" ON "NOT WIN32" ON)

# ATen fallback support
option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF)
Expand Down
3 changes: 3 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ endif()
foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
onnxruntime_add_include_to_target(${mlas_target} cpuinfo::cpuinfo)
endif()

target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions})

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1865,13 +1865,13 @@ MlasHalfGemmConvertPackB(
void* PackedB
);

#if defined(__aarch64__) && defined(__linux__)

/**
* @brief Whether current CPU supports Bfloat16(bf16) acceleration.
*/
bool MLASCALL
MlasBf16AccelerationSupported();

#if defined(__aarch64__) && defined(__linux__)
/**
* @brief Interface for bf16 gemm post processors.
*
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Module Name:

#include <thread>
#include <mutex>
#if defined(MLAS_TARGET_AMD64_IX86)
#include <cpuinfo.h>
#endif


#if defined(MLAS_TARGET_POWER)
#if defined(__linux__)
Expand Down Expand Up @@ -781,6 +785,22 @@ Return Value:
#endif
}

bool MLASCALL
MlasBf16AccelerationSupported()
{
#if defined(MLAS_TARGET_ARM64) && defined(__linux__)
return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16();
#elif defined(MLAS_TARGET_AMD64_IX86)
// cpuinfo is initialized early by the Env singleton (platform specific).
// Just query the feature flags here; if cpuinfo was unavailable initialization would have failed and
// the feature queries will safely return false.
return cpuinfo_has_x86_avx512bf16() || cpuinfo_has_x86_amx_bf16();
#else
return false;
#endif
}


#ifdef MLAS_TARGET_AMD64_IX86

bool
Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@ struct MLAS_SBGEMM_KERNEL_NEON {
static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K
};

bool MLASCALL
MlasBf16AccelerationSupported()
{
#if defined(MLAS_TARGET_ARM64)
return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16();
#else
return false;
#endif
}

/*
This routine converts fp32 to bf16 and copies elements from the source
matrix to the destination packed buffer.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,13 @@
Note - every "id" here, given it be group id, core id, or logical processor id, starts from 0.
*/
void WindowsEnv::InitializeCpuInfo() {
// Initialize cpuinfo once on Windows similar to PosixEnv constructor.
Comment on lines 873 to +874
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we nee a macro here if cpuinfo supported?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added code to force the library be available on Windows.

(void)cpuinfo_initialize(); // Ignore the error if it failed to initialize
// TODO: we should also call cpuinfo_deinitialize()

Check warning on line 876 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:876: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// TODO: the cpuinfo_initialize() function also gets called when creating ort thread pool, it would be better to

Check warning on line 877 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:877: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// put them in one place.
// TODO: test how it works in ARM64EC.

Check warning on line 879 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:879: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]

DWORD returnLength = 0;
GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &returnLength);
auto last_error = GetLastError();
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/platform/windows/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ limitations under the License.
#include "core/platform/windows/telemetry.h"
#include "core/common/inlined_containers.h"
#include <Windows.h>
#if defined(CPUINFO_SUPPORTED)
#include <cpuinfo.h>
#endif

namespace onnxruntime {

Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,10 +689,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean);
Expand Down Expand Up @@ -2426,13 +2428,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t,
MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Min)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Max)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Mean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Sign)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Sum)>,
Expand Down
48 changes: 46 additions & 2 deletions onnxruntime/core/providers/cpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
Gemm<MLFloat16>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
Gemm,
13,
BFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<BFloat16>()),
Gemm<BFloat16>);

bool GemmPackBFp32(AllocatorPtr& alloc,
const Tensor& tensor_b,
bool trans_a,
Expand Down Expand Up @@ -157,7 +164,7 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
GemmBroadcastBias(M, N, beta, c_data, c_shape, y_data);

if (K == 0) {
if (beta == 0 || c_data == nullptr) {
if (beta == T(0.0f) || c_data == nullptr) {
EigenMatrixMapRowMajor<T> dest(y_data, narrow<Eigen::Index>(M), narrow<Eigen::Index>(N));
dest.setZero();
}
Expand All @@ -171,7 +178,7 @@ void Gemm<T>::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b,
b_data,
// ideally we need to set the output buffer contents to 0 if bias is missing,
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
c_data != nullptr ? beta : 0,
c_data != nullptr ? beta : T(0),
y_data,
thread_pool);
}
Expand Down Expand Up @@ -401,6 +408,43 @@ Status Gemm<MLFloat16>::Compute(OpKernelContext* context) const {
return Status::OK();
}

template <>
Status Gemm<BFloat16>::Compute(OpKernelContext* context) const {
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();

const auto* A = context->Input<Tensor>(0);
const auto* B = context->Input<Tensor>(1);
const auto* C = context->Input<Tensor>(2);

// Bias could be missing. Treat as scalar 0 if that is the case.
GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans,
C != nullptr ? C->Shape() : TensorShape({}));

if (!helper.State().IsOK())
return helper.State();

ptrdiff_t M = helper.M();
ptrdiff_t N = helper.N();
ptrdiff_t K = helper.K();

auto Y = context->Output(0, {M, N});

// if input is empty tensor, return as nothing need to be calculated and we've set the shape for the output
if (M == 0 || N == 0)
return Status::OK();

BFloat16* y_data = Y->MutableData<BFloat16>();
const BFloat16* c_data = C != nullptr ? C->Data<BFloat16>() : nullptr;
const TensorShape* c_shape = C != nullptr ? &C->Shape() : nullptr;

ComputeGemm(trans_A_, trans_B_, M, N, K, static_cast<BFloat16>(alpha_), A->Data<BFloat16>(), B->Data<BFloat16>(), static_cast<BFloat16>(beta_),
c_data, c_shape, y_data, thread_pool);

ComputeActivation(y_data, SafeInt<ptrdiff_t>(M) * N, thread_pool);

return Status::OK();
}

template <>
Status Gemm<float>::Compute(OpKernelContext* context) const {
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/math/gemm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void GemmBroadcastBias(ptrdiff_t M, ptrdiff_t N, T beta,
_In_opt_ const T* c_data, _In_opt_ const TensorShape* c_shape,
_Out_writes_(M* N) T* y_data) {
// Broadcast the bias as needed if bias is given
if (beta != 0 && c_data != nullptr) {
if (beta != T(0.0f) && c_data != nullptr) {
ORT_ENFORCE(c_shape != nullptr, "c_shape is required if c_data is provided");
auto output_mat = EigenMatrixMapRowMajor<T>(y_data, M, N);
if (c_shape->Size() == 1) {
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL(
.TypeConstraint("T", BuildKernelDefConstraints<int64_t, uint64_t>()),
MatMul<int64_t>);

ONNX_CPU_OPERATOR_TYPED_KERNEL(
MatMul,
13,
BFloat16,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<BFloat16>()),
MatMul<BFloat16>);

template <typename T>
Status MatMul<T>::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
Expand Down
52 changes: 52 additions & 0 deletions onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,53 @@ void Gemm<double, ThreadPool>(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, pt
}
#endif

template <>
void Gemm<BFloat16, ThreadPool>(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M,
ptrdiff_t N, ptrdiff_t K, BFloat16 alpha, const BFloat16* A, const BFloat16* B, BFloat16 beta,
BFloat16* C, ThreadPool*) {
auto C_mat = EigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<Eigen::bfloat16*>(C), N, M);
if (beta == BFloat16(0.f)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is beta input, local var or a mistyped output? It is not clear from the code.

C_mat.setZero();
} else {
Eigen::bfloat16 beta_bfloat(static_cast<float>(beta));
C_mat *= beta_bfloat;
}
Eigen::bfloat16 alpha_bfloat(static_cast<float>(alpha));

switch (TransA) {
case CblasNoTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
return;
case CblasTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), K, M));
return;
default:
ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
case CblasTrans: {
switch (TransB) {
case CblasNoTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), N, K) *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
return;
case CblasTrans:
C_mat.noalias() += alpha_bfloat * (ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(B), K, N).transpose() *
ConstEigenMatrixMap<Eigen::bfloat16>(reinterpret_cast<const Eigen::bfloat16*>(A), M, K).transpose());
return;
default:
ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB);
}
}
default:
ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA);
}
}

template <>
void MatMul<float>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const float* A, const float* B, float* C, ThreadPool* threadpool) {
MlasGemm(CblasNoTrans, CblasNoTrans, M, N, K, 1.f, A, K, B, N, 0.f, C, N, threadpool);
Expand All @@ -194,6 +241,11 @@ void MatMul<double>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const double* A, cons
EIGEN_MATMUL_FUNCTION(double)
#endif

template <>
void MatMul<BFloat16>(ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, const BFloat16* A, const BFloat16* B, BFloat16* C, ThreadPool* threadpool) {
Gemm<BFloat16, ThreadPool>(CblasNoTrans, CblasNoTrans, M, N, K, BFloat16(1.f), A, B, BFloat16(0.f), C, threadpool);
}

template <>
void GemmEx<float, ThreadPool>(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K,
float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C,
Expand Down
Loading
Loading