diff --git a/include/stdx/bit.hpp b/include/stdx/bit.hpp index 3706177..fff7213 100644 --- a/include/stdx/bit.hpp +++ b/include/stdx/bit.hpp @@ -399,21 +399,24 @@ template struct bitmask_subtract> { }; } // namespace detail -template - 1, +template > - 1, std::size_t Lsb = 0> [[nodiscard]] CONSTEVAL auto bit_mask() noexcept -> T { - static_assert(Msb < detail::num_digits_v, + using U = underlying_type_t; + static_assert(Msb < detail::num_digits_v, "bit_mask requested exceeds the range of the type"); static_assert(Msb >= Lsb, "bit_mask range is invalid"); - return detail::bitmask_subtract{}(detail::mask_bits_t{}(Msb + 1), - detail::mask_bits_t{}(Lsb)); + return static_cast(detail::bitmask_subtract{}( + detail::mask_bits_t{}(Msb + 1), detail::mask_bits_t{}(Lsb))); } template [[nodiscard]] constexpr auto bit_mask(std::size_t Msb, std::size_t Lsb = 0) noexcept -> T { - return detail::bitmask_subtract{}(detail::mask_bits_t{}(Msb + 1), - detail::mask_bits_t{}(Lsb)); + using U = underlying_type_t; + return static_cast(detail::bitmask_subtract{}( + detail::mask_bits_t{}(Msb + 1), detail::mask_bits_t{}(Lsb))); } template constexpr auto bit_size() -> std::size_t { diff --git a/test/bit.cpp b/test/bit.cpp index 982b5d5..e4fdba2 100644 --- a/test/bit.cpp +++ b/test/bit.cpp @@ -296,6 +296,15 @@ TEST_CASE("template bit_mask (large array type)", "[bit]") { CHECK(m == A{0, 0, 0, 1}); } +namespace { +enum struct scoped_enum : std::uint8_t { A, B, C }; +} // namespace + +TEST_CASE("template bit_mask (enum type)", "[bit]") { + constexpr auto m = stdx::bit_mask(); + STATIC_REQUIRE(m == scoped_enum{0xffu}); +} + TEST_CASE("arg bit_mask (whole range)", "[bit]") { constexpr auto m = stdx::bit_mask(63); STATIC_REQUIRE(m == std::numeric_limits::max()); @@ -326,6 +335,11 @@ TEST_CASE("arg bit_mask (single bit)", "[bit]") { STATIC_REQUIRE(std::is_same_v); } +TEST_CASE("arg bit_mask (enum type)", "[bit]") { + constexpr auto m = stdx::bit_mask(1); + STATIC_REQUIRE(m == scoped_enum{0b11u}); +} + TEMPLATE_TEST_CASE("bit_size", "[bit]", std::uint8_t, std::uint16_t, std::uint32_t, std::uint64_t, std::int8_t, std::int16_t, std::int32_t, std::int64_t) {