Skip to content

Instantly share code, notes, and snippets.

@WiwilZ
Last active April 16, 2024 10:45
Show Gist options
  • Select an option

  • Save WiwilZ/c636d5d202d5793f9d1f938728106b3b to your computer and use it in GitHub Desktop.

Select an option

Save WiwilZ/c636d5d202d5793f9d1f938728106b3b to your computer and use it in GitHub Desktop.
Check whether an integer array is (strict) increasing or (strict) decreasing optimized with SIMD Instructions
#if defined(_MSC_VER) && !defined(__clang__) && (defined(_M_IX86) || defined(_M_X64) && !defined(_M_ARM64EC))
#define __SSE4_1__
#define __SSE4_2__
#endif
#if defined(__AVX512BW__) || defined(__AVX2__) || defined(__SSE4_2__) || defined(__SSE4_1__)
#if defined(_MSC_VER) && !defined(__clang__)
#include <intrin.h>
#elif defined(__AVX512BW__) || defined(__AVX2__)
#include <immintrin.h>
#else
#include <emmintrin.h>
#include <smmintrin.h>
#ifdef __SSE4_2__
#include <nmmintrin.h> // _mm_cmpgt_epi64
#endif
#endif
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace detail {
enum class SortedWay {
Increasing,
StrictIncreasing,
Decreasing,
StrictDecreasing
};
template <typename T>
concept Integer = std::is_same_v<T, int8_t> || std::is_same_v<T, int16_t> || std::is_same_v<T, int32_t>
#ifdef __AVX512BW__
|| std::is_same_v<T, int64_t> || std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> || std::is_same_v<T, uint32_t> || std::is_same_v<T, uint64_t>
#elif defined(__SSE4_2__)
|| std::is_same_v<T, int64_t>
#endif
;
}
#ifdef __AVX512BW__
namespace detail {
template <SortedWay Way, Integer T>
auto Compare(__m512i a, __m512i b) noexcept {
const int cmp = Way == SortedWay::Increasing ? _MM_CMPINT_GT :
Way == SortedWay::StrictIncreasing ? _MM_CMPINT_GE :
Way == SortedWay::Decreasing ? _MM_CMPINT_LT : _MM_CMPINT_LE;
if constexpr (std::is_same_v<T, int8_t>) {
return _mm512_cmp_epi8_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, int16_t>) {
return _mm512_cmp_epi16_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, int32_t>) {
return _mm512_cmp_epi32_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, int64_t>) {
return _mm512_cmp_epi64_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, uint8_t>) {
return _mm512_cmp_epu8_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, uint16_t>) {
return _mm512_cmp_epu16_mask(a, b, cmp);
} else if constexpr (std::is_same_v<T, uint32_t>) {
return _mm512_cmp_epu32_mask(a, b, cmp);
} else {
return _mm512_cmp_epu64_mask(a, b, cmp);
}
}
template <SortedWay Way, Integer T>
bool IsSorted(const T* const data, size_t length) noexcept {
constexpr size_t ChunkSize = sizeof(__m512i) / sizeof(T);
const T* buffer = data;
while (length > ChunkSize * 8) {
const auto x0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 0);
const auto x1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 1);
const auto x2 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 2);
const auto x3 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 3);
const auto x4 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 4);
const auto x5 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 5);
const auto x6 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 6);
const auto x7 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer) + 7);
const auto y0 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 0);
const auto y1 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 1);
const auto y2 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 2);
const auto y3 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 3);
const auto y4 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 4);
const auto y5 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 5);
const auto y6 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 6);
const auto y7 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(buffer + 1) + 7);
const auto mask0 = Compare<Way, T>(x0, y0);
const auto mask1 = Compare<Way, T>(x1, y1);
const auto mask2 = Compare<Way, T>(x2, y2);
const auto mask3 = Compare<Way, T>(x3, y3);
const auto mask4 = Compare<Way, T>(x4, y4);
const auto mask5 = Compare<Way, T>(x5, y5);
const auto mask6 = Compare<Way, T>(x6, y6);
const auto mask7 = Compare<Way, T>(x7, y7);
if (mask0 | mask1 | mask2 | mask3 | mask4 | mask5 | mask6 | mask7) {
return false;
}
buffer += ChunkSize * 8;
length -= ChunkSize * 8;
}
while (length > ChunkSize) {
const __m512i a = _mm512_loadu_si512(buffer);
const __m512i b = _mm512_loadu_si512(buffer + 1);
const auto mask = Compare<Way, T>(a, b);
if (mask) {
return false;
}
buffer += ChunkSize;
length -= ChunkSize;
}
if (length > 1) {
const __m512i a = _mm512_loadu_si512(buffer);
const __m512i b = _mm512_loadu_si512(buffer + 1);
const auto mask = Compare<Way, T>(a, b);
const decltype(mask) mask2 = (static_cast<decltype(mask)>(1) << (length - 1)) - 1;
if (mask & mask2) {
return false;
}
}
return true;
}
} // namespace detail
#elif defined(__AVX2__)
namespace detail {
template <Integer T>
__m256i EqualTo(__m256i a, __m256i b) noexcept {
if constexpr (std::is_same_v<T, int8_t>) {
return _mm256_cmpeq_epi8(a, b);
} else if constexpr (std::is_same_v<T, int16_t>) {
return _mm256_cmpeq_epi16(a, b);
} else if constexpr (std::is_same_v<T, int32_t>) {
return _mm256_cmpeq_epi32(a, b);
} else {
return _mm256_cmpeq_epi64(a, b);
}
}
template <Integer T>
__m256i GreaterThan(__m256i a, __m256i b) noexcept {
if constexpr (std::is_same_v<T, int8_t>) {
return _mm256_cmpgt_epi8(a, b);
} else if constexpr (std::is_same_v<T, int16_t>) {
return _mm256_cmpgt_epi16(a, b);
} else if constexpr (std::is_same_v<T, int32_t>) {
return _mm256_cmpgt_epi32(a, b);
} else {
return _mm256_cmpgt_epi64(a, b);
}
}
template <Integer T>
__m256i GreaterEqual(__m256i a, __m256i b) noexcept {
const auto gt = GreaterThan<T>(a, b);
const auto eq = EqualTo<T>(a, b);
return _mm256_or_si256(gt, eq);
}
template <SortedWay Way, Integer T>
auto Compare(__m256i a, __m256i b) noexcept {
if constexpr (Way == SortedWay::Increasing) {
// require b >= a or a <= b => all bits in a > b to be 0
return GreaterThan<T>(a, b);
} else if constexpr (Way == SortedWay::StrictIncreasing) {
// require b > a => all bits in b > a to be 1
return GreaterThan<T>(b, a);
} else if constexpr (Way == SortedWay::Decreasing) {
// require b <= a or a >= b => all bits in a >= b to be 1
return GreaterEqual<T>(a, b);
} else {
// require b < a => all bits in b >= a to be 0
return GreaterEqual<T>(b, a);
}
}
template <SortedWay Way>
auto MixMask(__m256i mask0, __m256i mask1, __m256i mask2, __m256i mask3) {
if constexpr (Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing) {
// all bits should be 0
return _mm256_or_si256(_mm256_or_si256(_mm256_or_si256(mask0, mask1), mask2), mask3);
} else {
// all bits should be 1
return _mm256_and_si256(_mm256_and_si256(_mm256_and_si256(mask0, mask1), mask2), mask3);
}
}
template <SortedWay Way>
auto TestMask(__m256i mask) {
if constexpr (Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing) {
// all bits should be 0
return _mm256_testz_si256(mask, mask);
} else {
// all bits should be 1
return _mm256_testc_si256(mask, _mm256_set1_epi8(0xFF));
}
}
template <SortedWay Way, Integer T>
bool IsSorted(const T* const data, size_t length) noexcept {
constexpr size_t ChunkSize = sizeof(__m256i) / sizeof(T);
const T* buffer = data;
while (length > ChunkSize * 4) {
const auto x0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 0);
const auto x1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 1);
const auto x2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 2);
const auto x3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer) + 3);
const auto y0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 0);
const auto y1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 1);
const auto y2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 2);
const auto y3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1) + 3);
const auto mask0 = Compare<Way, T>(x0, y0);
const auto mask1 = Compare<Way, T>(x1, y1);
const auto mask2 = Compare<Way, T>(x2, y2);
const auto mask3 = Compare<Way, T>(x3, y3);
const auto mask = MixMask<Way>(mask0, mask1, mask2, mask3);
if (!TestMask<Way>(mask)) {
return false;
}
buffer += ChunkSize * 4;
length -= ChunkSize * 4;
}
while (length > ChunkSize) {
const auto a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer));
const auto b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1));
const auto mask = Compare<Way, T>(a, b);
if (!TestMask<Way>(mask)) {
return false;
}
buffer += ChunkSize;
length -= ChunkSize;
}
if (length > 1) {
const auto a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer));
const auto b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(buffer + 1));
const auto mask = Compare<Way, T>(a, b);
const int mask2 = (1 << ((length - 1) * sizeof(T))) - 1;
const int value = Way == SortedWay::Increasing || Way == SortedWay::StrictDecreasing ? 0 : mask2;
if ((_mm256_movemask_epi8(mask) & mask2) != value) {
return false;
}
}
return true;
}
} // namespace detail
#else
namespace detail {
template <Integer T>
__m128i GreaterThan(__m128i a, __m128i b) noexcept {
if constexpr (std::is_same_v<T, int8_t>) {
return _mm_cmpgt_epi8(a, b);
} else if constexpr (std::is_same_v<T, int16_t>) {
return _mm_cmpgt_epi16(a, b);
} else if constexpr (std::is_same_v<T, int32_t>) {
return _mm_cmpgt_epi32(a, b);
} else {
#ifdef __SSE4_2__
return _mm_cmpgt_epi64(a, b);
#endif
}
}
#ifdef __SSE4_2__
__m128i GreaterEquali64(__m128i a, __m128i b) noexcept {
const auto gt = _mm_cmpgt_epi64(a, b);
const auto eq = _mm_cmpeq_epi64(a, b);
return _mm_or_si128(gt, eq);
}
#endif
template <Integer T>
__m128i LessThan(__m128i a, __m128i b) noexcept {
if constexpr (std::is_same_v<T, int8_t>) {
return _mm_cmplt_epi8(a, b);
} else if constexpr (std::is_same_v<T, int16_t>) {
return _mm_cmplt_epi16(a, b);
} else if constexpr (std::is_same_v<T, int32_t>) {
return _mm_cmplt_epi32(a, b);
}
}
template <SortedWay Way, Integer T>
auto Compare(__m128i a, __m128i b) noexcept {
if constexpr (Way == SortedWay::Increasing) {
// require b >= a or a <= b => all bits in a > b to be 0
return GreaterThan<T>(a, b);
} else if constexpr (Way == SortedWay::StrictIncreasing) {
// require b > a => all bits in b > a to be 1
return GreaterThan<T>(b, a);
} else if constexpr (Way == SortedWay::Decreasing) {
// require b <= a or a >= b
if constexpr (std::is_same_v<T, int64_t>) {
// all bits in a >= b to be 1
return GreaterEquali64(a, b);
} else {
// all bits in a < b to be 0
return LessThan<T>(a, b);
}
} else {
// require b < a
if constexpr (std::is_same_v<T, int64_t>) {
// all bits in b >= a to be 0
return GreaterEquali64(b, a);
} else {
// all bits in b < a to be 1
return LessThan<T>(b, a);
}
}
}
template <SortedWay Way, Integer T>
auto MixMask(__m128i mask0, __m128i mask1, __m128i mask2, __m128i mask3) {
if constexpr (Way == SortedWay::Increasing ||
Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> ||
Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t>) {
// all bits should be 0
return _mm_or_si128(_mm_or_si128(_mm_or_si128(mask0, mask1), mask2), mask3);
} else {
// all bits should be 1
return _mm_and_si128(_mm_and_si128(_mm_and_si128(mask0, mask1), mask2), mask3);
}
}
template <SortedWay Way, Integer T>
auto TestMask(__m128i mask) {
if constexpr (Way == SortedWay::Increasing ||
Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> ||
Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t>) {
// all bits should be 0
return _mm_testz_si128(mask, mask);
} else {
// all bits should be 1
return _mm_testc_si128(mask, _mm_set1_epi8(0xFF));
}
}
template <SortedWay Way, Integer T>
bool IsSorted(const T* const data, size_t length) noexcept {
constexpr size_t ChunkSize = sizeof(__m128i) / sizeof(T);
const T* buffer = data;
while (length > ChunkSize * 4) {
const auto x0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 0);
const auto x1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 1);
const auto x2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 2);
const auto x3 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer) + 3);
const auto y0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 0);
const auto y1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 1);
const auto y2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 2);
const auto y3 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1) + 3);
const auto mask0 = Compare<Way, T>(x0, y0);
const auto mask1 = Compare<Way, T>(x1, y1);
const auto mask2 = Compare<Way, T>(x2, y2);
const auto mask3 = Compare<Way, T>(x3, y3);
const auto mask = MixMask<Way, T>(mask0, mask1, mask2, mask3);
if (!TestMask<Way, T>(mask)) {
return false;
}
buffer += ChunkSize * 4;
length -= ChunkSize * 4;
}
while (length > ChunkSize) {
const auto a = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer));
const auto b = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1));
const auto mask = Compare<Way, T>(a, b);
if (!TestMask<Way, T>(mask)) {
return false;
}
buffer += ChunkSize;
length -= ChunkSize;
}
if (length > 1) {
const auto a = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer));
const auto b = _mm_loadu_si128(reinterpret_cast<const __m128i*>(buffer + 1));
const auto mask = Compare<Way, T>(a, b);
const int mask2 = (1 << ((length - 1) * sizeof(T))) - 1;
const int value = Way == SortedWay::Increasing ||
Way == SortedWay::Decreasing && !std::is_same_v<T, int64_t> ||
Way == SortedWay::StrictDecreasing && std::is_same_v<T, int64_t> ? 0 : mask2;
if ((_mm_movemask_epi8(mask) & mask2) != value) {
return false;
}
}
return true;
}
} // namespace detail
#endif
template <detail::Integer T>
bool IsIncreasing(const T* const data, size_t length) noexcept {
return detail::IsSorted<detail::SortedWay::Increasing>(data, length);
}
template <detail::Integer T>
bool IsStrictIncreasing(const T* const data, size_t length) noexcept {
return detail::IsSorted<detail::SortedWay::StrictIncreasing>(data, length);
}
template <detail::Integer T>
bool IsDecreasing(const T* const data, size_t length) noexcept {
return detail::IsSorted<detail::SortedWay::Decreasing>(data, length);
}
template <detail::Integer T>
bool IsStrictDecreasing(const T* const data, size_t length) noexcept {
return detail::IsSorted<detail::SortedWay::StrictDecreasing>(data, length);
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment