Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 74 additions & 15 deletions src/packed_seq.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
use core::cell::RefCell;
use traits::Seq;
use wide::u16x8;

use crate::{intrinsics::transpose, padded_it::ChunkIt};

use super::*;

type SimdBuf = [S; 8];

thread_local! {
static IT_BUF: RefCell<Vec<Box<SimdBuf>>> = {
RefCell::new(vec![Box::new(SimdBuf::default())])
};
}

struct RecycledBox(Option<Box<SimdBuf>>);

impl RecycledBox {
#[inline(always)]
pub fn init_if_needed(&mut self) {
if self.0.is_none() {
self.0 = Some(Box::new(SimdBuf::default()));
}
}

#[inline(always)]
pub fn get(&self) -> &SimdBuf {
unsafe { self.0.as_ref().unwrap_unchecked() }
}

#[inline(always)]
pub fn get_mut(&mut self) -> &mut SimdBuf {
unsafe { self.0.as_mut().unwrap_unchecked() }
}
}

impl Drop for RecycledBox {
#[inline(always)]
fn drop(&mut self) {
let mut x = None;
core::mem::swap(&mut x, &mut self.0);
IT_BUF.with_borrow_mut(|v| v.push(unsafe { x.unwrap_unchecked() }));
}
}

#[doc(hidden)]
pub struct Bits<const B: usize>;
#[doc(hidden)]
Expand All @@ -15,7 +54,7 @@ impl SupportedBits for Bits<4> {}
impl SupportedBits for Bits<8> {}

/// Number of padding bytes at the end of `PackedSeqVecBase::seq`.
const PADDING: usize = 16;
pub(crate) const PADDING: usize = 40;

/// A 2-bit packed non-owned slice of DNA bases.
#[doc(hidden)]
Expand Down Expand Up @@ -220,7 +259,7 @@ where
let start_byte = self.offset / Self::C8;
let end_byte = (self.offset + self.len).div_ceil(Self::C8);
Self {
seq: &self.seq[start_byte..end_byte],
seq: &self.seq[start_byte..end_byte + PADDING],
offset: self.offset % Self::C8,
len: self.len,
}
Expand All @@ -233,6 +272,16 @@ where
}
}

/// Read up to 32 bytes starting at idx.
#[inline(always)]
pub(crate) unsafe fn read_slice_32_unchecked(seq: &[u8], idx: usize) -> u32x8 {
unsafe {
let src = seq.as_ptr().add(idx);
debug_assert!(idx + 32 <= seq.len());
std::mem::transmute::<_, *const u32x8>(src).read_unaligned()
}
}

/// Read up to 32 bytes starting at idx.
#[inline(always)]
pub(crate) fn read_slice_32(seq: &[u8], idx: usize) -> u32x8 {
Expand Down Expand Up @@ -294,7 +343,6 @@ where
#[inline(always)]
fn as_u64(&self) -> u64 {
assert!(self.len() <= 64 / B);
debug_assert!(self.seq.len() <= 9);

let mask = u64::MAX >> (64 - B * self.len());

Expand Down Expand Up @@ -328,7 +376,6 @@ where
self.len() <= (128 - 8) / B + 1,
"Sequences >61 long cannot be read with a single unaligned u128 read."
);
debug_assert!(self.seq.len() <= 17);

let mask = u128::MAX >> (128 - B * self.len());

Expand Down Expand Up @@ -470,13 +517,23 @@ where

// Boxed, so it doesn't consume precious registers.
// Without this, cur is not always inlined into a register.
let mut buf = Box::new([S::ZERO; 8]);
// let mut buf = Box::new([S::ZERO; 8]);
let mut buf = IT_BUF.with_borrow_mut(|v| RecycledBox(v.pop()));
buf.init_if_needed();

let simd_char_mask: u32x8 = unsafe { core::mem::transmute([Self::CHAR_MASK as u32; 8]) };
let simd_b: u32x8 = unsafe { core::mem::transmute([B as u32; 8]) };

let par_len = if num_kmers == 0 {
0
} else {
n + context + o - 1
};

let last_read = par_len.saturating_sub(1) / Self::C32 * Self::C32;
// Safety check for the `read_slice_32_unchecked`:
assert!(offsets[7] + (last_read / Self::C8) + 32 <= this.seq.len());

let it = (0..par_len)
.map(
#[inline(always)]
Expand All @@ -486,16 +543,21 @@ where
// Read a u256 for each lane containing the next 128 characters.
let data: [u32x8; 8] = from_fn(
#[inline(always)]
|lane| read_slice_32(this.seq, offsets[lane] + (i / Self::C8)),
|lane| unsafe {
let idx = offsets[lane] + (i / Self::C8);
read_slice_32_unchecked(this.seq, idx)
},
);
*buf = transpose(data);
// *buf = transpose(data);
*buf.get_mut() = transpose(data);
}
cur = buf[(i % Self::C256) / Self::C32];
// cur = buf[(i % Self::C256) / Self::C32];
cur = buf.get()[(i % Self::C256) / Self::C32];
}
// Extract the last 2 bits of each character.
let chars = cur & S::splat(Self::CHAR_MASK as u32);
let chars = cur & simd_char_mask;
// Shift remaining characters to the right.
cur = cur >> S::splat(B as u32);
cur = cur >> simd_b;
chars
},
)
Expand Down Expand Up @@ -842,7 +904,7 @@ where
#[inline(always)]
fn as_slice(&self) -> Self::Seq<'_> {
PackedSeqBase {
seq: &self.seq[..self.len.div_ceil(Self::C8)],
seq: &self.seq[..self.len.div_ceil(Self::C8) + PADDING],
offset: 0,
len: self.len,
}
Expand All @@ -867,14 +929,11 @@ where
fn push_seq<'a>(&mut self, seq: PackedSeqBase<'_, B>) -> Range<usize> {
let start = self.len.next_multiple_of(Self::C8) + seq.offset;
let end = start + seq.len();
// Reserve *additional* capacity.
self.seq.reserve(seq.seq.len());

// Shrink away the padding.
self.seq.resize(self.len.div_ceil(Self::C8), 0);
// Extend.
self.seq.extend(seq.seq);
// Push padding.
self.seq.extend(std::iter::repeat_n(0u8, PADDING));
self.len = end;
start..end
}
Expand Down
44 changes: 27 additions & 17 deletions src/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use rand::{Rng, random_range};
use wide::u32x8;

use crate::packed_seq::PADDING;

use super::*;

fn pack_naive(seq: &[u8]) -> (Vec<u8>, usize) {
Expand Down Expand Up @@ -825,7 +827,7 @@ fn packed_seq_push_seq() {
);
total_len_in_bp += len.next_multiple_of(4);
}
assert_eq!(packed.seq.len(), total_len_in_bp.div_ceil(4) + 16,);
assert_eq!(packed.seq.len(), total_len_in_bp.div_ceil(4) + PADDING);
}

#[test]
Expand Down Expand Up @@ -974,14 +976,18 @@ fn par_iter_bp_bench() {
let seq = PackedSeqVec::random(len);

let start = std::time::Instant::now();
for _ in 0..rep {
let mut x = u32x8::splat(0);
let PaddedIt { it, .. } = seq.as_slice().par_iter_bp(1);
it.for_each(|y| {
x += y;
});
core::hint::black_box(&x);
}
(0..rep).for_each(
#[inline(always)]
|_| {
let PaddedIt { it, .. } = seq.as_slice().par_iter_bp(1);
it.for_each(
#[inline(always)]
|y| {
core::hint::black_box(&y);
},
);
},
);
eprintln!(
"Len {len:>7} => {:.03} Gbp/s",
start.elapsed().as_secs_f64().recip()
Expand All @@ -1000,14 +1006,18 @@ fn par_iter_kmer_ambiguity_bench() {
let seq = BitSeqVec::random(len, 0.01);

let start = std::time::Instant::now();
for _ in 0..rep {
let mut x = u32x8::splat(0);
let PaddedIt { it, .. } = seq.as_slice().par_iter_kmer_ambiguity(k, k - 1, 0);
it.for_each(|y| {
x += y;
});
core::hint::black_box(&x);
}
(0..rep).for_each(
#[inline(always)]
|_| {
let PaddedIt { it, .. } = seq.as_slice().par_iter_kmer_ambiguity(k, k - 1, 0);
it.for_each(
#[inline(always)]
|y| {
core::hint::black_box(&y);
},
);
},
);
eprintln!(
"Len {len:>7} => {:.03} Gbp/s",
start.elapsed().as_secs_f64().recip()
Expand Down