Skip to content

Commit 6ebf925

Browse files
authored
AVX swizzle broadcast and swap optimization (#1213)
1 parent 4abd575 commit 6ebf925

File tree

1 file changed

+68
-58
lines changed

1 file changed

+68
-58
lines changed

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,88 +1629,98 @@ namespace xsimd
16291629
}
16301630
return split;
16311631
}
1632-
// Duplicate lanes separately
1633-
// 1) duplicate low and high lanes
1634-
__m256 low_dup = _mm256_permute2f128_ps(self, self, 0x00); // [low | low]
1635-
__m256 hi_dup = _mm256_permute2f128_ps(self, self, 0x11); // [high| high]
1632+
constexpr auto lane_mask = mask % make_batch_constant<uint32_t, (mask.size / 2), A>();
1633+
XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask))
1634+
{
1635+
__m256 broadcast = _mm256_permute2f128_ps(self, self, 0x00); // [low | low]
1636+
return _mm256_permutevar_ps(broadcast, lane_mask.as_batch());
1637+
}
1638+
XSIMD_IF_CONSTEXPR(detail::is_only_from_hi(mask))
1639+
{
1640+
__m256 broadcast = _mm256_permute2f128_ps(self, self, 0x11); // [high | high]
1641+
return _mm256_permutevar_ps(broadcast, lane_mask.as_batch());
1642+
}
1643+
1644+
// Fallback to general algorithm. This is the same as the dynamic version with the exception
1645+
// that possible operations are done at compile time.
1646+
1647+
// swap lanes
1648+
__m256 swapped = _mm256_permute2f128_ps(self, self, 0x01); // [high | low]
16361649

1637-
// 2) build lane-local index vector (each element = source_index & 3)
1638-
constexpr batch_constant<uint32_t, A, (V0 % 4), (V1 % 4), (V2 % 4), (V3 % 4), (V4 % 4), (V5 % 4), (V6 % 4), (V7 % 4)> half_mask;
1650+
// normalize mask taking modulo 4
1651+
constexpr auto half_mask = mask % make_batch_constant<uint32_t, 4, A>();
16391652

1640-
__m256 r0 = _mm256_permutevar_ps(low_dup, half_mask.as_batch()); // pick from low lane
1641-
__m256 r1 = _mm256_permutevar_ps(hi_dup, half_mask.as_batch()); // pick from high lane
1653+
// permute within each lane
1654+
__m256 r0 = _mm256_permutevar_ps(self, half_mask.as_batch());
1655+
__m256 r1 = _mm256_permutevar_ps(swapped, half_mask.as_batch());
16421656

1643-
constexpr batch_bool_constant<uint32_t, A, (V0 >= 4), (V1 >= 4), (V2 >= 4), (V3 >= 4), (V4 >= 4), (V5 >= 4), (V6 >= 4), (V7 >= 4)> lane_mask {};
1657+
// select lane by the mask index divided by 4
1658+
constexpr auto lane = batch_constant<uint32_t, A, 0, 0, 0, 0, 1, 1, 1, 1> {};
1659+
constexpr int lane_idx = ((mask / make_batch_constant<uint32_t, 4, A>()) != lane).mask();
16441660

1645-
return _mm256_blend_ps(r0, r1, lane_mask.mask());
1661+
return _mm256_blend_ps(r0, r1, lane_idx);
16461662
}
16471663

16481664
template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3>
16491665
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, V0, V1, V2, V3> mask, requires_arch<avx>) noexcept
16501666
{
16511667
// cannot use detail::mod_shuffle as the mod and shift are different in this case
1652-
constexpr auto imm = ((V0 & 1) << 0) | ((V1 & 1) << 1) | ((V2 & 1) << 2) | ((V3 & 1) << 3);
1653-
XSIMD_IF_CONSTEXPR(detail::is_identity(mask)) { return self; }
1668+
constexpr auto imm = ((V0 % 2) << 0) | ((V1 % 2) << 1) | ((V2 % 2) << 2) | ((V3 % 2) << 3);
1669+
XSIMD_IF_CONSTEXPR(detail::is_identity(mask))
1670+
{
1671+
return self;
1672+
}
16541673
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
16551674
{
16561675
return _mm256_permute_pd(self, imm);
16571676
}
1658-
// duplicate low and high part of input
1659-
__m256d lo = _mm256_permute2f128_pd(self, self, 0x00);
1660-
__m256d hi = _mm256_permute2f128_pd(self, self, 0x11);
1677+
XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask))
1678+
{
1679+
__m256d broadcast = _mm256_permute2f128_pd(self, self, 0x00); // [low | low]
1680+
return _mm256_permute_pd(broadcast, imm);
1681+
}
1682+
XSIMD_IF_CONSTEXPR(detail::is_only_from_hi(mask))
1683+
{
1684+
__m256d broadcast = _mm256_permute2f128_pd(self, self, 0x11); // [high | high]
1685+
return _mm256_permute_pd(broadcast, imm);
1686+
}
1687+
1688+
// Fallback to general algorithm. This is the same as the dynamic version with the exception
1689+
// that possible operations are done at compile time.
1690+
1691+
// swap lanes
1692+
__m256d swapped = _mm256_permute2f128_pd(self, self, 0x01); // [high | low]
16611693

16621694
// permute within each lane
1663-
__m256d r0 = _mm256_permute_pd(lo, imm);
1664-
__m256d r1 = _mm256_permute_pd(hi, imm);
1695+
__m256d r0 = _mm256_permute_pd(self, imm);
1696+
__m256d r1 = _mm256_permute_pd(swapped, imm);
16651697

1666-
// mask to choose the right lane
1667-
constexpr batch_bool_constant<uint64_t, A, (V0 >= 2), (V1 >= 2), (V2 >= 2), (V3 >= 2)> blend_mask;
1698+
// select lane by the mask index divided by 2
1699+
constexpr auto lane = batch_constant<uint64_t, A, 0, 0, 1, 1> {};
1700+
constexpr int lane_idx = ((mask / make_batch_constant<uint64_t, 2, A>()) != lane).mask();
16681701

16691702
// blend the two permutes
1670-
return _mm256_blend_pd(r0, r1, blend_mask.mask());
1671-
}
1672-
template <class A,
1673-
typename T,
1674-
uint32_t V0,
1675-
uint32_t V1,
1676-
uint32_t V2,
1677-
uint32_t V3,
1678-
uint32_t V4,
1679-
uint32_t V5,
1680-
uint32_t V6,
1681-
uint32_t V7,
1682-
detail::enable_sized_integral_t<T, 4> = 0>
1683-
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self,
1684-
batch_constant<uint32_t, A,
1685-
V0,
1686-
V1,
1687-
V2,
1688-
V3,
1689-
V4,
1690-
V5,
1691-
V6,
1692-
V7> const& mask,
1693-
requires_arch<avx>) noexcept
1703+
return _mm256_blend_pd(r0, r1, lane_idx);
1704+
}
1705+
1706+
template <
1707+
class A, typename T,
1708+
uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7,
1709+
detail::enable_sized_integral_t<T, 4> = 0>
1710+
XSIMD_INLINE batch<T, A> swizzle(
1711+
batch<T, A> const& self,
1712+
batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> const& mask,
1713+
requires_arch<avx>) noexcept
16941714
{
1695-
return bitwise_cast<T>(
1696-
swizzle(bitwise_cast<float>(self), mask));
1715+
return bitwise_cast<T>(swizzle(bitwise_cast<float>(self), mask));
16971716
}
16981717

1699-
template <class A,
1700-
typename T,
1701-
uint64_t V0,
1702-
uint64_t V1,
1703-
uint64_t V2,
1704-
uint64_t V3,
1705-
detail::enable_sized_integral_t<T, 8> = 0>
1706-
XSIMD_INLINE batch<T, A>
1707-
swizzle(batch<T, A> const& self,
1708-
batch_constant<uint64_t, A, V0, V1, V2, V3> const& mask,
1709-
requires_arch<avx>) noexcept
1718+
template <class A, typename T, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3, detail::enable_sized_integral_t<T, 8> = 0>
1719+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint64_t, A, V0, V1, V2, V3> const& mask, requires_arch<avx>) noexcept
17101720
{
1711-
return bitwise_cast<T>(
1712-
swizzle(bitwise_cast<double>(self), mask));
1721+
return bitwise_cast<T>(swizzle(bitwise_cast<double>(self), mask));
17131722
}
1723+
17141724
// transpose
17151725
template <class A>
17161726
XSIMD_INLINE void transpose(batch<float, A>* matrix_begin, batch<float, A>* matrix_end, requires_arch<avx>) noexcept

0 commit comments

Comments
 (0)