diff --git a/src/packed_seq.rs b/src/packed_seq.rs index f484cb1..431ddc1 100644 --- a/src/packed_seq.rs +++ b/src/packed_seq.rs @@ -1,3 +1,4 @@ +use core::cell::RefCell; use traits::Seq; use wide::u16x8; @@ -5,6 +6,44 @@ use crate::{intrinsics::transpose, padded_it::ChunkIt}; use super::*; +type SimdBuf = [S; 8]; + +thread_local! { + static IT_BUF: RefCell>> = { + RefCell::new(vec![Box::new(SimdBuf::default())]) + }; +} + +struct RecycledBox(Option>); + +impl RecycledBox { + #[inline(always)] + pub fn init_if_needed(&mut self) { + if self.0.is_none() { + self.0 = Some(Box::new(SimdBuf::default())); + } + } + + #[inline(always)] + pub fn get(&self) -> &SimdBuf { + unsafe { self.0.as_ref().unwrap_unchecked() } + } + + #[inline(always)] + pub fn get_mut(&mut self) -> &mut SimdBuf { + unsafe { self.0.as_mut().unwrap_unchecked() } + } +} + +impl Drop for RecycledBox { + #[inline(always)] + fn drop(&mut self) { + let mut x = None; + core::mem::swap(&mut x, &mut self.0); + IT_BUF.with_borrow_mut(|v| v.push(unsafe { x.unwrap_unchecked() })); + } +} + #[doc(hidden)] pub struct Bits; #[doc(hidden)] @@ -15,7 +54,7 @@ impl SupportedBits for Bits<4> {} impl SupportedBits for Bits<8> {} /// Number of padding bytes at the end of `PackedSeqVecBase::seq`. -const PADDING: usize = 16; +pub(crate) const PADDING: usize = 40; /// A 2-bit packed non-owned slice of DNA bases. #[doc(hidden)] @@ -220,7 +259,7 @@ where let start_byte = self.offset / Self::C8; let end_byte = (self.offset + self.len).div_ceil(Self::C8); Self { - seq: &self.seq[start_byte..end_byte], + seq: &self.seq[start_byte..end_byte + PADDING], offset: self.offset % Self::C8, len: self.len, } @@ -233,6 +272,16 @@ where } } +/// Read up to 32 bytes starting at idx. +#[inline(always)] +pub(crate) unsafe fn read_slice_32_unchecked(seq: &[u8], idx: usize) -> u32x8 { + unsafe { + let src = seq.as_ptr().add(idx); + debug_assert!(idx + 32 <= seq.len()); + std::mem::transmute::<_, *const u32x8>(src).read_unaligned() + } +} + /// Read up to 32 bytes starting at idx. #[inline(always)] pub(crate) fn read_slice_32(seq: &[u8], idx: usize) -> u32x8 { @@ -294,7 +343,6 @@ where #[inline(always)] fn as_u64(&self) -> u64 { assert!(self.len() <= 64 / B); - debug_assert!(self.seq.len() <= 9); let mask = u64::MAX >> (64 - B * self.len()); @@ -328,7 +376,6 @@ where self.len() <= (128 - 8) / B + 1, "Sequences >61 long cannot be read with a single unaligned u128 read." ); - debug_assert!(self.seq.len() <= 17); let mask = u128::MAX >> (128 - B * self.len()); @@ -470,13 +517,23 @@ where // Boxed, so it doesn't consume precious registers. // Without this, cur is not always inlined into a register. - let mut buf = Box::new([S::ZERO; 8]); + // let mut buf = Box::new([S::ZERO; 8]); + let mut buf = IT_BUF.with_borrow_mut(|v| RecycledBox(v.pop())); + buf.init_if_needed(); + + let simd_char_mask: u32x8 = unsafe { core::mem::transmute([Self::CHAR_MASK as u32; 8]) }; + let simd_b: u32x8 = unsafe { core::mem::transmute([B as u32; 8]) }; let par_len = if num_kmers == 0 { 0 } else { n + context + o - 1 }; + + let last_read = par_len.saturating_sub(1) / Self::C32 * Self::C32; + // Safety check for the `read_slice_32_unchecked`: + assert!(offsets[7] + (last_read / Self::C8) + 32 <= this.seq.len()); + let it = (0..par_len) .map( #[inline(always)] @@ -486,16 +543,21 @@ where // Read a u256 for each lane containing the next 128 characters. let data: [u32x8; 8] = from_fn( #[inline(always)] - |lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)), + |lane| unsafe { + let idx = offsets[lane] + (i / Self::C8); + read_slice_32_unchecked(this.seq, idx) + }, ); - *buf = transpose(data); + // *buf = transpose(data); + *buf.get_mut() = transpose(data); } - cur = buf[(i % Self::C256) / Self::C32]; + // cur = buf[(i % Self::C256) / Self::C32]; + cur = buf.get()[(i % Self::C256) / Self::C32]; } // Extract the last 2 bits of each character. - let chars = cur & S::splat(Self::CHAR_MASK as u32); + let chars = cur & simd_char_mask; // Shift remaining characters to the right. - cur = cur >> S::splat(B as u32); + cur = cur >> simd_b; chars }, ) @@ -842,7 +904,7 @@ where #[inline(always)] fn as_slice(&self) -> Self::Seq<'_> { PackedSeqBase { - seq: &self.seq[..self.len.div_ceil(Self::C8)], + seq: &self.seq[..self.len.div_ceil(Self::C8) + PADDING], offset: 0, len: self.len, } @@ -867,14 +929,11 @@ where fn push_seq<'a>(&mut self, seq: PackedSeqBase<'_, B>) -> Range { let start = self.len.next_multiple_of(Self::C8) + seq.offset; let end = start + seq.len(); - // Reserve *additional* capacity. - self.seq.reserve(seq.seq.len()); + // Shrink away the padding. self.seq.resize(self.len.div_ceil(Self::C8), 0); // Extend. self.seq.extend(seq.seq); - // Push padding. - self.seq.extend(std::iter::repeat_n(0u8, PADDING)); self.len = end; start..end } diff --git a/src/test.rs b/src/test.rs index a9df3fe..9c7a9ee 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,6 +1,8 @@ use rand::{Rng, random_range}; use wide::u32x8; +use crate::packed_seq::PADDING; + use super::*; fn pack_naive(seq: &[u8]) -> (Vec, usize) { @@ -825,7 +827,7 @@ fn packed_seq_push_seq() { ); total_len_in_bp += len.next_multiple_of(4); } - assert_eq!(packed.seq.len(), total_len_in_bp.div_ceil(4) + 16,); + assert_eq!(packed.seq.len(), total_len_in_bp.div_ceil(4) + PADDING); } #[test] @@ -974,14 +976,18 @@ fn par_iter_bp_bench() { let seq = PackedSeqVec::random(len); let start = std::time::Instant::now(); - for _ in 0..rep { - let mut x = u32x8::splat(0); - let PaddedIt { it, .. } = seq.as_slice().par_iter_bp(1); - it.for_each(|y| { - x += y; - }); - core::hint::black_box(&x); - } + (0..rep).for_each( + #[inline(always)] + |_| { + let PaddedIt { it, .. } = seq.as_slice().par_iter_bp(1); + it.for_each( + #[inline(always)] + |y| { + core::hint::black_box(&y); + }, + ); + }, + ); eprintln!( "Len {len:>7} => {:.03} Gbp/s", start.elapsed().as_secs_f64().recip() @@ -1000,14 +1006,18 @@ fn par_iter_kmer_ambiguity_bench() { let seq = BitSeqVec::random(len, 0.01); let start = std::time::Instant::now(); - for _ in 0..rep { - let mut x = u32x8::splat(0); - let PaddedIt { it, .. } = seq.as_slice().par_iter_kmer_ambiguity(k, k - 1, 0); - it.for_each(|y| { - x += y; - }); - core::hint::black_box(&x); - } + (0..rep).for_each( + #[inline(always)] + |_| { + let PaddedIt { it, .. } = seq.as_slice().par_iter_kmer_ambiguity(k, k - 1, 0); + it.for_each( + #[inline(always)] + |y| { + core::hint::black_box(&y); + }, + ); + }, + ); eprintln!( "Len {len:>7} => {:.03} Gbp/s", start.elapsed().as_secs_f64().recip()