From 33e683ddcbbf9043755719435bab2c394a5eb0fb Mon Sep 17 00:00:00 2001 From: Cathy Zhao Date: Tue, 1 Jul 2025 06:42:42 +0000 Subject: [PATCH 1/2] add mock balancers for testing --- grpc/src/client/load_balancing/test_utils.rs | 345 ++++++++++++++++++- 1 file changed, 337 insertions(+), 8 deletions(-) diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index e9ab5c9b3..c1c4066df 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -1,21 +1,19 @@ use crate::client::{ load_balancing::{ - ChannelController, ExternalSubchannel, ForwardingSubchannel, LbState, Subchannel, - WorkScheduler, - }, - name_resolution::Address, + ChannelController, ExternalSubchannel, ForwardingSubchannel, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, ParsedJsonLbConfig, Pick, PickResult, Picker, Subchannel, SubchannelState, WorkScheduler + }, name_resolution::Address, service_config::LbConfig, ConnectivityState }; use crate::service::{Message, Request, Response, Service}; use std::{ - fmt::Display, - hash::{Hash, Hasher}, - ops::Add, - sync::Arc, + collections::HashMap, error::Error, fmt::Display, hash::{Hash, Hasher}, ops::Add, ptr, sync::Arc }; +use futures_util::future::ok; +use serde::{Deserialize, Serialize}; use tokio::{ sync::{mpsc, Notify}, task::AbortHandle, }; +use tonic::metadata::MetadataMap; pub(crate) struct EmptyMessage {} impl Message for EmptyMessage {} @@ -127,3 +125,334 @@ impl WorkScheduler for TestWorkScheduler { self.tx_events.send(TestEvent::ScheduleWork).unwrap(); } } + +pub struct MockBalancerOne { + connectivity_state: ConnectivityState, + subchannel_list: Option, +} + +pub struct MockBalancerTwo { + connectivity_state: ConnectivityState, + subchannel_list: Option, +} + +impl LbPolicy for MockBalancerOne { + fn resolver_update( + &mut self, + update: crate::client::name_resolution::ResolverUpdate, + config: Option<&crate::client::service_config::LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + if let Ok(ref endpoints) = update.endpoints { + let addresses: Vec<_> = endpoints.iter().flat_map(|ep| ep.addresses.clone()).collect(); + let scl = SubchannelList::new(&addresses, channel_controller); + self.subchannel_list = Some(scl); + } + channel_controller.update_picker(LbState { + connectivity_state: self.connectivity_state, + picker: Arc::new(MockPickerOne), + }); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &super::SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + if let Some(ref mut scl) = self.subchannel_list { + scl.update_subchannel_data(&subchannel, state); + println!("updating ready picker for mock balancer 2"); + // Optionally, send a picker update to simulate state change + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(MockPickerOne), // or MockPickerTwo + }); + } + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } +} + +impl LbPolicy for MockBalancerTwo { + fn resolver_update( + &mut self, + update: crate::client::name_resolution::ResolverUpdate, + config: Option<&crate::client::service_config::LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + if let Ok(ref endpoints) = update.endpoints { + let addresses: Vec<_> = endpoints.iter().flat_map(|ep| ep.addresses.clone()).collect(); + let scl = SubchannelList::new(&addresses, channel_controller); + self.subchannel_list = Some(scl); + } + channel_controller.update_picker(LbState { + connectivity_state: self.connectivity_state, + picker: Arc::new(MockPickerTwo), + }); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &super::SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + if let Some(ref mut scl) = self.subchannel_list { + // scl.update_subchannel_data(&subchannel, state); + println!("updating ready picker for mock balancer 2"); + // Optionally, send a picker update to simulate state change + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(MockPickerTwo), // or MockPickerTwo + }); + } + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + todo!() + } +} +pub static POLICY_NAME: &str = "mock_policy_one"; +pub static MOCK_POLICY_TWO: &str = "mock_policy_two"; + +struct MockPolicyOneBuilder {} +struct MockPolicyTwoBuilder {} + + +impl LbPolicyBuilder for MockPolicyOneBuilder { + fn build(&self, options: LbPolicyOptions) -> Box { + Box::new(MockBalancerOne { + // work_scheduler: options.work_scheduler, + subchannel_list: None, + // selected_subchannel: None, + // addresses: vec![], + // last_resolver_error: None, + // last_connection_error: None, + connectivity_state: ConnectivityState::Connecting, + // sent_connecting_state: false, + // num_transient_failures: 0, + }) + } + + fn name(&self) -> &'static str { + POLICY_NAME + } + + fn parse_config( + &self, + config: &ParsedJsonLbConfig, + ) -> Result, Box> { + let cfg: MockConfig = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + Ok(Some(LbConfig::new(cfg))) + } +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub(super) struct MockConfig { + shuffle_address_list: Option, +} + +impl LbPolicyBuilder for MockPolicyTwoBuilder { + fn build(&self, options: LbPolicyOptions) -> Box { + Box::new(MockBalancerTwo { + // work_scheduler: options.work_scheduler, + subchannel_list: None, + // selected_subchannel: None, + // addresses: vec![], + // last_resolver_error: None, + // last_connection_error: None, + connectivity_state: ConnectivityState::Connecting, + // sent_connecting_state: false, + // num_transient_failures: 0, + }) + } + + fn name(&self) -> &'static str { + MOCK_POLICY_TWO + } + + fn parse_config( + &self, + config: &ParsedJsonLbConfig, + ) -> Result, Box> { + let cfg: MockConfig = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + Ok(Some(LbConfig::new(cfg))) + } + + +} + +#[derive(Clone)] +struct SubchannelData { + state: Option, + seen_transient_failure: bool, +} + +impl SubchannelData { + fn new() -> SubchannelData { + SubchannelData { + state: None, + seen_transient_failure: false, + } + } +} + +struct SubchannelList { + subchannels: HashMap, SubchannelData>, + ordered_subchannels: Vec>, + current_idx: usize, + num_initial_notifications_seen: usize, +} + +impl SubchannelList { + fn new(addresses: &Vec
, channel_controller: &mut dyn ChannelController) -> Self { + let mut scl = SubchannelList { + subchannels: HashMap::new(), + ordered_subchannels: Vec::new(), + current_idx: 0, + num_initial_notifications_seen: 0, + }; + for address in addresses { + let sc = channel_controller.new_subchannel(address); + scl.ordered_subchannels.push(sc.clone()); + scl.subchannels.insert(sc, SubchannelData::new()); + } + + println!("created new subchannel list with {} subchannels", scl.len()); + scl + } + + fn len(&self) -> usize { + self.ordered_subchannels.len() + } + + fn subchannel_data(&self, sc: &Arc) -> Option { + self.subchannels.get(sc).cloned() + } + + fn contains(&self, sc: &Arc) -> bool { + self.subchannels.contains_key(sc) + } + + // Updates internal state of the subchannel with the new state. Callers must + // ensure that this method is called only for subchannels in the list. + // + // Returns old state corresponding to the subchannel, if one exists. + fn update_subchannel_data( + &mut self, + sc: &Arc, + state: &SubchannelState, + ) -> Option { + let sc_data = self.subchannels.get_mut(sc).unwrap(); + + // Increment the counter when seeing the first update. + if sc_data.state.is_none() { + self.num_initial_notifications_seen += 1; + } + + let old_state = sc_data.state.clone(); + sc_data.state = Some(state.clone()); + match state.connectivity_state { + ConnectivityState::Ready => sc_data.seen_transient_failure = false, + ConnectivityState::TransientFailure => sc_data.seen_transient_failure = true, + _ => {} + } + + old_state + } + + + // Initiates a connection attempt to the next subchannel in the list that is + // IDLE. Returns false if there are no more subchannels in the list. + fn connect_to_next_subchannel( + &mut self, + channel_controller: &mut dyn ChannelController, + ) -> bool { + // Special case for the first connection attempt, as current_idx is set + // to 0 when the subchannel list is created. + if self.current_idx != 0 { + self.current_idx += 1; + } + + for idx in self.current_idx..self.ordered_subchannels.len() { + // Grab the next subchannel and its data. + let sc = &self.ordered_subchannels[idx]; + let sc_data = self.subchannels.get(sc).unwrap(); + + match &sc_data.state { + Some(state) => { + if state.connectivity_state == ConnectivityState::Connecting + || state.connectivity_state == ConnectivityState::TransientFailure + { + self.current_idx += 1; + continue; + } else if state.connectivity_state == ConnectivityState::Idle { + sc.connect(); + return true; + } + } + None => { + debug_assert!( + false, + "No state available when asked to connect to subchannel: {}", + sc, + ); + } + } + } + false + } + +} + +pub struct MockPickerOne; +pub struct MockPickerTwo; + +impl Picker for MockPickerOne { + fn pick(&self, _req: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: Arc::new(TestSubchannel::new(Address { address: "one".to_string(), ..Default::default() }, mpsc::unbounded_channel().0)), + on_complete: None, + metadata: MetadataMap::new(), + }) + } +} + +impl Picker for MockPickerTwo { + fn pick(&self, _req: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: Arc::new(TestSubchannel::new(Address { address: "two".to_string(), ..Default::default() }, mpsc::unbounded_channel().0)), + on_complete: None, + metadata: MetadataMap::new(), + }) + } +} + +pub fn reg() { + super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyOneBuilder {}); + super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyTwoBuilder {}); +} \ No newline at end of file From eb6a005087a143c787c38062757553240a9a1ca8 Mon Sep 17 00:00:00 2001 From: Cathy Zhao Date: Mon, 7 Jul 2025 20:03:56 +0000 Subject: [PATCH 2/2] cleaned up test balancers --- grpc/src/client/load_balancing/test_utils.rs | 125 +++++++------------ 1 file changed, 42 insertions(+), 83 deletions(-) diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index c1c4066df..9b4a45ad1 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -1,14 +1,25 @@ use crate::client::{ load_balancing::{ - ChannelController, ExternalSubchannel, ForwardingSubchannel, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, ParsedJsonLbConfig, Pick, PickResult, Picker, Subchannel, SubchannelState, WorkScheduler - }, name_resolution::Address, service_config::LbConfig, ConnectivityState + ChannelController, ExternalSubchannel, ForwardingSubchannel, LbPolicy, LbPolicyBuilder, + LbPolicyOptions, LbState, ParsedJsonLbConfig, Pick, PickResult, Picker, Subchannel, + SubchannelState, WorkScheduler, + }, + name_resolution::Address, + service_config::LbConfig, + ConnectivityState, }; use crate::service::{Message, Request, Response, Service}; -use std::{ - collections::HashMap, error::Error, fmt::Display, hash::{Hash, Hasher}, ops::Add, ptr, sync::Arc -}; use futures_util::future::ok; use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + error::Error, + fmt::Display, + hash::{Hash, Hasher}, + ops::Add, + ptr, + sync::Arc, +}; use tokio::{ sync::{mpsc, Notify}, task::AbortHandle, @@ -144,7 +155,10 @@ impl LbPolicy for MockBalancerOne { channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { if let Ok(ref endpoints) = update.endpoints { - let addresses: Vec<_> = endpoints.iter().flat_map(|ep| ep.addresses.clone()).collect(); + let addresses: Vec<_> = endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); let scl = SubchannelList::new(&addresses, channel_controller); self.subchannel_list = Some(scl); } @@ -163,11 +177,9 @@ impl LbPolicy for MockBalancerOne { ) { if let Some(ref mut scl) = self.subchannel_list { scl.update_subchannel_data(&subchannel, state); - println!("updating ready picker for mock balancer 2"); - // Optionally, send a picker update to simulate state change channel_controller.update_picker(LbState { connectivity_state: state.connectivity_state, - picker: Arc::new(MockPickerOne), // or MockPickerTwo + picker: Arc::new(MockPickerOne), }); } } @@ -181,6 +193,7 @@ impl LbPolicy for MockBalancerOne { } } + impl LbPolicy for MockBalancerTwo { fn resolver_update( &mut self, @@ -189,7 +202,10 @@ impl LbPolicy for MockBalancerTwo { channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { if let Ok(ref endpoints) = update.endpoints { - let addresses: Vec<_> = endpoints.iter().flat_map(|ep| ep.addresses.clone()).collect(); + let addresses: Vec<_> = endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); let scl = SubchannelList::new(&addresses, channel_controller); self.subchannel_list = Some(scl); } @@ -207,12 +223,9 @@ impl LbPolicy for MockBalancerTwo { channel_controller: &mut dyn ChannelController, ) { if let Some(ref mut scl) = self.subchannel_list { - // scl.update_subchannel_data(&subchannel, state); - println!("updating ready picker for mock balancer 2"); - // Optionally, send a picker update to simulate state change channel_controller.update_picker(LbState { connectivity_state: state.connectivity_state, - picker: Arc::new(MockPickerTwo), // or MockPickerTwo + picker: Arc::new(MockPickerTwo), }); } } @@ -231,19 +244,11 @@ pub static MOCK_POLICY_TWO: &str = "mock_policy_two"; struct MockPolicyOneBuilder {} struct MockPolicyTwoBuilder {} - impl LbPolicyBuilder for MockPolicyOneBuilder { fn build(&self, options: LbPolicyOptions) -> Box { Box::new(MockBalancerOne { - // work_scheduler: options.work_scheduler, subchannel_list: None, - // selected_subchannel: None, - // addresses: vec![], - // last_resolver_error: None, - // last_connection_error: None, connectivity_state: ConnectivityState::Connecting, - // sent_connecting_state: false, - // num_transient_failures: 0, }) } @@ -274,15 +279,8 @@ pub(super) struct MockConfig { impl LbPolicyBuilder for MockPolicyTwoBuilder { fn build(&self, options: LbPolicyOptions) -> Box { Box::new(MockBalancerTwo { - // work_scheduler: options.work_scheduler, subchannel_list: None, - // selected_subchannel: None, - // addresses: vec![], - // last_resolver_error: None, - // last_connection_error: None, connectivity_state: ConnectivityState::Connecting, - // sent_connecting_state: false, - // num_transient_failures: 0, }) } @@ -302,8 +300,6 @@ impl LbPolicyBuilder for MockPolicyTwoBuilder { }; Ok(Some(LbConfig::new(cfg))) } - - } #[derive(Clone)] @@ -341,15 +337,9 @@ impl SubchannelList { scl.ordered_subchannels.push(sc.clone()); scl.subchannels.insert(sc, SubchannelData::new()); } - - println!("created new subchannel list with {} subchannels", scl.len()); scl } - fn len(&self) -> usize { - self.ordered_subchannels.len() - } - fn subchannel_data(&self, sc: &Arc) -> Option { self.subchannels.get(sc).cloned() } @@ -384,49 +374,6 @@ impl SubchannelList { old_state } - - - // Initiates a connection attempt to the next subchannel in the list that is - // IDLE. Returns false if there are no more subchannels in the list. - fn connect_to_next_subchannel( - &mut self, - channel_controller: &mut dyn ChannelController, - ) -> bool { - // Special case for the first connection attempt, as current_idx is set - // to 0 when the subchannel list is created. - if self.current_idx != 0 { - self.current_idx += 1; - } - - for idx in self.current_idx..self.ordered_subchannels.len() { - // Grab the next subchannel and its data. - let sc = &self.ordered_subchannels[idx]; - let sc_data = self.subchannels.get(sc).unwrap(); - - match &sc_data.state { - Some(state) => { - if state.connectivity_state == ConnectivityState::Connecting - || state.connectivity_state == ConnectivityState::TransientFailure - { - self.current_idx += 1; - continue; - } else if state.connectivity_state == ConnectivityState::Idle { - sc.connect(); - return true; - } - } - None => { - debug_assert!( - false, - "No state available when asked to connect to subchannel: {}", - sc, - ); - } - } - } - false - } - } pub struct MockPickerOne; @@ -435,7 +382,13 @@ pub struct MockPickerTwo; impl Picker for MockPickerOne { fn pick(&self, _req: &Request) -> PickResult { PickResult::Pick(Pick { - subchannel: Arc::new(TestSubchannel::new(Address { address: "one".to_string(), ..Default::default() }, mpsc::unbounded_channel().0)), + subchannel: Arc::new(TestSubchannel::new( + Address { + address: "one".to_string(), + ..Default::default() + }, + mpsc::unbounded_channel().0, + )), on_complete: None, metadata: MetadataMap::new(), }) @@ -445,7 +398,13 @@ impl Picker for MockPickerOne { impl Picker for MockPickerTwo { fn pick(&self, _req: &Request) -> PickResult { PickResult::Pick(Pick { - subchannel: Arc::new(TestSubchannel::new(Address { address: "two".to_string(), ..Default::default() }, mpsc::unbounded_channel().0)), + subchannel: Arc::new(TestSubchannel::new( + Address { + address: "two".to_string(), + ..Default::default() + }, + mpsc::unbounded_channel().0, + )), on_complete: None, metadata: MetadataMap::new(), }) @@ -455,4 +414,4 @@ impl Picker for MockPickerTwo { pub fn reg() { super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyOneBuilder {}); super::GLOBAL_LB_REGISTRY.add_builder(MockPolicyTwoBuilder {}); -} \ No newline at end of file +}