Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
arch:
- i686
- x86_64
- aarch64
features:
- default
- runtime-dispatch-simd
Expand Down
20 changes: 15 additions & 5 deletions src/integer_simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize {
ptr::copy_nonoverlapping(
bytes.as_ptr().add(offset),
&mut output as *mut usize as *mut u8,
mem::size_of::<usize>()
mem::size_of::<usize>(),
);
output
}
Expand Down Expand Up @@ -65,11 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 8
let mut counts = 0;
for i in 0..(haystack.len() - offset) / chunksize {
counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles);
counts += bytewise_equal(
usize_load_unchecked(haystack, offset + i * chunksize),
needles,
);
}
if haystack.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8)));
counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask;
counts += bytewise_equal(
usize_load_unchecked(haystack, haystack.len() - chunksize),
needles,
) & mask;
}
count += sum_usize(counts);

Expand Down Expand Up @@ -98,11 +104,15 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// 8
let mut counts = 0;
for i in 0..(utf8_chars.len() - offset) / chunksize {
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
counts +=
is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
}
if utf8_chars.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8)));
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask;
counts += is_leading_utf8_byte(usize_load_unchecked(
utf8_chars,
utf8_chars.len() - chunksize,
)) & mask;
}
count += sum_usize(counts);

Expand Down
35 changes: 29 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
//! still on small strings.

#![deny(missing_docs)]

#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)]

#[cfg(not(feature = "runtime-dispatch-simd"))]
Expand All @@ -45,7 +44,11 @@ pub use naive::*;
mod integer_simd;

#[cfg(any(
all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")),
all(
feature = "runtime-dispatch-simd",
any(target_arch = "x86", target_arch = "x86_64")
),
target_arch = "aarch64",
feature = "generic-simd"
))]
mod simd;
Expand All @@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_avx2::chunk_count(haystack, needle);
}
}
}

Expand All @@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_sse2::chunk_count(haystack, needle);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_count(haystack, needle);
}
}
}
Expand Down Expand Up @@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_avx2::chunk_num_chars(utf8_chars);
}
}
}

Expand All @@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_sse2::chunk_num_chars(utf8_chars);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_num_chars(utf8_chars);
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize {
/// assert_eq!(number_of_spaces, 6);
/// ```
pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize)
utf8_chars
.iter()
.fold(0, |n, c| n + (*c == needle) as usize)
}

/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple
Expand All @@ -38,5 +40,8 @@ pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
/// assert_eq!(char_count, 4);
/// ```
pub fn naive_num_chars(utf8_chars: &[u8]) -> usize {
utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count()
utf8_chars
.iter()
.filter(|&&byte| (byte >> 6) != 0b10)
.count()
}
157 changes: 157 additions & 0 deletions src/simd/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use core::arch::aarch64::{
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
vmvnq_u8, vsubq_u8,
};

const MASK: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255,
];

#[target_feature(enable = "neon")]
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
debug_assert!(
offset + 16 <= slice.len(),
"{} + 16 ≥ {}",
offset,
slice.len()
);
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
}

#[target_feature(enable = "neon")]
unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t {
debug_assert!(
offset + 64 <= slice.len(),
"{} + 64 ≥ {}",
offset,
slice.len()
);
vld1q_u8_x4(slice.as_ptr().add(offset) as *const _)
}

#[target_feature(enable = "neon")]
unsafe fn sum(u8s: uint8x16_t) -> usize {
vaddlvq_u8(u8s) as usize
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
assert!(haystack.len() >= 16);

let mut offset = 0;
let mut count = 0;

let needles = vdupq_n_u8(needle);

// 16320
while haystack.len() >= offset + 64 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..255 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
}

// 64
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..(haystack.len() - offset) / 64 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);

let mut counts = vdupq_n_u8(0);
// 16
for i in 0..(haystack.len() - offset) / 16 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles),
);
}
if haystack.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles),
u8x16_from_offset(&MASK, haystack.len() % 16),
),
);
}
count + sum(counts)
}

#[target_feature(enable = "neon")]
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
vmvnq_u8(vceqq_u8(
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
vdupq_n_u8(0b1000_0000),
))
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
assert!(utf8_chars.len() >= 16);

let mut offset = 0;
let mut count = 0;

// 4080
while utf8_chars.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);

for _ in 0..255 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(counts);
}

// 2048
if utf8_chars.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(counts);
}

// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
);
}
if utf8_chars.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
),
);
}
count += sum(counts);

count
}
20 changes: 11 additions & 9 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::mem;
use self::packed_simd::{u8x32, u8x64, FromCast};

const MASK: [u8; 64] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
];

unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 {
Expand Down Expand Up @@ -66,15 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 32
let mut counts = u8x32::splat(0);
for i in 0..(haystack.len() - offset) / 32 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
}
count += sum_x32(&counts);

// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if haystack.len() % 32 != 0 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) &
u8x32_from_offset(&MASK, haystack.len() % 32);
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32))
& u8x32_from_offset(&MASK, haystack.len() % 32);
}
count += sum_x32(&counts);

Expand Down Expand Up @@ -127,8 +128,9 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if utf8_chars.len() % 32 != 0 {
counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) &
u8x32_from_offset(&MASK, utf8_chars.len() % 32);
counts -=
is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32))
& u8x32_from_offset(&MASK, utf8_chars.len() % 32);
}
count += sum_x32(&counts);

Expand Down
4 changes: 4 additions & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ pub mod x86_sse2;
// Runtime feature detection is not available with no_std.
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
pub mod x86_avx2;

/// Modern ARM machines are also quite capable thanks to NEON
#[cfg(target_arch = "aarch64")]
pub mod aarch64;
Loading