Skip to content

Commit 4abd575

Browse files
authored
Improve Avx2 constant shift on uint8_t and fix sse2 (#1215)
1 parent 79902d7 commit 4abd575

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,15 @@ namespace xsimd
311311
{
312312
constexpr auto bits = std::numeric_limits<T>::digits + std::numeric_limits<T>::is_signed;
313313
static_assert(shift < bits, "Shift must be less than the number of bits in T");
314+
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
315+
{
316+
// 8-bit left shift via 16-bit shift + mask
317+
__m256i shifted = _mm256_slli_epi16(self, shift);
318+
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
319+
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? (~0u << shift) : 0);
320+
const __m256i mask = _mm256_set1_epi8(mask8);
321+
return _mm256_and_si256(shifted, mask);
322+
}
314323
XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
315324
{
316325
return _mm256_slli_epi16(self, shift);
@@ -323,10 +332,6 @@ namespace xsimd
323332
{
324333
return _mm256_slli_epi64(self, shift);
325334
}
326-
else
327-
{
328-
return bitwise_lshift<shift>(self, avx {});
329-
}
330335
}
331336

332337
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value>::type>
@@ -444,10 +449,12 @@ namespace xsimd
444449
{
445450
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
446451
{
447-
const __m256i byte_mask = _mm256_set1_epi16(0x00FF);
448-
__m256i u16 = _mm256_and_si256(self, byte_mask);
449-
__m256i r16 = _mm256_srli_epi16(u16, shift);
450-
return _mm256_and_si256(r16, byte_mask);
452+
// 8-bit left shift via 16-bit shift + mask
453+
const __m256i shifted = _mm256_srli_epi16(self, shift);
454+
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
455+
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0);
456+
const __m256i mask = _mm256_set1_epi8(mask8);
457+
return _mm256_and_si256(shifted, mask);
451458
}
452459
XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
453460
{
@@ -461,10 +468,6 @@ namespace xsimd
461468
{
462469
return _mm256_srli_epi64(self, shift);
463470
}
464-
else
465-
{
466-
return bitwise_rshift<shift>(self, avx {});
467-
}
468471
}
469472
}
470473

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ namespace xsimd
306306
{
307307
// 8-bit left shift via 16-bit shift + mask
308308
__m128i shifted = _mm_slli_epi16(self, static_cast<int>(shift));
309-
__m128i mask = _mm_set1_epi8(static_cast<char>(0xFF << shift));
309+
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
310+
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? (~0u << shift) : 0);
311+
const __m128i mask = _mm_set1_epi8(mask8);
310312
return _mm_and_si128(shifted, mask);
311313
}
312314
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
@@ -489,10 +491,12 @@ namespace xsimd
489491
{
490492
XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
491493
{
492-
// Emulate byte-wise logical right shift using 16-bit shifts + per-byte mask.
493-
__m128i s16 = _mm_srli_epi16(self, static_cast<int>(shift));
494-
__m128i mask = _mm_set1_epi8(static_cast<char>(0xFFu >> shift));
495-
return _mm_and_si128(s16, mask);
494+
// 8-bit left shift via 16-bit shift + mask
495+
__m128i shifted = _mm_srli_epi16(self, static_cast<int>(shift));
496+
// TODO(C++17): without `if constexpr` we must ensure the compile-time shift does not overflow
497+
constexpr uint8_t mask8 = static_cast<uint8_t>(sizeof(T) == 1 ? ((1u << shift) - 1u) : 0);
498+
const __m128i mask = _mm_set1_epi8(mask8);
499+
return _mm_and_si128(shifted, mask);
496500
}
497501
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
498502
{

test/test_xsimd_api.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ struct xsimd_api_integral_types_functions
358358
value_type val1(shift);
359359
value_type r = val0 << val1;
360360
value_type ir = val0 << shift;
361-
value_type cr = xsimd::bitwise_lshift<shift>(val0);
361+
T cr = xsimd::bitwise_lshift<shift>(T(val0));
362362
CHECK_EQ(extract(xsimd::bitwise_lshift(T(val0), T(val1))), r);
363363
CHECK_EQ(extract(ir), r);
364364
CHECK_EQ(extract(cr), r);
@@ -371,7 +371,7 @@ struct xsimd_api_integral_types_functions
371371
value_type val1(shift);
372372
value_type r = val0 >> val1;
373373
value_type ir = val0 >> shift;
374-
value_type cr = xsimd::bitwise_rshift<shift>(val0);
374+
T cr = xsimd::bitwise_rshift<shift>(T(val0));
375375
CHECK_EQ(extract(xsimd::bitwise_rshift(T(val0), T(val1))), r);
376376
CHECK_EQ(extract(ir), r);
377377
CHECK_EQ(extract(cr), r);
@@ -391,7 +391,7 @@ struct xsimd_api_integral_types_functions
391391
value_type val0(12);
392392
value_type val1(count);
393393
value_type r = (val0 << val1) | (val0 >> (N - val1));
394-
value_type cr = xsimd::rotl<count>(val0);
394+
T cr = xsimd::rotl<count>(T(val0));
395395
CHECK_EQ(extract(xsimd::rotl(T(val0), T(val1))), r);
396396
CHECK_EQ(extract(cr), r);
397397
}
@@ -403,7 +403,7 @@ struct xsimd_api_integral_types_functions
403403
value_type val0(12);
404404
value_type val1(count);
405405
value_type r = (val0 >> val1) | (val0 << (N - val1));
406-
value_type cr = xsimd::rotr<3>(val0);
406+
T cr = xsimd::rotr<3>(T(val0));
407407
CHECK_EQ(extract(xsimd::rotr(T(val0), T(val1))), r);
408408
CHECK_EQ(extract(cr), r);
409409
}

0 commit comments

Comments
 (0)