Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions iroh-relay/src/protos/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ pub enum FrameType {
/// Payload is two big endian u32 durations in milliseconds: when to reconnect,
/// and how long to try total.
Restarting = 12,
/// Sent from server to client when the server closes the connection.
/// Contains the reason for closing.
Close = 13,
}

#[stack_error(derive, add_meta)]
Expand Down
57 changes: 57 additions & 0 deletions iroh-relay/src/protos/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ pub enum RelayToClientMsg {
/// until a problem exists.
problem: String,
},
#[deprecated(
since = "0.97.0",
note = "Frame is no longer used but kept in place for wire backwards compatibility"
)]
/// A one-way message from relay to client, advertising that the relay is restarting.
Restarting {
/// An advisory duration that the client should wait before attempting to reconnect.
Expand All @@ -110,6 +114,41 @@ pub enum RelayToClientMsg {
/// Reply to a [`ClientToRelayMsg::Ping`] from a client
/// with the payload sent previously in the ping.
Pong([u8; 8]),
/// Sent from the server before it closes the connection.
Close {
/// Contains the reason why the server chose to close the connection.
reason: CloseReason,
},
}

#[derive(Debug, Clone, PartialEq, Eq)]
/// Reason why a relay server closes a connection to a client.
pub enum CloseReason {
/// The relay server is shutting down.
Shutdown,
/// Another endpoint with the same endpoint id connected to the relay server.
///
/// When a new connection comes in from an endpoint id for which the server already has a connection,
/// the previous connection is terminated with this error.
SameEndpointIdConnected,
}

impl CloseReason {
#[cfg(feature = "server")]
fn to_u8(&self) -> u8 {
match self {
CloseReason::Shutdown => 0,
CloseReason::SameEndpointIdConnected => 1,
}
}

fn try_from_u8(value: u8) -> Option<Self> {
match value {
0 => Some(CloseReason::Shutdown),
1 => Some(CloseReason::SameEndpointIdConnected),
_ => None,
}
}
}

/// Messages that clients send to relays.
Expand Down Expand Up @@ -258,7 +297,9 @@ impl RelayToClientMsg {
Self::Ping { .. } => FrameType::Ping,
Self::Pong { .. } => FrameType::Pong,
Self::Health { .. } => FrameType::Health,
#[allow(deprecated, reason = "kept for wire backwards compatibility")]
Self::Restarting { .. } => FrameType::Restarting,
Self::Close { .. } => FrameType::Close,
}
}

Expand Down Expand Up @@ -293,13 +334,17 @@ impl RelayToClientMsg {
Self::Health { problem } => {
dst.put(problem.as_ref());
}
#[allow(deprecated, reason = "kept for wire backwards compatibility")]
Self::Restarting {
reconnect_in,
try_for,
} => {
dst.put_u32(reconnect_in.as_millis() as u32);
dst.put_u32(try_for.as_millis() as u32);
}
Self::Close { reason } => {
dst.put_u8(reason.to_u8());
}
}
dst
}
Expand All @@ -314,10 +359,12 @@ impl RelayToClientMsg {
Self::EndpointGone(_) => 32,
Self::Ping(_) | Self::Pong(_) => 8,
Self::Health { problem } => problem.len(),
#[allow(deprecated, reason = "kept for wire backwards compatibility")]
Self::Restarting { .. } => {
4 // u32
+ 4 // u32
}
Self::Close { .. } => 1,
};
self.typ().encoded_len() + payload_len
}
Expand Down Expand Up @@ -383,11 +430,19 @@ impl RelayToClientMsg {
);
let reconnect_in = Duration::from_millis(reconnect_in as u64);
let try_for = Duration::from_millis(try_for as u64);
#[allow(deprecated, reason = "kept for wire backwards compatibility")]
Self::Restarting {
reconnect_in,
try_for,
}
}
FrameType::Close => {
ensure!(content.len() == 1, Error::InvalidFrame);
let value = content.get_u8();
let reason =
CloseReason::try_from_u8(value).ok_or_else(|| e!(Error::InvalidFrame))?;
Self::Close { reason }
}
_ => {
return Err(e!(Error::InvalidFrameType { frame_type }));
}
Expand Down Expand Up @@ -592,6 +647,7 @@ mod tests {
48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
),
(
#[allow(deprecated)]
RelayToClientMsg::Restarting {
reconnect_in: Duration::from_millis(10),
try_for: Duration::from_millis(20),
Expand Down Expand Up @@ -725,6 +781,7 @@ mod proptests {
})
.prop_map(|problem| RelayToClientMsg::Health { problem });
let restarting = (any::<u32>(), any::<u32>()).prop_map(|(reconnect_in, try_for)| {
#[allow(deprecated)]
RelayToClientMsg::Restarting {
reconnect_in: Duration::from_millis(reconnect_in.into()),
try_for: Duration::from_millis(try_for.into()),
Expand Down
77 changes: 45 additions & 32 deletions iroh-relay/src/server/client.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
//! The server-side representation of an ongoing client relaying connection.

use std::{collections::HashSet, sync::Arc, time::Duration};
use std::{
collections::HashSet,
sync::{Arc, Mutex},
time::Duration,
};

use iroh_base::EndpointId;
use n0_error::{e, stack_error};
use n0_future::{SinkExt, StreamExt};
use rand::Rng;
use time::{Date, OffsetDateTime};
use tokio::{
sync::mpsc::{self, error::TrySendError},
sync::{
mpsc::{self, error::TrySendError},
oneshot,
},
time::MissedTickBehavior,
};
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tokio_util::task::AbortOnDropHandle;
use tracing::{Instrument, debug, trace, warn};

use crate::{
PingTracker,
protos::{
relay::{ClientToRelayMsg, Datagrams, PING_INTERVAL, RelayToClientMsg},
relay::{ClientToRelayMsg, CloseReason, Datagrams, PING_INTERVAL, RelayToClientMsg},
streams::BytesStreamSink,
},
server::{
Expand Down Expand Up @@ -62,7 +69,7 @@ pub struct Client {
/// Connection identifier.
connection_id: u64,
/// Used to close the connection loop.
done: CancellationToken,
done_s: Mutex<Option<oneshot::Sender<Option<CloseReason>>>>,
/// Actor handle.
handle: AbortOnDropHandle<()>,
/// Queue of packets intended for the client.
Expand Down Expand Up @@ -91,10 +98,9 @@ impl Client {
channel_capacity,
} = config;

let done = CancellationToken::new();
let (send_queue_s, send_queue_r) = mpsc::channel(channel_capacity);

let (peer_gone_s, peer_gone_r) = mpsc::channel(channel_capacity);
let (done_s, done_r) = oneshot::channel();

let actor = Actor {
stream,
Expand All @@ -110,8 +116,7 @@ impl Client {
};

// start io loop
let io_done = done.clone();
let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
let handle = tokio::task::spawn(actor.run(done_r).instrument(tracing::info_span!(
"client-connection-actor",
remote_endpoint = %endpoint_id.fmt_short(),
connection_id = connection_id
Expand All @@ -121,7 +126,7 @@ impl Client {
endpoint_id,
connection_id,
handle: AbortOnDropHandle::new(handle),
done,
done_s: Mutex::new(Some(done_s)),
send_queue: send_queue_s,
peer_gone: peer_gone_s,
}
Expand All @@ -134,8 +139,8 @@ impl Client {
/// Shutdown the reader and writer loops and closes the connection.
///
/// Any shutdown errors will be logged as warnings.
pub(super) async fn shutdown(self) {
self.start_shutdown();
pub(super) async fn shutdown(self, reason: CloseReason) {
self.start_shutdown(Some(reason));
if let Err(e) = self.handle.await {
warn!(
remote_endpoint = %self.endpoint_id.fmt_short(),
Expand All @@ -145,8 +150,10 @@ impl Client {
}

/// Starts the process of shutdown.
pub(super) fn start_shutdown(&self) {
self.done.cancel();
pub(super) fn start_shutdown(&self, reason: Option<CloseReason>) {
if let Some(sender) = self.done_s.lock().expect("poisoned").take() {
sender.send(reason).ok();
}
}

pub(super) fn try_send_packet(
Expand Down Expand Up @@ -205,7 +212,7 @@ pub enum RunError {
source: ForwardPacketError,
},
#[error("Flush")]
Flush {},
CloseFlush {},
#[error(transparent)]
HandleFrame {
#[error(from)]
Expand All @@ -217,10 +224,8 @@ pub enum RunError {
PacketSend { source: WriteFrameError },
#[error("Server.endpoint_gone dropped")]
EndpointGoneDrop {},
#[error("EndpointGone write frame failed")]
EndpointGoneWriteFrame { source: WriteFrameError },
#[error("Keep alive write frame failed")]
KeepAliveWriteFrame { source: WriteFrameError },
#[error("Writing frame failed")]
WriteFrame { source: WriteFrameError },
#[error("Tick flush")]
TickFlush {},
}
Expand Down Expand Up @@ -268,15 +273,15 @@ impl<S> Actor<S>
where
S: BytesStreamSink,
{
async fn run(mut self, done: CancellationToken) {
async fn run(mut self, done_r: oneshot::Receiver<Option<CloseReason>>) {
// Note the accept and disconnects metrics must be in a pair. Technically the
// connection is accepted long before this in the HTTP server, but it is clearer to
// handle the metric here.
self.metrics.accepts.inc();
if self.client_counter.update(self.endpoint_id) {
self.metrics.unique_client_keys.inc();
}
match self.run_inner(done).await {
match self.run_inner(done_r).await {
Err(e) => {
warn!("actor errored {e:#}, exiting");
}
Expand All @@ -290,7 +295,10 @@ where
self.metrics.disconnects.inc();
}

async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> {
async fn run_inner(
&mut self,
mut done_r: oneshot::Receiver<Option<CloseReason>>,
) -> Result<(), RunError> {
// Add some jitter to ping pong interactions, to avoid all pings being sent at the same time
let next_interval = || {
let random_secs = rand::rng().random_range(1..=5);
Expand All @@ -306,10 +314,16 @@ where
tokio::select! {
biased;

_ = done.cancelled() => {
trace!("actor loop cancelled, exiting");
// final flush
self.stream.flush().await.map_err(|_| e!(RunError::Flush))?;
reason = &mut done_r => {
trace!("actor loop cancelled, exiting (reason: {reason:?})");
if let Ok(Some(reason)) = reason {
self.write_frame(RelayToClientMsg::Close { reason }).await
.map_err(|err| e!(RunError::WriteFrame, err))?;
}
self.stream
.flush()
.await
.map_err(|_| e!(RunError::CloseFlush))?;
break;
}
maybe_frame = self.stream.next() => {
Expand All @@ -332,7 +346,7 @@ where
trace!("endpoint_id gone: {:?}", endpoint_id);
self.write_frame(RelayToClientMsg::EndpointGone(endpoint_id))
.await
.map_err(|err| e!(RunError::EndpointGoneWriteFrame, err))?;
.map_err(|err| e!(RunError::WriteFrame, err))?;
}
_ = self.ping_tracker.timeout() => {
trace!("pong timed out");
Expand All @@ -345,7 +359,7 @@ where
let data = self.ping_tracker.new_ping();
self.write_frame(RelayToClientMsg::Ping(data))
.await
.map_err(|err| e!(RunError::KeepAliveWriteFrame, err))?;
.map_err(|err| e!(RunError::WriteFrame, err))?;
}
}

Expand Down Expand Up @@ -538,6 +552,7 @@ mod tests {

let (send_queue_s, send_queue_r) = mpsc::channel(10);
let (peer_gone_s, peer_gone_r) = mpsc::channel(10);
let (done_s, done_r) = oneshot::channel();

let endpoint_id = SecretKey::generate(&mut rng).public();
let (io, io_rw) = tokio::io::duplex(1024);
Expand All @@ -559,9 +574,7 @@ mod tests {
metrics,
};

let done = CancellationToken::new();
let io_done = done.clone();
let handle = tokio::task::spawn(async move { actor.run(io_done).await });
let handle = tokio::task::spawn(async move { actor.run(done_r).await });

// Write tests
println!("-- write");
Expand Down Expand Up @@ -621,7 +634,7 @@ mod tests {
.await
.std_context("send")?;

done.cancel();
done_s.send(None).ok();
handle.await.std_context("join")?;
Ok(())
}
Expand Down
Loading
Loading