Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 63 additions & 58 deletions src/lockmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::borrow::Borrow;
use std::cell::UnsafeCell;
use std::collections::BTreeSet;
use std::hash::Hash;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::OnceLock;

/// Internal state for a key-value pair in the `LockMap`.
///
/// This type manages both the stored value and the queue of waiting threads
/// for per-key synchronization.
/// This type manages the stored value, the per-key lock, and a reference count
/// used for both synchronization optimization and memory management.
struct State<V> {
refcnt: AtomicU32,
mutex: Mutex,
Expand Down Expand Up @@ -138,9 +138,7 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let ptr: *mut State<V> = self.map.update(key.clone(), |s| match s {
Some(state) => {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
let ptr = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, ptr)
}
Expand Down Expand Up @@ -185,9 +183,7 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let ptr: *mut State<V> = self.map.update_by_ref(key, |s| match s {
Some(state) => {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
let ptr = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, ptr)
}
Expand Down Expand Up @@ -237,15 +233,14 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_ref() }.clone();
(SimpleAction::Keep, value)
} else {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, None)
}
Expand Down Expand Up @@ -289,15 +284,14 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let (ptr, value) = self.map.update(key.clone(), move |s| match s {
Some(state) => {
if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.replace(value);
(UpdateAction::Keep, (std::ptr::null_mut(), value))
} else {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
let ptr: *mut State<V> = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, (ptr, Some(value)))
}
Expand Down Expand Up @@ -349,15 +343,14 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
{
let (ptr, value) = self.map.update_by_ref(key, move |s| match s {
Some(state) => {
if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.replace(value);
(UpdateAction::Keep, (std::ptr::null_mut(), value))
} else {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
let ptr: *mut State<V> = &**state as *const State<V> as *mut State<V>;
(UpdateAction::Keep, (ptr, Some(value)))
}
Expand Down Expand Up @@ -410,14 +403,13 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
// meaning no other thread can be holding an `Entry` for this key.
(SimpleAction::Keep, unsafe { state.value_ref() }.is_some())
} else {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, false)
}
Expand Down Expand Up @@ -463,15 +455,14 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
let mut ptr: *mut State<V> = std::ptr::null_mut();
let value = self.map.simple_update(key, |s| match s {
Some(state) => {
if state.refcnt.load(std::sync::atomic::Ordering::Relaxed) == 0 {
// Use Acquire to ensure we see the latest value if refcnt is 0.
if state.refcnt.load(Ordering::Acquire) == 0 {
// SAFETY: We are inside the map's shard lock, and refcnt is 0,
// meaning no other thread can be holding an `Entry` for this key.
let value = unsafe { state.value_mut() }.take();
(SimpleAction::Remove, value)
} else {
state
.refcnt
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
state.refcnt.fetch_add(1, Ordering::AcqRel);
ptr = &**state as *const State<V> as *mut State<V>;
(SimpleAction::Keep, None)
}
Expand Down Expand Up @@ -551,29 +542,33 @@ impl<K: Eq + Hash, V> LockMap<K, V> {
.collect()
}

fn unlock<Q>(&self, key: &Q)
/// Attempts to remove an entry from the map if it's no longer needed.
///
/// An entry is considered no longer needed if its reference count is 0
/// and it contains no value.
fn try_remove_entry<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.map.simple_update(key, |value| match value {
Some(state) => {
let prev = state
.refcnt
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if prev == 1 && unsafe {
// SAFETY: We are inside the map's shard lock, and refcnt was 1 (now 0),
// meaning no other thread can be holding an `Entry` for this key.
state.value_ref()
}
.is_none()
// SAFETY: We are inside the map's shard lock. If `refcnt` is 0 here,
// then no `Entry` is currently held for this key, and no other thread
// can increment `refcnt` without first acquiring this same shard lock.
// Therefore, if the stored value is also `None`, it is safe to remove
// the entry from the map.
if state.refcnt.load(Ordering::Acquire) == 0
&& unsafe { state.value_ref() }.is_none()
{
(SimpleAction::Remove, ())
} else {
(SimpleAction::Keep, ())
}
}
None => panic!("impossible: unlock a non-existent key!"),
// The key might have been removed by another thread (e.g., via `remove`)
// between the `refcnt` decrement and this call.
None => (SimpleAction::Keep, ()),
});
}

Expand Down Expand Up @@ -736,7 +731,20 @@ impl<K: Eq + Hash, V> Drop for EntryByVal<'_, K, V> {
fn drop(&mut self) {
// SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it.
unsafe { (*self.state).mutex.unlock() };
self.map.unlock(&self.key);

// SAFETY: The pointer `self.state` remains valid here because the `EntryByVal`
// incremented the `State`'s reference count when it was created. While `self` is
// alive in this `drop` call, the reference count is therefore at least 1, and this
// `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The
// `State` is only deallocated once its reference count reaches zero, which can only
// occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access
// `refcnt` is safe at this point.
let prev = (unsafe { &*self.state })
.refcnt
.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
self.map.try_remove_entry(&self.key);
}
}
}

Expand Down Expand Up @@ -880,7 +888,20 @@ impl<K: Eq + Hash + Borrow<Q>, Q: Eq + Hash + ?Sized, V> Drop for EntryByRef<'_,
fn drop(&mut self) {
// SAFETY: The entry holds the lock on the `State`, so it is safe to unlock it.
unsafe { (*self.state).mutex.unlock() };
self.map.unlock(self.key);

// SAFETY: The pointer `self.state` remains valid here because the `EntryByRef`
// incremented the `State`'s reference count when it was created. While `self` is
// alive in this `drop` call, the reference count is therefore at least 1, and this
// `fetch_sub(1, ...)` is decrementing that last reference held by the entry. The
// `State` is only deallocated once its reference count reaches zero, which can only
// occur after this `fetch_sub` completes. Thus, dereferencing `self.state` to access
// `refcnt` is safe at this point.
let prev = (unsafe { &*self.state })
.refcnt
.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
self.map.try_remove_entry(self.key);
}
}
}

Expand Down Expand Up @@ -957,22 +978,6 @@ mod tests {
}
}

#[test]
#[should_panic(expected = "impossible: unlock a non-existent key!")]
fn test_lockmap_invalid_unlock() {
let map = LockMap::<u32, u32>::new();
let state = State {
refcnt: AtomicU32::new(1),
mutex: Mutex::new(),
value: UnsafeCell::new(None),
};
let _ = EntryByVal {
map: &map,
key: 7268,
state: &state as *const State<u32> as *mut State<u32>,
};
}

#[test]
fn test_lockmap_same_key_by_value() {
let lock_map = Arc::new(LockMap::<usize, usize>::with_capacity(256));
Expand Down
Loading