From 7b6af3988204540208e619c67fbb4b5e42153393 Mon Sep 17 00:00:00 2001 From: Ragnt Date: Sat, 12 Oct 2024 22:30:53 -0400 Subject: [PATCH 1/7] Add ability to resolve (and cache) multicast groups for a given family --- src/handle.rs | 18 ++++++- src/resolver.rs | 139 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 155 insertions(+), 2 deletions(-) diff --git a/src/handle.rs b/src/handle.rs index 0102789..6f2cdbd 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -6,13 +6,14 @@ use crate::{ resolver::Resolver, }; use futures::{lock::Mutex, Stream, StreamExt}; +use log::trace; use netlink_packet_core::{ DecodeError, Emitable, NetlinkMessage, NetlinkPayload, ParseableParametrized, }; use netlink_packet_generic::{GenlFamily, GenlHeader, GenlMessage}; use netlink_proto::{sys::SocketAddr, ConnectionHandle}; -use std::{fmt::Debug, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; /// The generic netlink connection handle /// @@ -69,6 +70,21 @@ impl GenetlinkHandle { .await } + /// Resolve the multicast groups of the given [`GenlFamily`]. + pub async fn resolve_mcast_groups( + &self, + ) -> Result, GenetlinkError> + where + F: GenlFamily, + { + trace!("Requesting Groups from Resolver: {:?}", F::family_name()); + self.resolver + .lock() + .await + .query_family_multicast_groups(self, F::family_name()) + .await + } + /// Clear the resolver's fanily id cache pub async fn clear_family_id_cache(&self) { self.resolver.lock().await.clear_cache(); diff --git a/src/resolver.rs b/src/resolver.rs index ee46ae0..a376860 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -2,9 +2,10 @@ use crate::{error::GenetlinkError, GenetlinkHandle}; use futures::{future::Either, StreamExt}; +use log::{error, trace, warn}; use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST}; use netlink_packet_generic::{ - ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + ctrl::{nlas::{GenlCtrlAttrs, McastGrpAttrs}, GenlCtrl, GenlCtrlCmd}, GenlMessage, }; use std::{collections::HashMap, future::Future}; @@ -12,12 +13,14 @@ use std::{collections::HashMap, future::Future}; #[derive(Clone, Debug, Default)] pub struct Resolver { cache: HashMap<&'static str, u16>, + groups_cache: HashMap<&'static str, HashMap> } impl Resolver { pub fn new() -> Self { Self { cache: HashMap::new(), + groups_cache: HashMap::new(), } } @@ -25,6 +28,10 @@ impl Resolver { self.cache.get(family_name).copied() } + pub fn get_groups_cache_by_name(&self, family_name: &str) -> Option> { + self.groups_cache.get(family_name).cloned() + } + pub fn query_family_id( &mut self, handle: &GenetlinkHandle, @@ -85,9 +92,112 @@ impl Resolver { } } + pub fn query_family_multicast_groups( + &mut self, + handle: &GenetlinkHandle, + family_name: &'static str, + ) -> impl Future, GenetlinkError>> + '_ { + let mut handle = handle.clone(); + async move { + trace!("Starting query_family_multicast_groups for family_name: '{}'", family_name); + + // First, get the family ID (this uses your existing method) + trace!("Calling query_family_id for family_name: '{}'", family_name); + let family_id = self.query_family_id(&handle, family_name).await?; + trace!("Received family_id: {}", family_id); + + // Create the request message to get family details + trace!("Creating GenlMessage for CTRL_CMD_GETFAMILY"); + let mut genlmsg: GenlMessage = GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyId(family_id)], + }); + genlmsg.finalize(); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST; + nlmsg.finalize(); + trace!("NetlinkMessage created: {:?}", nlmsg); + + // Send the request + trace!("Sending NetlinkMessage to netlink socket"); + let mut res = handle.send_request(nlmsg)?; + trace!("Request sent, awaiting response"); + + // Prepare to collect multicast groups + let mut mc_groups = HashMap::new(); + + // Process the response + trace!("Processing responses"); + while let Some(result) = res.next().await { + trace!("Received a response"); + let rx_packet = result?; + trace!("Received NetlinkMessage: {:?}", rx_packet); + match rx_packet.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + trace!("Processing InnerMessage: {:?}", genlmsg); + for nla in genlmsg.payload.nlas { + trace!("Processing NLA: {:?}", nla); + if let GenlCtrlAttrs::McastGroups(groups) = nla { + trace!("Found McastGroups: {:?}", groups); + for group in groups { + // 'group' is a Vec + let mut group_name = None; + let mut group_id = None; + + for group_attr in group { + trace!("Processing group_attr: {:?}", group_attr); + match group_attr { + McastGrpAttrs::Name(ref name) => { + group_name = Some(name.clone()); + trace!("Found group name: '{}'", name); + } + McastGrpAttrs::Id(id) => { + group_id = Some(id); + trace!("Found group id: {}", id); + } + } + } + + if let (Some(name), Some(id)) = (group_name, group_id) { + mc_groups.insert(name.clone(), id); + trace!( + "Inserted group '{}' with id {} into mc_groups", + name, + id + ); + } + } + } else { + trace!("Unhandled NLA: {:?}", nla); + } + } + } + NetlinkPayload::Error(e) => { + error!("Received NetlinkPayload::Error: {:?}", e); + return Err(e.into()); + } + other => { + warn!("Received unexpected NetlinkPayload: {:?}", other); + } + } + } + trace!("Finished processing responses"); + + // Update the cache + self.groups_cache.insert(family_name, mc_groups.clone()); + trace!("Updated groups_cache for family_name: '{}'", family_name); + + trace!("Returning mc_groups: {:?}", mc_groups); + Ok(mc_groups) + } + } + + pub fn clear_cache(&mut self) { self.cache.clear(); + self.groups_cache.clear(); } + } #[cfg(all(test, feature = "tokio_socket"))] @@ -152,6 +262,33 @@ mod test { let cache = resolver.get_cache_by_name(name).unwrap(); assert_eq!(id, cache); + + let mcast_groups = resolver + .query_family_multicast_groups(&handle, name) + .await + .or_else(|e| { + if let GenetlinkError::NetlinkError(io_err) = &e { + if io_err.kind() == ErrorKind::NotFound { + // Ignore non exist entries + Ok(0) + } else { + Err(e) + } + } else { + Err(e) + } + }) + .unwrap(); + if mcast_groups.is_empty() { + log::warn!( + "Generic family \"{name}\" not exist or not loaded \ + in this environment. Ignored." + ); + continue; + } + + let cache = resolver.get_groups_cache_by_name(name).unwrap(); + assert_eq!(mcast_groups, cache); log::warn!("{:?}", (name, cache)); } } From 4890961019997b3b501c6dacf395c6a96ab8b3e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:23:11 +0200 Subject: [PATCH 2/7] Remove tracing logs --- src/handle.rs | 2 -- src/resolver.rs | 88 ++++++++++++++++--------------------------------- 2 files changed, 28 insertions(+), 62 deletions(-) diff --git a/src/handle.rs b/src/handle.rs index 6f2cdbd..d9245cd 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -6,7 +6,6 @@ use crate::{ resolver::Resolver, }; use futures::{lock::Mutex, Stream, StreamExt}; -use log::trace; use netlink_packet_core::{ DecodeError, Emitable, NetlinkMessage, NetlinkPayload, ParseableParametrized, @@ -77,7 +76,6 @@ impl GenetlinkHandle { where F: GenlFamily, { - trace!("Requesting Groups from Resolver: {:?}", F::family_name()); self.resolver .lock() .await diff --git a/src/resolver.rs b/src/resolver.rs index a376860..d8c1eaf 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -2,10 +2,12 @@ use crate::{error::GenetlinkError, GenetlinkHandle}; use futures::{future::Either, StreamExt}; -use log::{error, trace, warn}; use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST}; use netlink_packet_generic::{ - ctrl::{nlas::{GenlCtrlAttrs, McastGrpAttrs}, GenlCtrl, GenlCtrlCmd}, + ctrl::{ + nlas::{GenlCtrlAttrs, McastGrpAttrs}, + GenlCtrl, GenlCtrlCmd, + }, GenlMessage, }; use std::{collections::HashMap, future::Future}; @@ -13,7 +15,7 @@ use std::{collections::HashMap, future::Future}; #[derive(Clone, Debug, Default)] pub struct Resolver { cache: HashMap<&'static str, u16>, - groups_cache: HashMap<&'static str, HashMap> + groups_cache: HashMap<&'static str, HashMap>, } impl Resolver { @@ -28,7 +30,10 @@ impl Resolver { self.cache.get(family_name).copied() } - pub fn get_groups_cache_by_name(&self, family_name: &str) -> Option> { + pub fn get_groups_cache_by_name( + &self, + family_name: &str, + ) -> Option> { self.groups_cache.get(family_name).cloned() } @@ -96,108 +101,79 @@ impl Resolver { &mut self, handle: &GenetlinkHandle, family_name: &'static str, - ) -> impl Future, GenetlinkError>> + '_ { + ) -> impl Future, GenetlinkError>> + '_ + { let mut handle = handle.clone(); async move { - trace!("Starting query_family_multicast_groups for family_name: '{}'", family_name); - - // First, get the family ID (this uses your existing method) - trace!("Calling query_family_id for family_name: '{}'", family_name); let family_id = self.query_family_id(&handle, family_name).await?; - trace!("Received family_id: {}", family_id); - + // Create the request message to get family details - trace!("Creating GenlMessage for CTRL_CMD_GETFAMILY"); - let mut genlmsg: GenlMessage = GenlMessage::from_payload(GenlCtrl { - cmd: GenlCtrlCmd::GetFamily, - nlas: vec![GenlCtrlAttrs::FamilyId(family_id)], - }); + let mut genlmsg: GenlMessage = + GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyId(family_id)], + }); genlmsg.finalize(); let mut nlmsg = NetlinkMessage::from(genlmsg); nlmsg.header.flags = NLM_F_REQUEST; nlmsg.finalize(); - trace!("NetlinkMessage created: {:?}", nlmsg); - + // Send the request - trace!("Sending NetlinkMessage to netlink socket"); let mut res = handle.send_request(nlmsg)?; - trace!("Request sent, awaiting response"); - + // Prepare to collect multicast groups let mut mc_groups = HashMap::new(); - + // Process the response - trace!("Processing responses"); while let Some(result) = res.next().await { - trace!("Received a response"); let rx_packet = result?; - trace!("Received NetlinkMessage: {:?}", rx_packet); match rx_packet.payload { NetlinkPayload::InnerMessage(genlmsg) => { - trace!("Processing InnerMessage: {:?}", genlmsg); for nla in genlmsg.payload.nlas { - trace!("Processing NLA: {:?}", nla); if let GenlCtrlAttrs::McastGroups(groups) = nla { - trace!("Found McastGroups: {:?}", groups); for group in groups { // 'group' is a Vec let mut group_name = None; let mut group_id = None; - + for group_attr in group { - trace!("Processing group_attr: {:?}", group_attr); match group_attr { McastGrpAttrs::Name(ref name) => { group_name = Some(name.clone()); - trace!("Found group name: '{}'", name); } McastGrpAttrs::Id(id) => { group_id = Some(id); - trace!("Found group id: {}", id); } } } - - if let (Some(name), Some(id)) = (group_name, group_id) { + + if let (Some(name), Some(id)) = + (group_name, group_id) + { mc_groups.insert(name.clone(), id); - trace!( - "Inserted group '{}' with id {} into mc_groups", - name, - id - ); } } - } else { - trace!("Unhandled NLA: {:?}", nla); } } } NetlinkPayload::Error(e) => { - error!("Received NetlinkPayload::Error: {:?}", e); return Err(e.into()); } - other => { - warn!("Received unexpected NetlinkPayload: {:?}", other); - } + _ => (), } } - trace!("Finished processing responses"); - + // Update the cache self.groups_cache.insert(family_name, mc_groups.clone()); - trace!("Updated groups_cache for family_name: '{}'", family_name); - - trace!("Returning mc_groups: {:?}", mc_groups); + Ok(mc_groups) } } - pub fn clear_cache(&mut self) { self.cache.clear(); self.groups_cache.clear(); } - } #[cfg(all(test, feature = "tokio_socket"))] @@ -253,10 +229,6 @@ mod test { }) .unwrap(); if id == 0 { - log::warn!( - "Generic family \"{name}\" not exist or not loaded \ - in this environment. Ignored." - ); continue; } @@ -280,10 +252,6 @@ mod test { }) .unwrap(); if mcast_groups.is_empty() { - log::warn!( - "Generic family \"{name}\" not exist or not loaded \ - in this environment. Ignored." - ); continue; } From 9a9ef7191dffec87f54253706aa195dbb229beff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:27:28 +0200 Subject: [PATCH 3/7] Fix test --- src/resolver.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/resolver.rs b/src/resolver.rs index d8c1eaf..432a53f 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -242,7 +242,7 @@ mod test { if let GenetlinkError::NetlinkError(io_err) = &e { if io_err.kind() == ErrorKind::NotFound { // Ignore non exist entries - Ok(0) + Ok(HashMap::new()) } else { Err(e) } From 70a4b52ff38e8e7dad061c7904202f7df8fe8c03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:32:10 +0200 Subject: [PATCH 4/7] Use cached group names if filled --- src/resolver.rs | 106 ++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 49 deletions(-) diff --git a/src/resolver.rs b/src/resolver.rs index 432a53f..5f7863f 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -103,70 +103,78 @@ impl Resolver { family_name: &'static str, ) -> impl Future, GenetlinkError>> + '_ { - let mut handle = handle.clone(); - async move { - let family_id = self.query_family_id(&handle, family_name).await?; - - // Create the request message to get family details - let mut genlmsg: GenlMessage = - GenlMessage::from_payload(GenlCtrl { - cmd: GenlCtrlCmd::GetFamily, - nlas: vec![GenlCtrlAttrs::FamilyId(family_id)], - }); - genlmsg.finalize(); - let mut nlmsg = NetlinkMessage::from(genlmsg); - nlmsg.header.flags = NLM_F_REQUEST; - nlmsg.finalize(); + if let Some(groups) = self.get_groups_cache_by_name(family_name) { + Either::Left(futures::future::ready(Ok(groups))) + } else { + let mut handle = handle.clone(); + Either::Right(async move { + // Create the request message to get family details + let mut genlmsg: GenlMessage = + GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName( + family_name.to_owned(), + )], + }); + genlmsg.finalize(); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST; + nlmsg.finalize(); - // Send the request - let mut res = handle.send_request(nlmsg)?; + // Send the request + let mut res = handle.send_request(nlmsg)?; - // Prepare to collect multicast groups - let mut mc_groups = HashMap::new(); + // Prepare to collect multicast groups + let mut mc_groups = HashMap::new(); - // Process the response - while let Some(result) = res.next().await { - let rx_packet = result?; - match rx_packet.payload { - NetlinkPayload::InnerMessage(genlmsg) => { - for nla in genlmsg.payload.nlas { - if let GenlCtrlAttrs::McastGroups(groups) = nla { - for group in groups { - // 'group' is a Vec - let mut group_name = None; - let mut group_id = None; + // Process the response + while let Some(result) = res.next().await { + let rx_packet = result?; + match rx_packet.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + for nla in genlmsg.payload.nlas { + if let GenlCtrlAttrs::McastGroups(groups) = nla + { + for group in groups { + // 'group' is a Vec + let mut group_name = None; + let mut group_id = None; - for group_attr in group { - match group_attr { - McastGrpAttrs::Name(ref name) => { - group_name = Some(name.clone()); - } - McastGrpAttrs::Id(id) => { - group_id = Some(id); + for group_attr in group { + match group_attr { + McastGrpAttrs::Name( + ref name, + ) => { + group_name = + Some(name.clone()); + } + McastGrpAttrs::Id(id) => { + group_id = Some(id); + } } } - } - if let (Some(name), Some(id)) = - (group_name, group_id) - { - mc_groups.insert(name.clone(), id); + if let (Some(name), Some(id)) = + (group_name, group_id) + { + mc_groups.insert(name.clone(), id); + } } } } } + NetlinkPayload::Error(e) => { + return Err(e.into()); + } + _ => (), } - NetlinkPayload::Error(e) => { - return Err(e.into()); - } - _ => (), } - } - // Update the cache - self.groups_cache.insert(family_name, mc_groups.clone()); + // Update the cache + self.groups_cache.insert(family_name, mc_groups.clone()); - Ok(mc_groups) + Ok(mc_groups) + }) } } From 663445843cff7e7007967672454c31600da1356d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:36:08 +0200 Subject: [PATCH 5/7] Reduce left drift in impl --- src/resolver.rs | 54 +++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/src/resolver.rs b/src/resolver.rs index 5f7863f..90dcfdc 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -129,38 +129,34 @@ impl Resolver { // Process the response while let Some(result) = res.next().await { - let rx_packet = result?; - match rx_packet.payload { + match result?.payload { NetlinkPayload::InnerMessage(genlmsg) => { - for nla in genlmsg.payload.nlas { - if let GenlCtrlAttrs::McastGroups(groups) = nla - { - for group in groups { - // 'group' is a Vec - let mut group_name = None; - let mut group_id = None; - - for group_attr in group { - match group_attr { - McastGrpAttrs::Name( - ref name, - ) => { - group_name = - Some(name.clone()); - } - McastGrpAttrs::Id(id) => { - group_id = Some(id); - } - } - } - - if let (Some(name), Some(id)) = - (group_name, group_id) - { - mc_groups.insert(name.clone(), id); - } + // One specific family id was requested, it can be + // assumed, that the mcast + // groups are part of that family. + let Some(mcast_groups) = genlmsg + .payload + .nlas + .into_iter() + .filter_map(|attr| match attr { + GenlCtrlAttrs::McastGroups(groups) => { + Some(groups) } + _ => None, + }) + .next() + else { + continue; + }; + + for group in mcast_groups.into_iter().filter_map(|attrs| { + match attrs.as_slice() { + [McastGrpAttrs::Name(name), McastGrpAttrs::Id(i)] | + [McastGrpAttrs::Id(i), McastGrpAttrs::Name(name)] => Some((name.clone(), *i)), + _ => None } + }) { + mc_groups.insert(group.0, group.1); } } NetlinkPayload::Error(e) => { From 5e81adf1d51f5f2cacc5bf6ccad69aed2989290c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:39:57 +0200 Subject: [PATCH 6/7] Fix clippy --- src/handle.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/handle.rs b/src/handle.rs index d9245cd..63179c5 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -36,13 +36,14 @@ use std::{collections::HashMap, fmt::Debug, sync::Arc}; /// 2. Query the family id using the builtin resolver. /// 3. If the id is in the cache, returning the id in the cache and skip step 4. /// 4. The resolver sends `CTRL_CMD_GETFAMILY` request to get the id and records -/// it in the cache. 5. fill the family id using -/// [`GenlMessage::set_resolved_family_id()`]. 6. Serialize the payload to -/// [`RawGenlMessage`]. 7. Send it through the connection. -/// - The family id filled into `message_type` field in -/// [`NetlinkMessage::finalize()`]. +/// it in the cache. +/// 5. fill the family id using [`GenlMessage::set_resolved_family_id()`]. +/// 6. Serialize the payload to [`RawGenlMessage`]. +/// 7. Send it through the connection. +/// - The family id filled into `message_type` field in +/// [`NetlinkMessage::finalize()`]. /// 8. In the response stream, deserialize the payload back to -/// [`GenlMessage`]. +/// [`GenlMessage`]. #[derive(Clone, Debug)] pub struct GenetlinkHandle { handle: ConnectionHandle, From 2676788960d50976c84f9940ad8f8e26387eb7e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Vi=C3=B6l?= Date: Fri, 5 Sep 2025 14:48:06 +0200 Subject: [PATCH 7/7] Fix typo --- src/handle.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/handle.rs b/src/handle.rs index 63179c5..8d0ce71 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -84,7 +84,7 @@ impl GenetlinkHandle { .await } - /// Clear the resolver's fanily id cache + /// Clear the resolver's family id cache pub async fn clear_family_id_cache(&self) { self.resolver.lock().await.clear_cache(); }