Created
February 28, 2024 13:29
-
-
Save WiwilZ/83cbd6cfcde435d1c40bf29d8d4bcb73 to your computer and use it in GitHub Desktop.
use SSE4.2 intrinsics about string to implement some string functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #if defined(_MSC_VER) && !defined(__clang__) | |
| #include <intrin.h> | |
| #else | |
| #include <x86intrin.h> | |
| #endif | |
| #include <cstddef> | |
| #include <cstdint> | |
| size_t StringLength(const char* str) noexcept { | |
| size_t length = 0; | |
| const auto zero = _mm_setzero_si128(); | |
| auto p = reinterpret_cast<const __m128i*>(str); | |
| for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| length += termination_idx; | |
| if (termination_idx != 16) { | |
| break; | |
| } | |
| } | |
| return length; | |
| } | |
| int StringCompare(const char* s1, const char* s2) noexcept { | |
| const auto zero = _mm_setzero_si128(); | |
| auto p1 = reinterpret_cast<const __m128i*>(s1); | |
| auto p2 = reinterpret_cast<const __m128i*>(s2); | |
| while (true) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| const int termination_idx = _mm_cmpistri(zero, v1, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| // \0 in v1 (also the same index in v2) | |
| if (termination_idx < ne_idx) { | |
| return 0; | |
| } | |
| // termination_idx >= ne_idx | |
| if (ne_idx != 16) { | |
| return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx); | |
| } | |
| // termination_idx == ne_idx == 16 | |
| ++p1; | |
| ++p2; | |
| } | |
| } | |
| int StringCompare(const char* s1, size_t len1, const char* s2, size_t len2) noexcept { | |
| auto p1 = reinterpret_cast<const __m128i*>(s1); | |
| auto p2 = reinterpret_cast<const __m128i*>(s2); | |
| size_t length = len1 < len2 ? len1 : len2; | |
| for (size_t i = 0; i < length / 16; i += 16) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (ne_idx != 16) { | |
| return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx); | |
| } | |
| ++p1; | |
| ++p2; | |
| } | |
| if (const size_t last = length % 16; last != 0) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (ne_idx < last) { | |
| return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx); | |
| } | |
| } | |
| return len1 - len2; | |
| } | |
| bool StringEquals(const char* s1, const char* s2) noexcept { | |
| const auto zero = _mm_setzero_si128(); | |
| auto p1 = reinterpret_cast<const __m128i*>(s1); | |
| auto p2 = reinterpret_cast<const __m128i*>(s2); | |
| while (true) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| const int termination_idx = _mm_cmpistri(zero, v1, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| // \0 in v1 (also the same index in v2) | |
| if (termination_idx < ne_idx) { | |
| return true; | |
| } | |
| // termination_idx >= ne_idx | |
| if (ne_idx != 16) { | |
| return false; | |
| } | |
| // termination_idx == ne_idx == 16 | |
| ++p1; | |
| ++p2; | |
| } | |
| } | |
| bool StringEquals(const char* s1, size_t len1, const char* s2, size_t len2) noexcept { | |
| auto p1 = reinterpret_cast<const __m128i*>(s1); | |
| auto p2 = reinterpret_cast<const __m128i*>(s2); | |
| size_t length = len1 < len2 ? len1 : len2; | |
| for (size_t i = 0; i < length / 16; i += 16) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const bool notEq = _mm_cmpistrc(v1, v2, _SIDD_MOST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (notEq) { | |
| return false; | |
| } | |
| ++p1; | |
| ++p2; | |
| } | |
| if (const size_t last = length % 16; last != 0) { | |
| const auto v1 = _mm_loadu_si128(p1); | |
| const auto v2 = _mm_loadu_si128(p2); | |
| const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (ne_idx < last) { | |
| return false; | |
| } | |
| } | |
| return len1 == len2; | |
| } | |
| bool StringContains(const char* str, char c) noexcept { | |
| const auto zero = _mm_setzero_si128(); | |
| const auto pattern = _mm_set1_epi8(c); | |
| for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| // \0 in v | |
| if (termination_idx < match_idx) { | |
| return false; | |
| } | |
| // termination_idx >= match_idx | |
| if (match_idx != 16) { | |
| return true; | |
| } | |
| // termination_idx == match_idx == 16 | |
| } | |
| } | |
| bool StringContains(const char* str, size_t length, char c) noexcept { | |
| const auto pattern = _mm_set1_epi8(c); | |
| auto p = reinterpret_cast<const __m128i*>(str); | |
| for (size_t i = 0; i < length / 16; i += 16) { | |
| const auto v = _mm_loadu_si128(p); | |
| const bool contain = _mm_cmpistrc(pattern, v, _SIDD_MOST_SIGNIFICANT | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (contain) { | |
| return true; | |
| } | |
| ++p; | |
| } | |
| if (const size_t last = length % 16; last != 0) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (match_idx < last) { | |
| return true; | |
| } | |
| } | |
| return false; | |
| } | |
| size_t StringFind(const char* str, char c) noexcept { | |
| const auto zero = _mm_setzero_si128(); | |
| const auto pattern = _mm_set1_epi8(c); | |
| for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| // \0 in v | |
| if (termination_idx < match_idx) { | |
| return -1; | |
| } | |
| // termination_idx >= match_idx | |
| if (match_idx != 16) { | |
| return reinterpret_cast<const char*>(p) - str + match_idx; | |
| } | |
| // termination_idx == match_idx == 16 | |
| } | |
| } | |
| size_t StringFind(const char* str, size_t length, char c) noexcept { | |
| const auto pattern = _mm_set1_epi8(c); | |
| auto p = reinterpret_cast<const __m128i*>(str); | |
| for (size_t i = 0; i < length / 16; i += 16) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (match_idx != 16) { | |
| return reinterpret_cast<const char*>(p) - str + match_idx; | |
| } | |
| ++p; | |
| } | |
| if (const size_t last = length % 16; last != 0) { | |
| const auto v = _mm_loadu_si128(p); | |
| const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS); | |
| if (match_idx < last) { | |
| return reinterpret_cast<const char*>(p) - str + match_idx; | |
| } | |
| } | |
| return -1; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment