From 99cc865d511a53b416a804dc771a779772aacfa3 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sat, 6 Aug 2022 14:29:45 +0200 Subject: [PATCH 1/4] Test whether OS supports AVX512 Signed-off-by: Stefan Weil --- src/arch/simddetect.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index 1afe5a5d81..b7682f6959 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -171,9 +171,12 @@ SIMDDetect::SIMDDetect() { // be used inside an if. __cpuid_count(7, 0, eax, ebx, ecx, edx); avx2_available_ = (ebx & 0x00000020) != 0; - avx512F_available_ = (ebx & 0x00010000) != 0; - avx512BW_available_ = (ebx & 0x40000000) != 0; - avx512VNNI_available_ = (ecx & 0x00000800) != 0; + if ((xgetbv() & 0xe0) == 0xe0) { + // OS supports AVX512. + avx512F_available_ = (ebx & 0x00010000) != 0; + avx512BW_available_ = (ebx & 0x40000000) != 0; + avx512VNNI_available_ = (ecx & 0x00000800) != 0; + } } # endif } @@ -202,9 +205,12 @@ SIMDDetect::SIMDDetect() { if (max_function_id >= 7) { __cpuid(cpuInfo, 7); avx2_available_ = (cpuInfo[1] & 0x00000020) != 0; - avx512F_available_ = (cpuInfo[1] & 0x00010000) != 0; - avx512BW_available_ = (cpuInfo[1] & 0x40000000) != 0; - avx512VNNI_available_ = (cpuInfo[2] & 0x00000800) != 0; + if ((_xgetbv(0) & 0xe0) == 0xe0) { + // OS supports AVX512. + avx512F_available_ = (cpuInfo[1] & 0x00010000) != 0; + avx512BW_available_ = (cpuInfo[1] & 0x40000000) != 0; + avx512VNNI_available_ = (cpuInfo[2] & 0x00000800) != 0; + } } # endif } From 57034199a082eae3e39d9f9a58e42a5a78564258 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Fri, 5 Aug 2022 21:43:37 +0200 Subject: [PATCH 2/4] Implement intSimdMatrixAVX512VNNI (dummy) This dummy implementation just copied the code from AVX2. Signed-off-by: Stefan Weil --- Makefile.am | 8 + configure.ac | 7 + src/arch/intsimdmatrix.h | 1 + src/arch/intsimdmatrixavx512vnni.cpp | 594 +++++++++++++++++++++++++++ src/arch/simddetect.cpp | 11 + 5 files changed, 621 insertions(+) create mode 100644 src/arch/intsimdmatrixavx512vnni.cpp diff --git a/Makefile.am b/Makefile.am index 84159945b5..763368bd48 100644 --- a/Makefile.am +++ b/Makefile.am @@ -170,6 +170,14 @@ libtesseract_la_LIBADD += libtesseract_avx512.la noinst_LTLIBRARIES += libtesseract_avx512.la endif +if HAVE_AVX512VNNI +libtesseract_avx512vnni_la_CXXFLAGS = -march=icelake-client +libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil +libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp +libtesseract_la_LIBADD += libtesseract_avx512vnni.la +noinst_LTLIBRARIES += libtesseract_avx512vnni.la +endif + if HAVE_FMA libtesseract_fma_la_CXXFLAGS = -mfma libtesseract_fma_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil diff --git a/configure.ac b/configure.ac index d32b6f1d59..a0e421f1c4 100644 --- a/configure.ac +++ b/configure.ac @@ -130,6 +130,7 @@ AX_CHECK_COMPILE_FLAG([-Werror=unused-command-line-argument], [WERROR=-Werror=un AM_CONDITIONAL([HAVE_AVX], false) AM_CONDITIONAL([HAVE_AVX2], false) AM_CONDITIONAL([HAVE_AVX512F], false) +AM_CONDITIONAL([HAVE_AVX512VNNI], false) AM_CONDITIONAL([HAVE_FMA], false) AM_CONDITIONAL([HAVE_SSE4_1], false) AM_CONDITIONAL([HAVE_NEON], false) @@ -156,6 +157,12 @@ case "${host_cpu}" in AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions]) fi + AX_CHECK_COMPILE_FLAG([-march=icelake-client], [avx512vnni=true], [avx512vnni=false], [$WERROR]) + AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni) + if $avx512vnni; then + AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions]) + fi + AX_CHECK_COMPILE_FLAG([-mfma], [fma=true], [fma=false], [$WERROR]) AM_CONDITIONAL([HAVE_FMA], $fma) if $fma; then diff --git a/src/arch/intsimdmatrix.h b/src/arch/intsimdmatrix.h index d93f928dbc..968d407baf 100644 --- a/src/arch/intsimdmatrix.h +++ b/src/arch/intsimdmatrix.h @@ -117,6 +117,7 @@ struct TESS_API IntSimdMatrix { static const IntSimdMatrix intSimdMatrixNEON; // Only available with AVX2 / AVX / FMA / SSE. static const IntSimdMatrix intSimdMatrixAVX2; + static const IntSimdMatrix intSimdMatrixAVX512VNNI; static const IntSimdMatrix intSimdMatrixSSE; }; diff --git a/src/arch/intsimdmatrixavx512vnni.cpp b/src/arch/intsimdmatrixavx512vnni.cpp new file mode 100644 index 0000000000..279c5ce686 --- /dev/null +++ b/src/arch/intsimdmatrixavx512vnni.cpp @@ -0,0 +1,594 @@ +/////////////////////////////////////////////////////////////////////// +// File: intsimdmatrixavx512vnni.cpp +// Description: matrix-vector product for 8-bit data on avx512vnni. +// Author: Ray Smith +// +// (C) Copyright 2017, Google Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/////////////////////////////////////////////////////////////////////// + +#include "intsimdmatrix.h" + +#if !defined(__AVX2__) +# if defined(__i686__) || defined(__x86_64__) +# error Implementation only for AVX2 capable architectures +# endif +#else +# include +# include +# include +# include + +# if defined(_MSC_VER) && _MSC_VER >= 1925 && _MSC_VER <= 1929 && \ + defined(_WIN32) && !defined(_WIN64) +// Optimize for size (/Os) instead of using the default optimization for some +// versions of the 32 bit Visual Studio compiler which generate buggy code. +# pragma optimize("", off) +# pragma optimize("s", on) +# endif + +namespace tesseract { + +// Number of outputs held in each register. 8 x 32 bit ints. +constexpr int kNumOutputsPerRegister = 8; +// Maximum number of registers that we will use. +constexpr int kMaxOutputRegisters = 8; +// Number of inputs in the inputs register. +constexpr int kNumInputsPerRegister = 32; +// Number of inputs in each weight group. +constexpr int kNumInputsPerGroup = 4; +// Number of groups of inputs to be broadcast. +constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup; + +// Functions to compute part of a matrix.vector multiplication. The weights +// are in a very specific order (see above) in w, which is multiplied by +// u of length num_in, to produce output v after scaling the integer results +// by the corresponding member of scales. +// The amount of w and scales consumed is fixed and not available to the +// caller. The number of outputs written to v will be at most num_out. + +// Computes one set of 4x8 products of inputs and weights, adding to result. +// Horizontally adds 4 adjacent results, making 8x32-bit results. +// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers. +// Note that wi must previously have been re-organized with blocks of 4x8 +// weights in contiguous memory. +// ones is a register of 16x16-bit values all equal to 1. +// Note: wi is incremented by the amount of data read. +// weights and reps are scratch registers. +// This function must be inlined with references in order for the compiler to +// correctly use the registers declared in the caller. +static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, const int8_t *&wi, + __m256i &weights, __m256i &reps, __m256i &result) { + // Load a 4x8 block of weights. + weights = _mm256_loadu_si256(reinterpret_cast(wi)); + wi += kNumInputsPerRegister; + // Normalize the signs on rep_input, weights, so weights is always +ve. + reps = _mm256_sign_epi8(rep_input, weights); + weights = _mm256_sign_epi8(weights, weights); + // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, + // with adjacent pairs added. + weights = _mm256_maddubs_epi16(weights, reps); + // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, + // with adjacent pairs added. What we really want is a horizontal add of + // 16+16=32 bit result, but there is no such instruction, so multiply by + // 16-bit ones instead. It is probably faster than all the sign-extending, + // permuting and adding that would otherwise be required. + weights = _mm256_madd_epi16(weights, ones); + result = _mm256_add_epi32(result, weights); +} + +// Load 64 bits into the bottom of a 128bit register. +// We don't actually care what the top 64bits are, but this ends +// up with them being zero. +static inline __m128i load64_to_128(const int8_t *wi_) { + const auto *wi = reinterpret_cast(wi_); + return _mm_set_epi64x(0, wi[0]); +} + +#if defined(FAST_FLOAT) + +static inline void ExtractResults8(__m256i result, const int8_t *wi, + const float *scales, float *v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg + __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result = _mm256_add_epi32(result, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result); + result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); +} + +static inline void ExtractResults16(__m256i result0, __m256i result1, + const int8_t *&wi, const float *&scales, + float *&v) { + __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); + // 8x8bit vals in bottom of 128bit reg + const __m256i bias_scale = + _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + __m256 scale01234567 = _mm256_loadu_ps(scales); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 + __m256 res01234567 = _mm256_cvtepi32_ps(result0); + result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v, res01234567); + w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); + w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + scale01234567 = _mm256_loadu_ps(scales + 8); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 + res01234567 = _mm256_cvtepi32_ps(result1); + result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); + res01234567 = _mm256_mul_ps(res01234567, scale01234567); + _mm256_storeu_ps(v + 8, res01234567); + wi += 16; + scales += 16; + v += 16; +} + +// Computes part of matrix.vector v = Wu. Computes N=64 results. +// The weights *must* be arranged so that consecutive reads from wi +// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of +// (kNumInputsPerGroup inputs))). After that there must be N consecutive +// bias weights, before continuing with any more weights. +// u must be padded out with zeros to +// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. +static void PartialMatrixDotVector64(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + __m256i result4 = _mm256_setzero_si256(); + __m256i result5 = _mm256_setzero_si256(); + __m256i result6 = _mm256_setzero_si256(); + __m256i result7 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + MultiplyGroup(rep_input, ones, wi, weights, reps, result4); + MultiplyGroup(rep_input, ones, wi, weights, reps, result5); + MultiplyGroup(rep_input, ones, wi, weights, reps, result6); + MultiplyGroup(rep_input, ones, wi, weights, reps, result7); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); + ExtractResults16(result4, result5, wi, scales, v); + ExtractResults16(result6, result7, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=32 results. +// For details see PartialMatrixDotVector64 with N=32. +static void PartialMatrixDotVector32(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=16 results. +// For details see PartialMatrixDotVector64 with N=16. +static void PartialMatrixDotVector16(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + } + } + ExtractResults16(result0, result1, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=8 results. +// For details see PartialMatrixDotVector64 with N=8. +static inline void PartialMatrixDotVector8(const int8_t *wi, const float *scales, const int8_t *u, + int num_in, float *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + } + } + ExtractResults8(result0, wi, scales, v); +} + +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const float *scales, + const int8_t *u, float *v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + // Each call to a partial_func_ produces group_size outputs, except the + // last one, which can produce less. + const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); + const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); + int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; + int output = 0; + + int w_step = (rounded_num_in + 1) * group_size; + + // Run with this group size, until it would produce too much output, then + // switch to a smaller size. + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); + } +} +#else +static inline void ExtractResults8(__m256i result, const int8_t *wi, const double *scales, + double *v) { + __m128i w128 = load64_to_128(wi); // 8x8bit vals in bottom of 128bit reg + __m256i w256 = _mm256_cvtepi8_epi32(w128); // 8x32bit vals in 256bit reg + __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256d scale0123 = _mm256_loadu_pd(scales); + __m256d scale4567 = _mm256_loadu_pd(scales + 4); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result = _mm256_add_epi32(result, w256); // result += bias * 127 + __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); + result = _mm256_permute4x64_epi64(result, 2 + (3 << 2)); + __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result)); + res0123 = _mm256_mul_pd(res0123, scale0123); + res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_pd(v, res0123); + _mm256_storeu_pd(v + 4, res4567); +} + +static inline void ExtractResults16(__m256i result0, __m256i result1, const int8_t *&wi, + const double *&scales, double *&v) { + __m128i w8 = _mm_loadu_si128(reinterpret_cast(wi)); + // 8x8bit vals in bottom of 128bit reg + const __m256i bias_scale = _mm256_set_epi32(127, 127, 127, 127, 127, 127, 127, 127); + __m256i w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + __m256d scale0123 = _mm256_loadu_pd(scales); + __m256d scale4567 = _mm256_loadu_pd(scales + 4); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result0 = _mm256_add_epi32(result0, w256); // result += bias * 127 + __m256d res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); + result0 = _mm256_permute4x64_epi64(result0, 2 + (3 << 2)); + __m256d res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result0)); + res0123 = _mm256_mul_pd(res0123, scale0123); + res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_pd(v, res0123); + _mm256_storeu_pd(v + 4, res4567); + w8 = _mm_shuffle_epi32(w8, 2 + (3 << 2)); + w256 = _mm256_cvtepi8_epi32(w8); // 8x32bit vals in 256bit reg + scale0123 = _mm256_loadu_pd(scales + 8); + scale4567 = _mm256_loadu_pd(scales + 12); + w256 = _mm256_mullo_epi32(w256, bias_scale); // 8x32 + result1 = _mm256_add_epi32(result1, w256); // result += bias * 127 + res0123 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); + result1 = _mm256_permute4x64_epi64(result1, 2 + (3 << 2)); + res4567 = _mm256_cvtepi32_pd(_mm256_castsi256_si128(result1)); + res0123 = _mm256_mul_pd(res0123, scale0123); + res4567 = _mm256_mul_pd(res4567, scale4567); + _mm256_storeu_pd(v + 8, res0123); + _mm256_storeu_pd(v + 12, res4567); + wi += 16; + scales += 16; + v += 16; +} + +// Computes part of matrix.vector v = Wu. Computes N=64 results. +// The weights *must* be arranged so that consecutive reads from wi +// provides (num_in/kNumInputsPerGroup groups of (N output dim groups of +// (kNumInputsPerGroup inputs))). After that there must be N consecutive +// bias weights, before continuing with any more weights. +// u must be padded out with zeros to +// kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements. +static void PartialMatrixDotVector64(const int8_t *wi, const double *scales, const int8_t *u, + int num_in, double *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + __m256i result4 = _mm256_setzero_si256(); + __m256i result5 = _mm256_setzero_si256(); + __m256i result6 = _mm256_setzero_si256(); + __m256i result7 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + MultiplyGroup(rep_input, ones, wi, weights, reps, result4); + MultiplyGroup(rep_input, ones, wi, weights, reps, result5); + MultiplyGroup(rep_input, ones, wi, weights, reps, result6); + MultiplyGroup(rep_input, ones, wi, weights, reps, result7); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); + ExtractResults16(result4, result5, wi, scales, v); + ExtractResults16(result6, result7, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=32 results. +// For details see PartialMatrixDotVector64 with N=32. +static void PartialMatrixDotVector32(const int8_t *wi, const double *scales, const int8_t *u, + int num_in, double *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + __m256i result2 = _mm256_setzero_si256(); + __m256i result3 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + MultiplyGroup(rep_input, ones, wi, weights, reps, result2); + MultiplyGroup(rep_input, ones, wi, weights, reps, result3); + } + } + ExtractResults16(result0, result1, wi, scales, v); + ExtractResults16(result2, result3, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=16 results. +// For details see PartialMatrixDotVector64 with N=16. +static void PartialMatrixDotVector16(const int8_t *wi, const double *scales, const int8_t *u, + int num_in, double *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + __m256i result1 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + MultiplyGroup(rep_input, ones, wi, weights, reps, result1); + } + } + ExtractResults16(result0, result1, wi, scales, v); +} + +// Computes part of matrix.vector v = Wu. Computes N=8 results. +// For details see PartialMatrixDotVector64 with N=8. +static inline void PartialMatrixDotVector8(const int8_t *wi, const double *scales, const int8_t *u, + int num_in, double *v) { + // Register containing 16-bit ones for horizontal add with 16->32 bit + // conversion. + __m256i ones = _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); + __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1); + // Initialize all the results to 0. + __m256i result0 = _mm256_setzero_si256(); + // Iterate over the input (u), one registerful at a time. + for (int j = 0; j < num_in;) { + __m256i inputs = _mm256_loadu_si256(reinterpret_cast(u + j)); + // Inputs are processed in groups of kNumInputsPerGroup, replicated + // kNumInputGroups times. + for (int ig = 0; ig < kNumInputGroups && j < num_in; ++ig, j += kNumInputsPerGroup) { + // Replicate the low 32 bits (4 inputs) 8 times. + __m256i rep_input = _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs)); + // Rotate the inputs in groups of 4, so the next 4 inputs are ready. + inputs = _mm256_permutevar8x32_epi32(inputs, shift_id); + __m256i weights, reps; + // Mul-add, with horizontal add of the 4 inputs to each of the results. + MultiplyGroup(rep_input, ones, wi, weights, reps, result0); + } + } + ExtractResults8(result0, wi, scales, v); +} + +static void matrixDotVector(int dim1, int dim2, const int8_t *wi, const double *scales, + const int8_t *u, double *v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + // Each call to a partial_func_ produces group_size outputs, except the + // last one, which can produce less. + const int rounded_num_in = IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); + const int rounded_num_out = IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); + int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; + int output = 0; + + int w_step = (rounded_num_in + 1) * group_size; + + // Run with this group size, until it would produce too much output, then + // switch to a smaller size. + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector64(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector32(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector16(wi, scales, u, rounded_num_in, v); + wi += w_step; + scales += group_size; + v += group_size; + output += group_size; + } + group_size /= 2; + w_step /= 2; + + if (output + group_size <= rounded_num_out) { + PartialMatrixDotVector8(wi, scales, u, rounded_num_in, v); + } +} +#endif + +const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX512VNNI = { + // Function. + matrixDotVector, + // Number of 32 bit outputs held in each register. + kNumOutputsPerRegister, + // Maximum number of registers that we will use to hold outputs. + kMaxOutputRegisters, + // Number of 8 bit inputs in the inputs register. + kNumInputsPerRegister, + // Number of inputs in each weight group. + kNumInputsPerGroup +}; + +} // namespace tesseract. + +#endif diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index b7682f6959..3b43a5d1d5 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -243,7 +243,18 @@ SIMDDetect::SIMDDetect() { #if defined(HAVE_AVX512F) } else if (avx512F_available_) { // AVX512F detected. +# if defined(HAVE_AVX512VNNI) + if (avx512VNNI_available_) { + printf("mit VNNI\n"); + SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX512VNNI); + } else { + printf("ohne VNNI\n"); + SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX2); + } +# else + printf("ohne VNNI???\n"); SetDotProduct(DotProductAVX512F, &IntSimdMatrix::intSimdMatrixAVX2); +# endif #endif #if defined(HAVE_AVX2) } else if (avx2_available_) { From 8d0bd689835332be3e732b3f57f7d41ea557b367 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sat, 6 Aug 2022 09:36:16 +0200 Subject: [PATCH 3/4] Add unittest for AVX512VNNI Signed-off-by: Stefan Weil --- unittest/intsimdmatrix_test.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unittest/intsimdmatrix_test.cc b/unittest/intsimdmatrix_test.cc index 95688eed5a..b014410548 100644 --- a/unittest/intsimdmatrix_test.cc +++ b/unittest/intsimdmatrix_test.cc @@ -135,4 +135,18 @@ TEST_F(IntSimdMatrixTest, AVX2) { #endif } +// Tests that the AVX512VNNI implementation gets the same result as the vanilla. +TEST_F(IntSimdMatrixTest, AVX512VNNI) { +#if defined(HAVE_AVX512VNNI) + if (!SIMDDetect::IsAVX512VNNIAvailable()) { + GTEST_LOG_(INFO) << "No AVX512VNNI found! Not tested!"; + GTEST_SKIP(); + } + ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX512VNNI); +#else + GTEST_LOG_(INFO) << "AVX512VNNI unsupported! Not tested!"; + GTEST_SKIP(); +#endif +} + } // namespace tesseract From 232093f0463ae7e34dd94b026029cd1bf2e82657 Mon Sep 17 00:00:00 2001 From: Amit Dovev Date: Tue, 9 Aug 2022 15:22:38 +0300 Subject: [PATCH 4/4] Add the needed changes to support AVX512VNNI --- Makefile.am | 2 +- configure.ac | 2 +- src/arch/intsimdmatrixavx512vnni.cpp | 20 ++++++++------------ 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/Makefile.am b/Makefile.am index 763368bd48..8398ed5b04 100644 --- a/Makefile.am +++ b/Makefile.am @@ -171,7 +171,7 @@ noinst_LTLIBRARIES += libtesseract_avx512.la endif if HAVE_AVX512VNNI -libtesseract_avx512vnni_la_CXXFLAGS = -march=icelake-client +libtesseract_avx512vnni_la_CXXFLAGS = -mavx512vnni -mavx512vl libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp libtesseract_la_LIBADD += libtesseract_avx512vnni.la diff --git a/configure.ac b/configure.ac index a0e421f1c4..dc02f2d90b 100644 --- a/configure.ac +++ b/configure.ac @@ -157,7 +157,7 @@ case "${host_cpu}" in AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions]) fi - AX_CHECK_COMPILE_FLAG([-march=icelake-client], [avx512vnni=true], [avx512vnni=false], [$WERROR]) + AX_CHECK_COMPILE_FLAG([-mavx512vnni], [avx512vnni=true], [avx512vnni=false], [$WERROR]) AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni) if $avx512vnni; then AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions]) diff --git a/src/arch/intsimdmatrixavx512vnni.cpp b/src/arch/intsimdmatrixavx512vnni.cpp index 279c5ce686..bf6f0764ac 100644 --- a/src/arch/intsimdmatrixavx512vnni.cpp +++ b/src/arch/intsimdmatrixavx512vnni.cpp @@ -17,9 +17,9 @@ #include "intsimdmatrix.h" -#if !defined(__AVX2__) +#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__) # if defined(__i686__) || defined(__x86_64__) -# error Implementation only for AVX2 capable architectures +# error Implementation only for AVX512VNNI capable architectures # endif #else # include @@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones, // Normalize the signs on rep_input, weights, so weights is always +ve. reps = _mm256_sign_epi8(rep_input, weights); weights = _mm256_sign_epi8(weights, weights); - // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, - // with adjacent pairs added. - weights = _mm256_maddubs_epi16(weights, reps); - // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, - // with adjacent pairs added. What we really want is a horizontal add of - // 16+16=32 bit result, but there is no such instruction, so multiply by - // 16-bit ones instead. It is probably faster than all the sign-extending, - // permuting and adding that would otherwise be required. - weights = _mm256_madd_epi16(weights, ones); - result = _mm256_add_epi32(result, weights); + + // VNNI instruction. It replaces 3 AVX2 instructions: + //weights = _mm256_maddubs_epi16(weights, reps); + //weights = _mm256_madd_epi16(weights, ones); + //result = _mm256_add_epi32(result, weights); + result = _mm256_dpbusd_epi32(result, weights, reps); } // Load 64 bits into the bottom of a 128bit register.