diff --git a/netwatch/src/udp.rs b/netwatch/src/udp.rs index bbc1a8f..6cfc278 100644 --- a/netwatch/src/udp.rs +++ b/netwatch/src/udp.rs @@ -688,6 +688,8 @@ enum SocketState { addr: SocketAddr, }, Closed { + /// The addr to rebind to when recovering. + addr: SocketAddr, last_max_gso_segments: NonZeroUsize, last_gro_segments: NonZeroUsize, last_may_fragment: bool, @@ -768,25 +770,34 @@ impl SocketState { } fn rebind(&mut self) -> io::Result<()> { - let (addr, closed_state) = match self { - Self::Connected { state, addr, .. } => { - let s = SocketState::Closed { - last_max_gso_segments: state.max_gso_segments(), - last_gro_segments: state.gro_segments(), - last_may_fragment: state.may_fragment(), - }; - (*addr, s) - } - Self::Closed { .. } => { - return Err(io::Error::other("socket is closed and cannot be rebound")); - } + let addr = match self { + Self::Connected { addr, .. } => *addr, + Self::Closed { addr, .. } => *addr, }; debug!("rebinding {}", addr); - *self = closed_state; - *self = Self::bind(addr)?; + // Transition to Closed first to drop the old socket. + // This is needed so the port is released before we try to bind again. + if let Self::Connected { state, .. } = self { + *self = SocketState::Closed { + addr, + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + } - Ok(()) + match Self::bind(addr) { + Ok(new_state) => { + *self = new_state; + Ok(()) + } + Err(err) => { + // Stay in Closed state but allow future rebind attempts + debug!("rebind failed, will retry on next attempt: {}", err); + Err(err) + } + } } fn is_closed(&self) -> bool { @@ -795,8 +806,9 @@ impl SocketState { fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { match self { - Self::Connected { state, .. } => { + Self::Connected { state, addr, .. } => { let s = SocketState::Closed { + addr: *addr, last_max_gso_segments: state.max_gso_segments(), last_gro_segments: state.gro_segments(), last_may_fragment: state.may_fragment(),