Skip to content

Instantly share code, notes, and snippets.

@WiwilZ
Created February 28, 2024 13:29
Show Gist options
  • Select an option

  • Save WiwilZ/83cbd6cfcde435d1c40bf29d8d4bcb73 to your computer and use it in GitHub Desktop.

Select an option

Save WiwilZ/83cbd6cfcde435d1c40bf29d8d4bcb73 to your computer and use it in GitHub Desktop.
use SSE4.2 intrinsics about string to implement some string functions
#if defined(_MSC_VER) && !defined(__clang__)
#include <intrin.h>
#else
#include <x86intrin.h>
#endif
#include <cstddef>
#include <cstdint>
size_t StringLength(const char* str) noexcept {
size_t length = 0;
const auto zero = _mm_setzero_si128();
auto p = reinterpret_cast<const __m128i*>(str);
for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) {
const auto v = _mm_loadu_si128(p);
const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
length += termination_idx;
if (termination_idx != 16) {
break;
}
}
return length;
}
int StringCompare(const char* s1, const char* s2) noexcept {
const auto zero = _mm_setzero_si128();
auto p1 = reinterpret_cast<const __m128i*>(s1);
auto p2 = reinterpret_cast<const __m128i*>(s2);
while (true) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
const int termination_idx = _mm_cmpistri(zero, v1, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
// \0 in v1 (also the same index in v2)
if (termination_idx < ne_idx) {
return 0;
}
// termination_idx >= ne_idx
if (ne_idx != 16) {
return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx);
}
// termination_idx == ne_idx == 16
++p1;
++p2;
}
}
int StringCompare(const char* s1, size_t len1, const char* s2, size_t len2) noexcept {
auto p1 = reinterpret_cast<const __m128i*>(s1);
auto p2 = reinterpret_cast<const __m128i*>(s2);
size_t length = len1 < len2 ? len1 : len2;
for (size_t i = 0; i < length / 16; i += 16) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (ne_idx != 16) {
return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx);
}
++p1;
++p2;
}
if (const size_t last = length % 16; last != 0) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (ne_idx < last) {
return *(reinterpret_cast<const char*>(p1) + ne_idx) - *(reinterpret_cast<const char*>(p2) + ne_idx);
}
}
return len1 - len2;
}
bool StringEquals(const char* s1, const char* s2) noexcept {
const auto zero = _mm_setzero_si128();
auto p1 = reinterpret_cast<const __m128i*>(s1);
auto p2 = reinterpret_cast<const __m128i*>(s2);
while (true) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
const int termination_idx = _mm_cmpistri(zero, v1, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
// \0 in v1 (also the same index in v2)
if (termination_idx < ne_idx) {
return true;
}
// termination_idx >= ne_idx
if (ne_idx != 16) {
return false;
}
// termination_idx == ne_idx == 16
++p1;
++p2;
}
}
bool StringEquals(const char* s1, size_t len1, const char* s2, size_t len2) noexcept {
auto p1 = reinterpret_cast<const __m128i*>(s1);
auto p2 = reinterpret_cast<const __m128i*>(s2);
size_t length = len1 < len2 ? len1 : len2;
for (size_t i = 0; i < length / 16; i += 16) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const bool notEq = _mm_cmpistrc(v1, v2, _SIDD_MOST_SIGNIFICANT | _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (notEq) {
return false;
}
++p1;
++p2;
}
if (const size_t last = length % 16; last != 0) {
const auto v1 = _mm_loadu_si128(p1);
const auto v2 = _mm_loadu_si128(p2);
const int ne_idx = _mm_cmpistri(v1, v2, _SIDD_NEGATIVE_POLARITY | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (ne_idx < last) {
return false;
}
}
return len1 == len2;
}
bool StringContains(const char* str, char c) noexcept {
const auto zero = _mm_setzero_si128();
const auto pattern = _mm_set1_epi8(c);
for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) {
const auto v = _mm_loadu_si128(p);
const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
// \0 in v
if (termination_idx < match_idx) {
return false;
}
// termination_idx >= match_idx
if (match_idx != 16) {
return true;
}
// termination_idx == match_idx == 16
}
}
bool StringContains(const char* str, size_t length, char c) noexcept {
const auto pattern = _mm_set1_epi8(c);
auto p = reinterpret_cast<const __m128i*>(str);
for (size_t i = 0; i < length / 16; i += 16) {
const auto v = _mm_loadu_si128(p);
const bool contain = _mm_cmpistrc(pattern, v, _SIDD_MOST_SIGNIFICANT | _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (contain) {
return true;
}
++p;
}
if (const size_t last = length % 16; last != 0) {
const auto v = _mm_loadu_si128(p);
const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (match_idx < last) {
return true;
}
}
return false;
}
size_t StringFind(const char* str, char c) noexcept {
const auto zero = _mm_setzero_si128();
const auto pattern = _mm_set1_epi8(c);
for (auto p = reinterpret_cast<const __m128i*>(str);; ++p) {
const auto v = _mm_loadu_si128(p);
const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
const int termination_idx = _mm_cmpistri(zero, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
// \0 in v
if (termination_idx < match_idx) {
return -1;
}
// termination_idx >= match_idx
if (match_idx != 16) {
return reinterpret_cast<const char*>(p) - str + match_idx;
}
// termination_idx == match_idx == 16
}
}
size_t StringFind(const char* str, size_t length, char c) noexcept {
const auto pattern = _mm_set1_epi8(c);
auto p = reinterpret_cast<const __m128i*>(str);
for (size_t i = 0; i < length / 16; i += 16) {
const auto v = _mm_loadu_si128(p);
const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (match_idx != 16) {
return reinterpret_cast<const char*>(p) - str + match_idx;
}
++p;
}
if (const size_t last = length % 16; last != 0) {
const auto v = _mm_loadu_si128(p);
const int match_idx = _mm_cmpistri(pattern, v, _SIDD_CMP_EQUAL_EACH | _SIDD_UBYTE_OPS);
if (match_idx < last) {
return reinterpret_cast<const char*>(p) - str + match_idx;
}
}
return -1;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment