Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 142 additions & 124 deletions grpc/src/client/load_balancing/child_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,34 @@ use std::sync::Mutex;
use std::{collections::HashMap, error::Error, hash::Hash, mem, sync::Arc};

use crate::client::load_balancing::{
ChannelController, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, WeakSubchannel,
WorkScheduler,
ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState,
WeakSubchannel, WorkScheduler,
};
use crate::client::name_resolution::{Address, ResolverUpdate};
use crate::client::service_config::LbConfig;

use super::{Subchannel, SubchannelState};

use tokio::sync::{mpsc, watch, Notify};
use tokio::task::{AbortHandle, JoinHandle};

// An LbPolicy implementation that manages multiple children.
pub struct ChildManager<T> {
subchannels: HashMap<WeakSubchannel, Arc<T>>,
children: HashMap<Arc<T>, Child>,
sharder: Box<dyn ResolverUpdateSharder<T>>,
updated: bool, // true iff a child has updated its state since the last call to has_updated.
work_requests: Arc<Mutex<HashSet<Arc<T>>>>,
work_scheduler: Arc<dyn WorkScheduler>,
subchannel_child_map: HashMap<WeakSubchannel, usize>,
children: Vec<Child<T>>,
update_sharder: Box<dyn ResolverUpdateSharder<T>>,
pending_work: Arc<Mutex<HashSet<usize>>>,
}

pub trait ChildIdentifier: PartialEq + Hash + Eq + Send + Sync + 'static {}

struct Child {
struct Child<T> {
identifier: T,
policy: Box<dyn LbPolicy>,
state: LbState,
work_scheduler: Arc<ChildWorkScheduler>,
}

/// A collection of data sent to a child of the ChildManager.
pub struct ChildUpdate {
pub struct ChildUpdate<T> {
/// The identifier the ChildManager should use for this child.
pub child_identifier: T,
/// The builder the ChildManager should use to create this child if it does
/// not exist.
pub child_policy_builder: Box<dyn LbPolicyBuilder>,
Expand All @@ -73,70 +71,48 @@ pub trait ResolverUpdateSharder<T: ChildIdentifier>: Send {
fn shard_update(
&self,
resolver_update: ResolverUpdate,
) -> Result<HashMap<T, ChildUpdate>, Box<dyn Error + Send + Sync>>;
) -> Result<Box<dyn Iterator<Item = ChildUpdate<T>>>, Box<dyn Error + Send + Sync>>;
}

impl<T: ChildIdentifier> ChildManager<T> {
/// Creates a new ChildManager LB policy. shard_update is called whenever a
/// resolver_update operation occurs.
pub fn new(
work_scheduler: Arc<dyn WorkScheduler>,
sharder: Box<dyn ResolverUpdateSharder<T>>,
) -> Self {
ChildManager {
subchannels: HashMap::default(),
children: HashMap::default(),
sharder,
updated: false,
work_requests: Arc::default(),
work_scheduler,
pub fn new(update_sharder: Box<dyn ResolverUpdateSharder<T>>) -> Self {
Self {
update_sharder,
subchannel_child_map: Default::default(),
children: Default::default(),
pending_work: Default::default(),
}
}

/// Returns data for all current children.
pub fn child_states(&mut self) -> impl Iterator<Item = (&T, &LbState)> {
self.children
.iter()
.map(|(id, child)| (id.as_ref(), &child.state))
}

pub fn has_updated(&mut self) -> bool {
mem::take(&mut self.updated)
.map(|child| (&child.identifier, &child.state))
}

// Called to update all accounting in the ChildManager from operations
// performed by a child policy on the WrappedController that was created for
// it.
// it. child_idx is an index into the children map for the relevant child.
//
// TODO: this post-processing step can be eliminated by capturing the right
// state inside the WrappedController, however it is fairly complex. Decide
// which way is better.
fn resolve_child_controller(
&mut self,
channel_controller: WrappedController,
child_id: Arc<T>,
child_idx: usize,
) {
// Add all created subchannels into the subchannel_child_map.
for csc in channel_controller.created_subchannels {
self.subchannels
.insert(WeakSubchannel::new(csc), child_id.clone());
self.subchannel_child_map.insert(csc.into(), child_idx);
}
// Update the tracked state if the child produced an update.
if let Some(state) = channel_controller.picker_update {
self.children.get_mut(&child_id.clone()).unwrap().state = state;
self.updated = true;
self.children[child_idx].state = state;
};
// Prune subchannels created by this child that are no longer
// referenced.
self.subchannels.retain(|sc, cid| {
if cid != &child_id {
return true;
}
if sc.upgrade().is_none() {
return false;
}
true
});
}
}

Expand All @@ -148,47 +124,108 @@ impl<T: ChildIdentifier> LbPolicy for ChildManager<T> {
channel_controller: &mut dyn ChannelController,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// First determine if the incoming update is valid.
let child_updates = self.sharder.shard_update(resolver_update)?;
let child_updates = self.update_sharder.shard_update(resolver_update)?;

// Remove children that are no longer active.
self.children
.retain(|child_id, _| child_updates.contains_key(child_id));

// Apply child updates to respective policies, instantiating new ones as
// needed.
for (id, update) in child_updates.into_iter() {
let child_id: Arc<T> = Arc::new(id);
let child_policy: &mut dyn LbPolicy = match self.children.get_mut(&child_id) {
Some(child) => child.policy.as_mut(),
None => {
self.children.insert(
child_id.clone(),
Child {
policy: update.child_policy_builder.build(LbPolicyOptions {
work_scheduler: Arc::new(ChildScheduler::new(
child_id.clone(),
self.work_scheduler.clone(),
self.work_requests.clone(),
)),
}),
state: LbState::initial(),
},
);
self.children.get_mut(&child_id).unwrap().policy.as_mut()
// Hold the lock to prevent new work requests during this operation and
// rewrite the indices.
let mut pending_work = self.pending_work.lock().unwrap();

// Reset pending work; we will re-add any entries it contains with the
// right index later.
let old_pending_work = mem::take(&mut *pending_work);

// Replace self.children with an empty vec.
let old_children = mem::take(&mut self.children);

// Replace the subchannel map with an empty map.
let old_subchannel_child_map = mem::take(&mut self.subchannel_child_map);

// Reverse the old subchannel map.
let mut old_child_subchannels_map: HashMap<usize, Vec<WeakSubchannel>> = HashMap::new();

for (subchannel, child_idx) in old_subchannel_child_map {
old_child_subchannels_map
.entry(child_idx)
.or_default()
.push(subchannel);
}

// Build a map of the old children from their IDs for efficient lookups.
let old_children = old_children
.into_iter()
.enumerate()
.map(|(old_idx, e)| (e.identifier, (e.policy, e.state, old_idx, e.work_scheduler)));
let mut old_children: HashMap<T, _> = old_children.collect();

// Split the child updates into the IDs and builders, and the
// ResolverUpdates.
let (ids_builders, updates): (Vec<_>, Vec<_>) = child_updates
.map(|e| ((e.child_identifier, e.child_policy_builder), e.child_update))
.unzip();

// Transfer children whose identifiers appear before and after the
// update, and create new children. Add entries back into the
// subchannel map.
for (new_idx, (identifier, builder)) in ids_builders.into_iter().enumerate() {
if let Some((policy, state, old_idx, work_scheduler)) = old_children.remove(&identifier)
{
for subchannel in old_child_subchannels_map
.remove(&old_idx)
.into_iter()
.flatten()
{
self.subchannel_child_map.insert(subchannel, new_idx);
}
if old_pending_work.contains(&old_idx) {
pending_work.insert(new_idx);
}
*work_scheduler.idx.lock().unwrap() = Some(new_idx);
self.children.push(Child {
identifier,
state,
policy,
work_scheduler,
});
} else {
let work_scheduler = Arc::new(ChildWorkScheduler {
pending_work: self.pending_work.clone(),
idx: Mutex::new(Some(new_idx)),
});
let policy = builder.build(LbPolicyOptions {
work_scheduler: work_scheduler.clone(),
});
let state = LbState::initial();
self.children.push(Child {
identifier,
state,
policy,
work_scheduler,
});
};
let mut channel_controller = WrappedController::new(channel_controller);
let _ = child_policy.resolver_update(
update.child_update.clone(),
config,
&mut channel_controller,
);
self.resolve_child_controller(channel_controller, child_id.clone());
}

// Keep only the subchannels associated with currently active children.
self.subchannels
.retain(|_, child_id| self.children.contains_key(child_id));
// Invalidate all deleted children's work_schedulers.
for (_, (_, _, _, work_scheduler)) in old_children {
*work_scheduler.idx.lock().unwrap() = None;
}

// Release the pending_work mutex before calling into the children to
// allow their work scheduler calls to unblock.
drop(pending_work);

// Anything left in old_children will just be Dropped and cleaned up.

// Call resolver_update on all children.
let mut updates = updates.into_iter();
for child_idx in 0..self.children.len() {
let child = &mut self.children[child_idx];
let child_update = updates.next().unwrap();
let mut channel_controller = WrappedController::new(channel_controller);
let _ = child
.policy
.resolver_update(child_update, config, &mut channel_controller);
self.resolve_child_controller(channel_controller, child_idx);
}
Ok(())
}

Expand All @@ -199,32 +236,26 @@ impl<T: ChildIdentifier> LbPolicy for ChildManager<T> {
channel_controller: &mut dyn ChannelController,
) {
// Determine which child created this subchannel.
let child_id = self
.subchannels
.get(&WeakSubchannel::new(subchannel.clone()))
.unwrap_or_else(|| {
panic!("Subchannel not found in child manager: {}", subchannel);
});
let policy = &mut self.children.get_mut(&child_id.clone()).unwrap().policy;

let child_idx = *self
.subchannel_child_map
.get(&WeakSubchannel::new(&subchannel))
.unwrap();
let policy = &mut self.children[child_idx].policy;
// Wrap the channel_controller to track the child's operations.
let mut channel_controller = WrappedController::new(channel_controller);
// Call the proper child.
policy.subchannel_update(subchannel, state, &mut channel_controller);
self.resolve_child_controller(channel_controller, child_id.clone());
self.resolve_child_controller(channel_controller, child_idx);
}

fn work(&mut self, channel_controller: &mut dyn ChannelController) {
let children = mem::take(&mut *self.work_requests.lock().unwrap());
// It is possible that work was queued for a child that got removed as
// part of a subsequent resolver_update. So, it is safe to ignore such a
// child here.
for child_id in children {
if let Some(child) = self.children.get_mut(&child_id) {
let mut channel_controller = WrappedController::new(channel_controller);
child.policy.work(&mut channel_controller);
self.resolve_child_controller(channel_controller, child_id.clone());
}
let child_idxes = mem::take(&mut *self.pending_work.lock().unwrap());
for child_idx in child_idxes {
let mut channel_controller = WrappedController::new(channel_controller);
self.children[child_idx]
.policy
.work(&mut channel_controller);
self.resolve_child_controller(channel_controller, child_idx);
}
}
}
Expand Down Expand Up @@ -261,29 +292,16 @@ impl ChannelController for WrappedController<'_> {
}
}

struct ChildScheduler<T: ChildIdentifier> {
child_identifier: Arc<T>,
work_requests: Arc<Mutex<HashSet<Arc<T>>>>,
work_scheduler: Arc<dyn WorkScheduler>,
struct ChildWorkScheduler {
pending_work: Arc<Mutex<HashSet<usize>>>, // Must be taken first for correctness
idx: Mutex<Option<usize>>, // None if the child is deleted.
}

impl<T: ChildIdentifier> ChildScheduler<T> {
fn new(
child_identifier: Arc<T>,
work_scheduler: Arc<dyn WorkScheduler>,
work_requests: Arc<Mutex<HashSet<Arc<T>>>>,
) -> Self {
Self {
child_identifier,
work_requests,
work_scheduler,
}
}
}

impl<T: ChildIdentifier> WorkScheduler for ChildScheduler<T> {
impl WorkScheduler for ChildWorkScheduler {
fn schedule_work(&self) {
(*self.work_requests.lock().unwrap()).insert(self.child_identifier.clone());
self.work_scheduler.schedule_work();
let mut pending_work = self.pending_work.lock().unwrap();
if let Some(idx) = *self.idx.lock().unwrap() {
pending_work.insert(idx);
}
}
}
10 changes: 8 additions & 2 deletions grpc/src/client/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,16 @@ impl Display for dyn Subchannel {

struct WeakSubchannel(Weak<dyn Subchannel>);

impl WeakSubchannel {
pub fn new(subchannel: Arc<dyn Subchannel>) -> Self {
impl From<Arc<dyn Subchannel>> for WeakSubchannel {
fn from(subchannel: Arc<dyn Subchannel>) -> Self {
WeakSubchannel(Arc::downgrade(&subchannel))
}
}

impl WeakSubchannel {
pub fn new(subchannel: &Arc<dyn Subchannel>) -> Self {
WeakSubchannel(Arc::downgrade(subchannel))
}

pub fn upgrade(&self) -> Option<Arc<dyn Subchannel>> {
self.0.upgrade()
Expand Down
Loading