diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index e9ab5c9b3..9b4a45ad1 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -1,21 +1,30 @@ use crate::client::{ load_balancing::{ - ChannelController, ExternalSubchannel, ForwardingSubchannel, LbState, Subchannel, - WorkScheduler, + ChannelController, ExternalSubchannel, ForwardingSubchannel, LbPolicy, LbPolicyBuilder, + LbPolicyOptions, LbState, ParsedJsonLbConfig, Pick, PickResult, Picker, Subchannel, + SubchannelState, WorkScheduler, }, name_resolution::Address, + service_config::LbConfig, + ConnectivityState, }; use crate::service::{Message, Request, Response, Service}; +use futures_util::future::ok; +use serde::{Deserialize, Serialize}; use std::{ + collections::HashMap, + error::Error, fmt::Display, hash::{Hash, Hasher}, ops::Add, + ptr, sync::Arc, }; use tokio::{ sync::{mpsc, Notify}, task::AbortHandle, }; +use tonic::metadata::MetadataMap; pub(crate) struct EmptyMessage {} impl Message for EmptyMessage {} @@ -127,3 +136,282 @@ impl WorkScheduler for TestWorkScheduler { self.tx_events.send(TestEvent::ScheduleWork).unwrap(); } } + +pub struct MockBalancerOne { + connectivity_state: ConnectivityState, + subchannel_list: Option, +} + +pub struct MockBalancerTwo { + connectivity_state: ConnectivityState, + subchannel_list: Option, +} + +impl LbPolicy for MockBalancerOne { + fn resolver_update( + &mut self, + update: crate::client::name_resolution::ResolverUpdate, + config: Option<&crate::client::service_config::LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + if let Ok(ref endpoints) = update.endpoints { + let addresses: Vec<_> = endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); + let scl = SubchannelList::new(&addresses, channel_controller); + self.subchannel_list = Some(scl); + } + channel_controller.update_picker(LbState { + connectivity_state: self.connectivity_state, + picker: Arc::new(MockPickerOne), + }); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &super::SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + if let Some(ref mut scl) = self.subchannel_list { + scl.update_subchannel_data(&subchannel, state); + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(MockPickerOne), + }); + } + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } +} + + +impl LbPolicy for MockBalancerTwo { + fn resolver_update( + &mut self, + update: crate::client::name_resolution::ResolverUpdate, + config: Option<&crate::client::service_config::LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + if let Ok(ref endpoints) = update.endpoints { + let addresses: Vec<_> = endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); + let scl = SubchannelList::new(&addresses, channel_controller); + self.subchannel_list = Some(scl); + } + channel_controller.update_picker(LbState { + connectivity_state: self.connectivity_state, + picker: Arc::new(MockPickerTwo), + }); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &super::SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + if let Some(ref mut scl) = self.subchannel_list { + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(MockPickerTwo), + }); + } + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } +} +pub static POLICY_NAME: &str = "mock_policy_one"; +pub static MOCK_POLICY_TWO: &str = "mock_policy_two"; + +struct MockPolicyOneBuilder {} +struct MockPolicyTwoBuilder {} + +impl LbPolicyBuilder for MockPolicyOneBuilder { + fn build(&self, options: LbPolicyOptions) -> Box { + Box::new(MockBalancerOne { + subchannel_list: None, + connectivity_state: ConnectivityState::Connecting, + }) + } + + fn name(&self) -> &'static str { + POLICY_NAME + } + + fn parse_config( + &self, + config: &ParsedJsonLbConfig, + ) -> Result, Box> { + let cfg: MockConfig = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + Ok(Some(LbConfig::new(cfg))) + } +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub(super) struct MockConfig { + shuffle_address_list: Option, +} + +impl LbPolicyBuilder for MockPolicyTwoBuilder { + fn build(&self, options: LbPolicyOptions) -> Box { + Box::new(MockBalancerTwo { + subchannel_list: None, + connectivity_state: ConnectivityState::Connecting, + }) + } + + fn name(&self) -> &'static str { + MOCK_POLICY_TWO + } + + fn parse_config( + &self, + config: &ParsedJsonLbConfig, + ) -> Result, Box> { + let cfg: MockConfig = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + Ok(Some(LbConfig::new(cfg))) + } +} + +#[derive(Clone)] +struct SubchannelData { + state: Option, + seen_transient_failure: bool, +} + +impl SubchannelData { + fn new() -> SubchannelData { + SubchannelData { + state: None, + seen_transient_failure: false, + } + } +} + +struct SubchannelList { + subchannels: HashMap, SubchannelData>, + ordered_subchannels: Vec>, + current_idx: usize, + num_initial_notifications_seen: usize, +} + +impl SubchannelList { + fn new(addresses: &Vec
, channel_controller: &mut dyn ChannelController) -> Self { + let mut scl = SubchannelList { + subchannels: HashMap::new(), + ordered_subchannels: Vec::new(), + current_idx: 0, + num_initial_notifications_seen: 0, + }; + for address in addresses { + let sc = channel_controller.new_subchannel(address); + scl.ordered_subchannels.push(sc.clone()); + scl.subchannels.insert(sc, SubchannelData::new()); + } + scl + } + + fn subchannel_data(&self, sc: &Arc) -> Option { + self.subchannels.get(sc).cloned() + } + + fn contains(&self, sc: &Arc) -> bool { + self.subchannels.contains_key(sc) + } + + // Updates internal state of the subchannel with the new state. Callers must + // ensure that this method is called only for subchannels in the list. + // + // Returns old state corresponding to the subchannel, if one exists. + fn update_subchannel_data( + &mut self, + sc: &Arc, + state: &SubchannelState, + ) -> Option { + let sc_data = self.subchannels.get_mut(sc).unwrap(); + + // Increment the counter when seeing the first update. + if sc_data.state.is_none() { + self.num_initial_notifications_seen += 1; + } + + let old_state = sc_data.state.clone(); + sc_data.state = Some(state.clone()); + match state.connectivity_state { + ConnectivityState::Ready => sc_data.seen_transient_failure = false, + ConnectivityState::TransientFailure => sc_data.seen_transient_failure = true, + _ => {} + } + + old_state + } +} + +pub struct MockPickerOne; +pub struct MockPickerTwo; + +impl Picker for MockPickerOne { + fn pick(&self, _req: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: Arc::new(TestSubchannel::new( + Address { + address: "one".to_string(), + ..Default::default() + }, + mpsc::unbounded_channel().0, + )), + on_complete: None, + metadata: MetadataMap::new(), + }) + } +} + +impl Picker for MockPickerTwo { + fn pick(&self, _req: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: Arc::new(TestSubchannel::new( + Address { + address: "two".to_string(), + ..Default::default() + }, + mpsc::unbounded_channel().0, + )), + on_complete: None, + metadata: MetadataMap::new(), + }) + } +} + +pub fn reg() { + super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyOneBuilder {}); + super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyTwoBuilder {}); +}