Last active
April 16, 2024 10:45
-
-
Save WiwilZ/c636d5d202d5793f9d1f938728106b3b to your computer and use it in GitHub Desktop.
Check whether an integer array is (strict) increasing or (strict) decreasing optimized with SIMD Instructions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #if defined(_MSC_VER) && !defined(__clang__) && (defined(_M_IX86) || defined(_M_X64) && !defined(_M_ARM64EC)) | |
| #define __SSE4_1__ | |
| #define __SSE4_2__ | |
| #endif | |
| #if defined(__AVX512BW__) || defined(__AVX2__) || defined(__SSE4_2__) || defined(__SSE4_1__) | |
| #if defined(_MSC_VER) && !defined(__clang__) | |
| #include <intrin.h> | |
| #elif defined(__AVX512BW__) || defined(__AVX2__) | |
| #include <immintrin.h> | |
| #else | |
| #include <emmintrin.h> | |
| #include <smmintrin.h> | |
| #ifdef __SSE4_2__ | |
| #include <nmmintrin.h> // _mm_cmpgt_epi64 | |
| #endif | |
| #endif | |
| #include <cstddef> | |
| #include <cstdint> | |
| #include <type_traits> | |
| namespace detail { | |
| enum class SortedWay { | |
| Increasing, | |
| StrictIncreasing, | |
| Decreasing, | |
| StrictDecreasing | |
| }; | |
| template <typename T> | |
| concept Integer = std::is_same_v<T, int8_t> || std::is_same_v<T, int16_t> || std::is_same_v<T, int32_t> | |
| #ifdef __AVX512BW__ | |
| || std::is_same_v<T, int64_t> || std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> || std::is_same_v<T, uint64_t> | |
| #elif defined(__SSE4_2__) | |
| || std::is_same_v<T, int64_t> | |
| #endif | |
| ; | |
| } | |
| #ifdef __AVX512BW__ | |
| namespace detail { | |
| template <SortedWay Way, Integer T> | |
| auto Compare(__m512i a, __m512i b) noexcept { | |
| const int cmp = Way == SortedWay::Increasing ? _MM_CMPINT_GT : | |
| Way == SortedWay::StrictIncreasing ? _MM_CMPINT_GE : | |
| Way == SortedWay::Decreasing ? _MM_CMPINT_LT : _MM_CMPINT_LE; | |
| if constexpr (std::is_same_v<T, int8_t>) { | |
| return _mm512_cmp_epi8_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, int16_t>) { | |
| return _mm512_cmp_epi16_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, int32_t>) { | |
| return _mm512_cmp_epi32_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, int64_t>) { | |
| return _mm512_cmp_epi64_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, uint8_t>) { | |
| return _mm512_cmp_epu8_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, uint16_t>) { | |
| return _mm512_cmp_epu16_mask(a, b, cmp); | |
| } else if constexpr (std::is_same_v<T, uint32_t>) { | |
| return _mm512_cmp_epu32_mask(a, b, cmp); | |
| } else { | |
| return _mm512_cmp_epu64_mask(a, b, cmp); | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| bool IsSorted(const T* const data, size_t length) noexcept { | |
| constexpr size_t ChunkSize = sizeof(__m512i) / sizeof(T); | |
| const T* buffer = data; | |
| while (length > ChunkSize * 8) { | |
| const auto x0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 0); | |
| const auto x1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 1); | |
| const auto x2 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 2); | |
| const auto x3 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 3); | |
| const auto x4 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 4); | |
| const auto x5 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 5); | |
| const auto x6 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 6); | |
| const auto x7 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 7); | |
| const auto y0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 0); | |
| const auto y1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 1); | |
| const auto y2 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 2); | |
| const auto y3 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 3); | |
| const auto y4 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 4); | |
| const auto y5 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 5); | |
| const auto y6 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 6); | |
| const auto y7 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 7); | |
| const auto mask0 = Compare<Way, T>(x0, y0); | |
| const auto mask1 = Compare<Way, T>(x1, y1); | |
| const auto mask2 = Compare<Way, T>(x2, y2); | |
| const auto mask3 = Compare<Way, T>(x3, y3); | |
| const auto mask4 = Compare<Way, T>(x4, y4); | |
| const auto mask5 = Compare<Way, T>(x5, y5); | |
| const auto mask6 = Compare<Way, T>(x6, y6); | |
| const auto mask7 = Compare<Way, T>(x7, y7); | |
| if (mask0 | mask1 | mask2 | mask3 | mask4 | mask5 | mask6 | mask7) { | |
| return false; | |
| } | |
| buffer += ChunkSize * 8; | |
| length -= ChunkSize * 8; | |
| } | |
| while (length > ChunkSize) { | |
| const __m512i a = _mm512_loadu_si512(buffer); | |
| const __m512i b = _mm512_loadu_si512(buffer + 1); | |
| const auto mask = Compare<Way, T>(a, b); | |
| if (mask) { | |
| return false; | |
| } | |
| buffer += ChunkSize; | |
| length -= ChunkSize; | |
| } | |
| if (length > 1) { | |
| const __m512i a = _mm512_loadu_si512(buffer); | |
| const __m512i b = _mm512_loadu_si512(buffer + 1); | |
| const auto mask = Compare<Way, T>(a, b); | |
| const decltype(mask) mask2 = (static_cast<decltype(mask)>(1) << (length - 1)) - 1; | |
| if (mask & mask2) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| } // namespace detail | |
| #elif defined(__AVX2__) | |
| namespace detail { | |
| template <Integer T> | |
| __m256i EqualTo(__m256i a, __m256i b) noexcept { | |
| if constexpr (std::is_same_v<T, int8_t>) { | |
| return _mm256_cmpeq_epi8(a, b); | |
| } else if constexpr (std::is_same_v<T, int16_t>) { | |
| return _mm256_cmpeq_epi16(a, b); | |
| } else if constexpr (std::is_same_v<T, int32_t>) { | |
| return _mm256_cmpeq_epi32(a, b); | |
| } else { | |
| return _mm256_cmpeq_epi64(a, b); | |
| } | |
| } | |
| template <Integer T> | |
| __m256i GreaterThan(__m256i a, __m256i b) noexcept { | |
| if constexpr (std::is_same_v<T, int8_t>) { | |
| return _mm256_cmpgt_epi8(a, b); | |
| } else if constexpr (std::is_same_v<T, int16_t>) { | |
| return _mm256_cmpgt_epi16(a, b); | |
| } else if constexpr (std::is_same_v<T, int32_t>) { | |
| return _mm256_cmpgt_epi32(a, b); | |
| } else { | |
| return _mm256_cmpgt_epi64(a, b); | |
| } | |
| } | |
| template <Integer T> | |
| __m256i GreaterEqual(__m256i a, __m256i b) noexcept { | |
| const auto gt = GreaterThan<T>(a, b); | |
| const auto eq = EqualTo<T>(a, b); | |
| return _mm256_or_si256(gt, eq); | |
| } | |
| template <SortedWay Way, Integer T> | |
| auto Compare(__m256i a, __m256i b) noexcept { | |
| if constexpr (Way == SortedWay::Increasing) { | |
| // require b >= a or a <= b => all bits in a > b to be 0 | |
| return GreaterThan<T>(a, b); | |
| } else if constexpr (Way == SortedWay::StrictIncreasing) { | |
| // require b > a => all bits in b > a to be 1 | |
| return GreaterThan<T>(b, a); | |
| } else if constexpr (Way == SortedWay::Decreasing) { | |
| // require b <= a or a >= b => all bits in a >= b to be 1 | |
| return GreaterEqual<T>(a, b); | |
| } else { | |
| // require b < a => all bits in b >= a to be 0 | |
| return GreaterEqual<T>(b, a); | |
| } | |
| } | |
| template <SortedWay Way> | |
| auto MixMask(__m256i mask0, __m256i mask1, __m256i mask2, __m256i mask3) { | |
| if constexpr (Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing) { | |
| // all bits should be 0 | |
| return _mm256_or_si256(_mm256_or_si256(_mm256_or_si256(mask0, mask1), mask2), mask3); | |
| } else { | |
| // all bits should be 1 | |
| return _mm256_and_si256(_mm256_and_si256(_mm256_and_si256(mask0, mask1), mask2), mask3); | |
| } | |
| } | |
| template <SortedWay Way> | |
| auto TestMask(__m256i mask) { | |
| if constexpr (Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing) { | |
| // all bits should be 0 | |
| return _mm256_testz_si256(mask, mask); | |
| } else { | |
| // all bits should be 1 | |
| return _mm256_testc_si256(mask, _mm256_set1_epi8(0xFF)); | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| bool IsSorted(const T* const data, size_t length) noexcept { | |
| constexpr size_t ChunkSize = sizeof(__m256i) / sizeof(T); | |
| const T* buffer = data; | |
| while (length > ChunkSize * 4) { | |
| const auto x0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 0); | |
| const auto x1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 1); | |
| const auto x2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 2); | |
| const auto x3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 3); | |
| const auto y0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 0); | |
| const auto y1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 1); | |
| const auto y2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 2); | |
| const auto y3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 3); | |
| const auto mask0 = Compare<Way, T>(x0, y0); | |
| const auto mask1 = Compare<Way, T>(x1, y1); | |
| const auto mask2 = Compare<Way, T>(x2, y2); | |
| const auto mask3 = Compare<Way, T>(x3, y3); | |
| const auto mask = MixMask<Way>(mask0, mask1, mask2, mask3); | |
| if (!TestMask<Way>(mask)) { | |
| return false; | |
| } | |
| buffer += ChunkSize * 4; | |
| length -= ChunkSize * 4; | |
| } | |
| while (length > ChunkSize) { | |
| const auto a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer)); | |
| const auto b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1)); | |
| const auto mask = Compare<Way, T>(a, b); | |
| if (!TestMask<Way>(mask)) { | |
| return false; | |
| } | |
| buffer += ChunkSize; | |
| length -= ChunkSize; | |
| } | |
| if (length > 1) { | |
| const auto a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer)); | |
| const auto b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1)); | |
| const auto mask = Compare<Way, T>(a, b); | |
| const int mask2 = (1 << ((length - 1) * sizeof(T))) - 1; | |
| const int value = Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing ? 0 : mask2; | |
| if ((_mm256_movemask_epi8(mask) & mask2) != value) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| } // namespace detail | |
| #else | |
| namespace detail { | |
| template <Integer T> | |
| __m128i GreaterThan(__m128i a, __m128i b) noexcept { | |
| if constexpr (std::is_same_v<T, int8_t>) { | |
| return _mm_cmpgt_epi8(a, b); | |
| } else if constexpr (std::is_same_v<T, int16_t>) { | |
| return _mm_cmpgt_epi16(a, b); | |
| } else if constexpr (std::is_same_v<T, int32_t>) { | |
| return _mm_cmpgt_epi32(a, b); | |
| } else { | |
| #ifdef __SSE4_2__ | |
| return _mm_cmpgt_epi64(a, b); | |
| #endif | |
| } | |
| } | |
| #ifdef __SSE4_2__ | |
| __m128i GreaterEquali64(__m128i a, __m128i b) noexcept { | |
| const auto gt = _mm_cmpgt_epi64(a, b); | |
| const auto eq = _mm_cmpeq_epi64(a, b); | |
| return _mm_or_si128(gt, eq); | |
| } | |
| #endif | |
| template <Integer T> | |
| __m128i LessThan(__m128i a, __m128i b) noexcept { | |
| if constexpr (std::is_same_v<T, int8_t>) { | |
| return _mm_cmplt_epi8(a, b); | |
| } else if constexpr (std::is_same_v<T, int16_t>) { | |
| return _mm_cmplt_epi16(a, b); | |
| } else if constexpr (std::is_same_v<T, int32_t>) { | |
| return _mm_cmplt_epi32(a, b); | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| auto Compare(__m128i a, __m128i b) noexcept { | |
| if constexpr (Way == SortedWay::Increasing) { | |
| // require b >= a or a <= b => all bits in a > b to be 0 | |
| return GreaterThan<T>(a, b); | |
| } else if constexpr (Way == SortedWay::StrictIncreasing) { | |
| // require b > a => all bits in b > a to be 1 | |
| return GreaterThan<T>(b, a); | |
| } else if constexpr (Way == SortedWay::Decreasing) { | |
| // require b <= a or a >= b | |
| if constexpr (std::is_same_v<T, int64_t>) { | |
| // all bits in a >= b to be 1 | |
| return GreaterEquali64(a, b); | |
| } else { | |
| // all bits in a < b to be 0 | |
| return LessThan<T>(a, b); | |
| } | |
| } else { | |
| // require b < a | |
| if constexpr (std::is_same_v<T, int64_t>) { | |
| // all bits in b >= a to be 0 | |
| return GreaterEquali64(b, a); | |
| } else { | |
| // all bits in b < a to be 1 | |
| return LessThan<T>(b, a); | |
| } | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| auto MixMask(__m128i mask0, __m128i mask1, __m128i mask2, __m128i mask3) { | |
| if constexpr (Way == SortedWay::Increasing || | |
| Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> || | |
| Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t>) { | |
| // all bits should be 0 | |
| return _mm_or_si128(_mm_or_si128(_mm_or_si128(mask0, mask1), mask2), mask3); | |
| } else { | |
| // all bits should be 1 | |
| return _mm_and_si128(_mm_and_si128(_mm_and_si128(mask0, mask1), mask2), mask3); | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| auto TestMask(__m128i mask) { | |
| if constexpr (Way == SortedWay::Increasing || | |
| Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> || | |
| Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t>) { | |
| // all bits should be 0 | |
| return _mm_testz_si128(mask, mask); | |
| } else { | |
| // all bits should be 1 | |
| return _mm_testc_si128(mask, _mm_set1_epi8(0xFF)); | |
| } | |
| } | |
| template <SortedWay Way, Integer T> | |
| bool IsSorted(const T* const data, size_t length) noexcept { | |
| constexpr size_t ChunkSize = sizeof(__m128i) / sizeof(T); | |
| const T* buffer = data; | |
| while (length > ChunkSize * 4) { | |
| const auto x0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 0); | |
| const auto x1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 1); | |
| const auto x2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 2); | |
| const auto x3 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 3); | |
| const auto y0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 0); | |
| const auto y1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 1); | |
| const auto y2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 2); | |
| const auto y3 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 3); | |
| const auto mask0 = Compare<Way, T>(x0, y0); | |
| const auto mask1 = Compare<Way, T>(x1, y1); | |
| const auto mask2 = Compare<Way, T>(x2, y2); | |
| const auto mask3 = Compare<Way, T>(x3, y3); | |
| const auto mask = MixMask<Way, T>(mask0, mask1, mask2, mask3); | |
| if (!TestMask<Way, T>(mask)) { | |
| return false; | |
| } | |
| buffer += ChunkSize * 4; | |
| length -= ChunkSize * 4; | |
| } | |
| while (length > ChunkSize) { | |
| const auto a = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer)); | |
| const auto b = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1)); | |
| const auto mask = Compare<Way, T>(a, b); | |
| if (!TestMask<Way, T>(mask)) { | |
| return false; | |
| } | |
| buffer += ChunkSize; | |
| length -= ChunkSize; | |
| } | |
| if (length > 1) { | |
| const auto a = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer)); | |
| const auto b = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1)); | |
| const auto mask = Compare<Way, T>(a, b); | |
| const int mask2 = (1 << ((length - 1) * sizeof(T))) - 1; | |
| const int value = Way == SortedWay::Increasing || | |
| Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> || | |
| Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t> ? 0 : mask2; | |
| if ((_mm_movemask_epi8(mask) & mask2) != value) { | |
| return false; | |
| } | |
| } | |
| return true; | |
| } | |
| } // namespace detail | |
| #endif | |
| template <detail::Integer T> | |
| bool IsIncreasing(const T* const data, size_t length) noexcept { | |
| return detail::IsSorted<detail::SortedWay::Increasing>(data, length); | |
| } | |
| template <detail::Integer T> | |
| bool IsStrictIncreasing(const T* const data, size_t length) noexcept { | |
| return detail::IsSorted<detail::SortedWay::StrictIncreasing>(data, length); | |
| } | |
| template <detail::Integer T> | |
| bool IsDecreasing(const T* const data, size_t length) noexcept { | |
| return detail::IsSorted<detail::SortedWay::Decreasing>(data, length); | |
| } | |
| template <detail::Integer T> | |
| bool IsStrictDecreasing(const T* const data, size_t length) noexcept { | |
| return detail::IsSorted<detail::SortedWay::StrictDecreasing>(data, length); | |
| } | |
| #endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment