diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index 7696c059..db9acdb8 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -1194,26 +1194,19 @@ template bool constexpr IsPowerOfTwo(T x) { return (x != 0) && ((x & (x - 1)) == 0); } +template +constexpr unsigned int MostSignificantBit() { + unsigned int result = 0; + long long x = N; + while (x >>= 1) { + ++result; + } + return result; // return N == 0 ? 0 : (sizeof(long long) * 8 - 1 - __builtin_clz(N)); +} // Helper for getting a bytes size which is a power of two. -template +template struct NextPowerOfTwo { - static constexpr int value = Size; -}; -template <> -struct NextPowerOfTwo<3> { - static constexpr int value = 4; -}; -template <> -struct NextPowerOfTwo<5> { - static constexpr int value = 8; -}; -template <> -struct NextPowerOfTwo<6> { - static constexpr int value = 8; -}; -template <> -struct NextPowerOfTwo<7> { - static constexpr int value = 8; + static constexpr int value = IsPowerOfTwo(Size) ? Size : 2 << MostSignificantBit(); }; // Helper for getting a bit representation provided a byte size.