From a5b185b4cbcb07de839bcc00c00d18ebd3503b69 Mon Sep 17 00:00:00 2001 From: David Venhoek Date: Mon, 16 Feb 2026 14:59:32 +0100 Subject: [PATCH 1/3] Move shared wire structures for pairing and connection into common module. --- .../examples/pairing-client.rs | 5 +- .../examples/pairing-server.rs | 5 +- s2energy-connection/src/common/mod.rs | 1 + s2energy-connection/src/common/wire.rs | 113 +++++++++++++ s2energy-connection/src/lib.rs | 3 + s2energy-connection/src/pairing/client.rs | 1 + s2energy-connection/src/pairing/mod.rs | 23 ++- s2energy-connection/src/pairing/server.rs | 7 +- s2energy-connection/src/pairing/wire.rs | 154 +++--------------- 9 files changed, 166 insertions(+), 146 deletions(-) create mode 100644 s2energy-connection/src/common/mod.rs create mode 100644 s2energy-connection/src/common/wire.rs diff --git a/s2energy-connection/examples/pairing-client.rs b/s2energy-connection/examples/pairing-client.rs index 5cd6417..517d47b 100644 --- a/s2energy-connection/examples/pairing-client.rs +++ b/s2energy-connection/examples/pairing-client.rs @@ -1,7 +1,8 @@ use std::sync::Arc; -use s2energy_connection::pairing::{ - Client, ClientConfig, Deployment, EndpointConfig, MessageVersion, PairingRemote, S2NodeDescription, S2NodeId, S2Role, +use s2energy_connection::{ + Deployment, MessageVersion, S2NodeDescription, S2NodeId, S2Role, + pairing::{Client, ClientConfig, EndpointConfig, PairingRemote}, }; const PAIRING_TOKEN: &[u8] = &[1, 2, 3]; diff --git a/s2energy-connection/examples/pairing-server.rs b/s2energy-connection/examples/pairing-server.rs index 8a35f36..b3bab72 100644 --- a/s2energy-connection/examples/pairing-server.rs +++ b/s2energy-connection/examples/pairing-server.rs @@ -2,8 +2,9 @@ use axum_server::tls_rustls::RustlsConfig; use rustls::pki_types::{CertificateDer, pem::PemObject}; use std::{net::SocketAddr, path::PathBuf, sync::Arc}; -use s2energy_connection::pairing::{ - EndpointConfig, MessageVersion, PairingToken, S2NodeDescription, S2NodeId, S2Role, Server, ServerConfig, +use s2energy_connection::{ + MessageVersion, S2NodeDescription, S2NodeId, S2Role, + pairing::{EndpointConfig, PairingToken, Server, ServerConfig}, }; #[allow(unused)] diff --git a/s2energy-connection/src/common/mod.rs b/s2energy-connection/src/common/mod.rs new file mode 100644 index 0000000..c70eda0 --- /dev/null +++ b/s2energy-connection/src/common/mod.rs @@ -0,0 +1 @@ +pub(crate) mod wire; diff --git a/s2energy-connection/src/common/wire.rs b/s2energy-connection/src/common/wire.rs new file mode 100644 index 0000000..a110e68 --- /dev/null +++ b/s2energy-connection/src/common/wire.rs @@ -0,0 +1,113 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[serde(rename_all = "lowercase")] +pub(crate) enum PairingVersion { + V1, +} + +#[derive(Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[serde(rename_all = "lowercase")] +pub(crate) enum WirePairingVersion { + V1, + #[serde(other)] + Other, +} + +impl TryFrom for PairingVersion { + type Error = (); + + fn try_from(value: WirePairingVersion) -> Result { + match value { + WirePairingVersion::V1 => Ok(PairingVersion::V1), + WirePairingVersion::Other => Err(()), + } + } +} + +/// Message schema version. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct MessageVersion(pub String); + +/// Information about the pairing endpoint of a S2 node +#[derive(Default, Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct S2EndpointDescription { + /// Name of the endpoint + #[serde(default)] + pub name: Option, + /// URI of a logo to be used for the endpoint in GUIs + #[serde(default)] + pub logo_uri: Option, + /// Type of deployment used by the endpoint (local or globally routable). + #[serde(default)] + pub deployment: Option, +} + +/// One-time access token for secure access to the S2 message communication channel. It must be renewed every time a client wants to access +/// the S2 message communication channel by calling the requestToken endpoint. This token is valid for one time login, with a maximum 5 +/// years, and should have a minimum length of 32 bytes. +#[derive(Serialize, Deserialize, Clone)] +pub struct AccessToken(pub String); + +impl AccessToken { + pub fn new(rng: &mut impl rand::Rng) -> Self { + use base64::{Engine as _, engine::general_purpose::STANDARD}; + + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + + let encoded = STANDARD.encode(bytes); + Self(encoded) + } +} + +/// Unique identifier of the S2 node +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub struct S2NodeId(pub String); + +/// Information about the S2 node +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct S2NodeDescription { + /// Unique identifier of the node + pub id: S2NodeId, + /// Brandname used for the node + pub brand: String, + /// URI of a logo to be used for the node in GUIs + #[serde(default)] + pub logo_uri: Option, + /// The type of this node. + pub type_: String, + /// Model name of the device this node belongs to. + pub model_name: String, + /// A name for the device configured by the end user/owner. + #[serde(default)] + pub user_defined_name: Option, + /// The S2 role this device has (e.g. CEM or RM). + pub role: S2Role, +} + +/// Identifier of a protocol that can be used for communication of S2 messages between nodes, for example `"WebSocket"` +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub struct CommunicationProtocol(pub String); + +/// Role within the S2 standard. +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] +#[serde(rename_all = "UPPERCASE")] +pub enum S2Role { + /// Customer Energy Manager. + Cem, + /// Resource Manager. + Rm, +} + +/// Place of deployment for an S2 Node +#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] +#[serde(rename_all = "UPPERCASE")] +pub enum Deployment { + /// On a WAN, reachable over the internet + Wan, + /// On the local network, only reachable near the place the device is located. + Lan, +} diff --git a/s2energy-connection/src/lib.rs b/s2energy-connection/src/lib.rs index 77543f6..d97e31d 100644 --- a/s2energy-connection/src/lib.rs +++ b/s2energy-connection/src/lib.rs @@ -1 +1,4 @@ +pub(crate) mod common; pub mod pairing; + +pub use common::wire::{CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, S2Role}; diff --git a/s2energy-connection/src/pairing/client.rs b/s2energy-connection/src/pairing/client.rs index 6d79ee7..e0c28b0 100644 --- a/s2energy-connection/src/pairing/client.rs +++ b/s2energy-connection/src/pairing/client.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use reqwest::{StatusCode, Url}; use rustls::pki_types::CertificateDer; +use crate::common::wire::{AccessToken, Deployment, PairingVersion, S2NodeId, S2Role, WirePairingVersion}; use crate::pairing::transport::{HashProvider, hash_providing_https_client}; use crate::pairing::{Pairing, PairingRole, SUPPORTED_PAIRING_VERSIONS}; diff --git a/s2energy-connection/src/pairing/mod.rs b/s2energy-connection/src/pairing/mod.rs index 1d20d72..e38d5b7 100644 --- a/s2energy-connection/src/pairing/mod.rs +++ b/s2energy-connection/src/pairing/mod.rs @@ -7,7 +7,8 @@ //! The main configuration struct [`EndpointConfig`] describes an S2 endpoint. It is constructed through //! a builder pattern. For simple configuration, the builder can immediately be build: //! ```rust -//! # use s2energy_connection::pairing::{EndpointConfig, MessageVersion, S2NodeDescription, S2NodeId, S2Role}; +//! # use s2energy_connection::pairing::EndpointConfig; +//! # use s2energy_connection::{MessageVersion, S2NodeDescription, S2NodeId, S2Role}; //! let _config = EndpointConfig::builder(S2NodeDescription { //! id: S2NodeId(String::from("12121212")), //! brand: String::from("super-reliable-corp"), @@ -23,7 +24,8 @@ //! //! Additional information can be added through methods on the builder. For example, we can add a connection initiate url through: //! ```rust -//! # use s2energy_connection::pairing::{EndpointConfig, MessageVersion, S2NodeDescription, S2NodeId, S2Role}; +//! # use s2energy_connection::pairing::EndpointConfig; +//! # use s2energy_connection::{MessageVersion, S2NodeDescription, S2NodeId, S2Role}; //! let _config = EndpointConfig::builder(S2NodeDescription { //! id: S2NodeId(String::from("12121212")), //! brand: String::from("super-reliable-corp"), @@ -44,7 +46,8 @@ //! server. For this, you will also need to know the id of the node, and the URL on which its pairing server is reachable. //! ```rust //! # use std::sync::Arc; -//! # use s2energy_connection::pairing::{Client, ClientConfig, Deployment, EndpointConfig, MessageVersion, PairingRemote, S2NodeDescription, S2NodeId, S2Role}; +//! # use s2energy_connection::pairing::{Client, ClientConfig, EndpointConfig, PairingRemote}; +//! # use s2energy_connection::{Deployment, MessageVersion, S2NodeDescription, S2NodeId, S2Role}; //! # let config = EndpointConfig::builder(S2NodeDescription { //! # id: S2NodeId(String::from("12121212")), //! # brand: String::from("super-reliable-corp"), @@ -102,7 +105,8 @@ //! ```no_run //! # use std::{path::PathBuf, net::SocketAddr, sync::Arc}; //! # use axum_server::tls_rustls::RustlsConfig; -//! # use s2energy_connection::pairing::{EndpointConfig, MessageVersion, PairingToken, Server, ServerConfig, S2NodeDescription, S2NodeId, S2Role}; +//! # use s2energy_connection::pairing::{EndpointConfig, PairingToken, Server, ServerConfig}; +//! # use s2energy_connection::{MessageVersion, S2NodeDescription, S2NodeId, S2Role}; //! # #[tokio::main(flavor = "current_thread")] //! # async fn main() { //! # let tls_config = RustlsConfig::from_pem_file( @@ -137,7 +141,8 @@ //! ```no_run //! # use std::{path::PathBuf, net::SocketAddr, sync::Arc}; //! # use axum_server::tls_rustls::RustlsConfig; -//! # use s2energy_connection::pairing::{EndpointConfig, MessageVersion, PairingToken, Server, ServerConfig, S2NodeDescription, S2NodeId, S2Role}; +//! # use s2energy_connection::pairing::{EndpointConfig, PairingToken, Server, ServerConfig}; +//! # use s2energy_connection::{MessageVersion, S2NodeDescription, S2NodeId, S2Role}; //! # #[tokio::main(flavor = "current_thread")] //! # async fn main() { //! # let tls_config = RustlsConfig::from_pem_file( @@ -183,13 +188,15 @@ mod wire; use rand::Rng; -use wire::{AccessToken, HmacChallenge, HmacChallengeResponse}; +use wire::{HmacChallenge, HmacChallengeResponse}; pub use client::{Client, ClientConfig, PairingRemote}; pub use server::{PairingToken, PendingPairing, RepeatedPairing, Server, ServerConfig}; -pub use wire::{CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, S2Role}; -use crate::pairing::wire::PairingVersion; +use crate::{ + CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2Role, + common::wire::{AccessToken, PairingVersion}, +}; const SUPPORTED_PAIRING_VERSIONS: &[PairingVersion] = &[PairingVersion::V1]; diff --git a/s2energy-connection/src/pairing/server.rs b/s2energy-connection/src/pairing/server.rs index 98c58c5..b815b71 100644 --- a/s2energy-connection/src/pairing/server.rs +++ b/s2energy-connection/src/pairing/server.rs @@ -16,9 +16,12 @@ use rustls::pki_types::CertificateDer; use sha2::Digest; use tokio::time::Instant; -use crate::pairing::{PairingRole, SUPPORTED_PAIRING_VERSIONS}; +use crate::{ + common::wire::{AccessToken, PairingVersion, S2EndpointDescription, S2NodeDescription, S2NodeId}, + pairing::{PairingRole, SUPPORTED_PAIRING_VERSIONS}, +}; -use super::{EndpointConfig, Error, Network, Pairing, PairingResult, S2EndpointDescription, S2NodeDescription, wire::*}; +use super::{EndpointConfig, Error, Network, Pairing, PairingResult, wire::*}; const PERMANENT_PAIRING_BUFFER_SIZE: usize = 1; diff --git a/s2energy-connection/src/pairing/wire.rs b/s2energy-connection/src/pairing/wire.rs index 9b47ba6..2751496 100644 --- a/s2energy-connection/src/pairing/wire.rs +++ b/s2energy-connection/src/pairing/wire.rs @@ -2,6 +2,8 @@ use axum::http::{HeaderMap, HeaderName, HeaderValue}; use serde::*; use thiserror::Error; +use crate::common::wire::{AccessToken, CommunicationProtocol, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId}; + #[derive(Error, Debug, Serialize, Deserialize)] pub(crate) enum PairingResponseErrorMessage { #[error("Invalid combination of roles")] @@ -24,34 +26,29 @@ pub(crate) enum PairingResponseErrorMessage { Other, } -#[derive(Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[serde(rename_all = "lowercase")] -pub(crate) enum PairingVersion { - V1, -} - -#[derive(Deserialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[serde(rename_all = "lowercase")] -pub(crate) enum WirePairingVersion { - V1, - #[serde(other)] - Other, +#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] +#[serde(rename_all = "UPPERCASE")] +pub(crate) enum HmacHashingAlgorithm { + Sha256, } -impl TryFrom for PairingVersion { - type Error = (); - - fn try_from(value: WirePairingVersion) -> Result { - match value { - WirePairingVersion::V1 => Ok(PairingVersion::V1), - WirePairingVersion::Other => Err(()), - } - } -} +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct HmacChallenge( + #[serde( + serialize_with = "base64_bytes::serialize", + deserialize_with = "base64_bytes::deserialize::<_, 32>" + )] + pub(crate) [u8; 32], +); -/// Message schema version. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct MessageVersion(pub String); +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub(crate) struct HmacChallengeResponse( + #[serde( + serialize_with = "base64_bytes::serialize", + deserialize_with = "base64_bytes::deserialize::<_, 32>" + )] + pub(crate) [u8; 32], +); #[derive(Serialize, Deserialize)] pub(crate) struct RequestPairing { @@ -77,113 +74,6 @@ pub(crate) struct RequestPairing { pub force_pairing: bool, } -/// Information about the pairing endpoint of a S2 node -#[derive(Default, Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct S2EndpointDescription { - /// Name of the endpoint - #[serde(default)] - pub name: Option, - /// URI of a logo to be used for the endpoint in GUIs - #[serde(default)] - pub logo_uri: Option, - /// Type of deployment used by the endpoint (local or globally routable). - #[serde(default)] - pub deployment: Option, -} - -/// One-time access token for secure access to the S2 message communication channel. It must be renewed every time a client wants to access -/// the S2 message communication channel by calling the requestToken endpoint. This token is valid for one time login, with a maximum 5 -/// years, and should have a minimum length of 32 bytes. -#[derive(Serialize, Deserialize, Clone)] -pub struct AccessToken(pub String); - -impl AccessToken { - pub fn new(rng: &mut impl rand::Rng) -> Self { - use base64::{Engine as _, engine::general_purpose::STANDARD}; - - let mut bytes = [0u8; 32]; - rng.fill(&mut bytes); - - let encoded = STANDARD.encode(bytes); - Self(encoded) - } -} - -/// Unique identifier of the S2 node -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] -pub struct S2NodeId(pub String); - -/// Information about the S2 node -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(rename_all = "camelCase")] -pub struct S2NodeDescription { - /// Unique identifier of the node - pub id: S2NodeId, - /// Brandname used for the node - pub brand: String, - /// URI of a logo to be used for the node in GUIs - #[serde(default)] - pub logo_uri: Option, - /// The type of this node. - pub type_: String, - /// Model name of the device this node belongs to. - pub model_name: String, - /// A name for the device configured by the end user/owner. - #[serde(default)] - pub user_defined_name: Option, - /// The S2 role this device has (e.g. CEM or RM). - pub role: S2Role, -} - -/// Identifier of a protocol that can be used for communication of S2 messages between nodes, for example `"WebSocket"` -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] -pub struct CommunicationProtocol(pub String); - -/// Role within the S2 standard. -#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] -#[serde(rename_all = "UPPERCASE")] -pub enum S2Role { - /// Customer Energy Manager. - Cem, - /// Resource Manager. - Rm, -} - -/// Place of deployment for an S2 Node -#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] -#[serde(rename_all = "UPPERCASE")] -pub enum Deployment { - /// On a WAN, reachable over the internet - Wan, - /// On the local network, only reachable near the place the device is located. - Lan, -} - -#[derive(Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)] -#[serde(rename_all = "UPPERCASE")] -pub(crate) enum HmacHashingAlgorithm { - Sha256, -} - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct HmacChallenge( - #[serde( - serialize_with = "base64_bytes::serialize", - deserialize_with = "base64_bytes::deserialize::<_, 32>" - )] - pub(crate) [u8; 32], -); - -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -pub(crate) struct HmacChallengeResponse( - #[serde( - serialize_with = "base64_bytes::serialize", - deserialize_with = "base64_bytes::deserialize::<_, 32>" - )] - pub(crate) [u8; 32], -); - /// An identifier that is generated by the server for each pairing attempt. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct PairingAttemptId(String); From a3d50b2c013ef21276c6f26f9cab59eed7a1c32b Mon Sep 17 00:00:00 2001 From: David Venhoek Date: Tue, 17 Feb 2026 13:41:48 +0100 Subject: [PATCH 2/3] Moved version negotiation logic for client and server to common. --- s2energy-connection/src/common/mod.rs | 40 +++++++++++++++++++++++ s2energy-connection/src/pairing/client.rs | 23 ++----------- s2energy-connection/src/pairing/mod.rs | 14 ++++++-- s2energy-connection/src/pairing/server.rs | 11 +++---- 4 files changed, 59 insertions(+), 29 deletions(-) diff --git a/s2energy-connection/src/common/mod.rs b/s2energy-connection/src/common/mod.rs index c70eda0..b58046b 100644 --- a/s2energy-connection/src/common/mod.rs +++ b/s2energy-connection/src/common/mod.rs @@ -1 +1,41 @@ +use axum::Json; + pub(crate) mod wire; + +use reqwest::{StatusCode, Url}; +use wire::PairingVersion; + +use crate::common::wire::WirePairingVersion; + +pub(crate) const SUPPORTED_PAIRING_VERSIONS: &[PairingVersion] = &[PairingVersion::V1]; + +pub(crate) async fn root() -> Json<&'static [PairingVersion]> { + Json(SUPPORTED_PAIRING_VERSIONS) +} + +pub(crate) enum BaseError { + TransportFailed, + ProtocolError, + NoSupportedVersion, +} + +pub(crate) async fn negotiate_version(client: &reqwest::Client, url: Url) -> Result { + let response = client.get(url).send().await.map_err(|_| BaseError::TransportFailed)?; + let status = response.status(); + if status != StatusCode::OK { + return Err(BaseError::ProtocolError); + } + + let supported_versions = response + .json::>() + .await + .map_err(|_| BaseError::ProtocolError)?; + + for version in supported_versions.into_iter().filter_map(|v| v.try_into().ok()) { + if SUPPORTED_PAIRING_VERSIONS.contains(&version) { + return Ok(version); + } + } + + Err(BaseError::NoSupportedVersion) +} diff --git a/s2energy-connection/src/pairing/client.rs b/s2energy-connection/src/pairing/client.rs index e0c28b0..2e5b08e 100644 --- a/s2energy-connection/src/pairing/client.rs +++ b/s2energy-connection/src/pairing/client.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use reqwest::{StatusCode, Url}; use rustls::pki_types::CertificateDer; -use crate::common::wire::{AccessToken, Deployment, PairingVersion, S2NodeId, S2Role, WirePairingVersion}; +use crate::common::negotiate_version; +use crate::common::wire::{AccessToken, Deployment, PairingVersion, S2NodeId, S2Role}; use crate::pairing::transport::{HashProvider, hash_providing_https_client}; -use crate::pairing::{Pairing, PairingRole, SUPPORTED_PAIRING_VERSIONS}; +use crate::pairing::{Pairing, PairingRole}; use super::EndpointConfig; use super::wire::*; @@ -80,24 +81,6 @@ impl Client { } } -async fn negotiate_version(client: &reqwest::Client, url: Url) -> Result { - let response = client.get(url).send().await.map_err(|_| Error::TransportFailed)?; - let status = response.status(); - if status != StatusCode::OK { - return Err(Error::ProtocolError); - } - - let supported_versions = response.json::>().await.map_err(|_| Error::ProtocolError)?; - - for version in supported_versions.into_iter().filter_map(|v| v.try_into().ok()) { - if SUPPORTED_PAIRING_VERSIONS.contains(&version) { - return Ok(version); - } - } - - Err(Error::NoSupportedVersion) -} - struct V1Session<'a> { client: reqwest::Client, base_url: Url, diff --git a/s2energy-connection/src/pairing/mod.rs b/s2energy-connection/src/pairing/mod.rs index e38d5b7..acb7964 100644 --- a/s2energy-connection/src/pairing/mod.rs +++ b/s2energy-connection/src/pairing/mod.rs @@ -195,11 +195,9 @@ pub use server::{PairingToken, PendingPairing, RepeatedPairing, Server, ServerCo use crate::{ CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2Role, - common::wire::{AccessToken, PairingVersion}, + common::{BaseError, wire::AccessToken}, }; -const SUPPORTED_PAIRING_VERSIONS: &[PairingVersion] = &[PairingVersion::V1]; - /// Full description of an S2 endpoint #[derive(Debug, Clone)] pub struct EndpointConfig { @@ -384,6 +382,16 @@ pub enum Error { InvalidConfig(ConfigError), } +impl From for Error { + fn from(value: BaseError) -> Self { + match value { + BaseError::TransportFailed => Self::TransportFailed, + BaseError::ProtocolError => Self::ProtocolError, + BaseError::NoSupportedVersion => Self::NoSupportedVersion, + } + } +} + impl From for Error { fn from(value: ConfigError) -> Self { Self::InvalidConfig(value) diff --git a/s2energy-connection/src/pairing/server.rs b/s2energy-connection/src/pairing/server.rs index b815b71..9c5f175 100644 --- a/s2energy-connection/src/pairing/server.rs +++ b/s2energy-connection/src/pairing/server.rs @@ -17,8 +17,11 @@ use sha2::Digest; use tokio::time::Instant; use crate::{ - common::wire::{AccessToken, PairingVersion, S2EndpointDescription, S2NodeDescription, S2NodeId}, - pairing::{PairingRole, SUPPORTED_PAIRING_VERSIONS}, + common::{ + root, + wire::{AccessToken, PairingVersion, S2EndpointDescription, S2NodeDescription, S2NodeId}, + }, + pairing::PairingRole, }; use super::{EndpointConfig, Error, Network, Pairing, PairingResult, wire::*}; @@ -234,10 +237,6 @@ struct AppStateInner { attempts: Mutex>, } -async fn root() -> Json<&'static [PairingVersion]> { - Json(SUPPORTED_PAIRING_VERSIONS) -} - fn v1_router() -> Router { Router::new() .route("/requestPairing", post(v1_request_pairing)) From f52dae3a284e87a91bfe09d0d2c4155a803c8c22 Mon Sep 17 00:00:00 2001 From: David Venhoek Date: Tue, 17 Feb 2026 13:42:37 +0100 Subject: [PATCH 3/3] Initial rough implementation of connection subprotocol. --- Cargo.lock | 47 ++++ Cargo.toml | 1 + s2energy-connection/Cargo.toml | 1 + .../examples/communication-client.rs | 73 +++++ .../examples/communication-server.rs | 107 +++++++ s2energy-connection/src/common/wire.rs | 33 ++- .../src/communication/client.rs | 148 ++++++++++ s2energy-connection/src/communication/mod.rs | 94 +++++++ .../src/communication/server.rs | 263 ++++++++++++++++++ s2energy-connection/src/communication/wire.rs | 96 +++++++ s2energy-connection/src/lib.rs | 5 +- s2energy-connection/testdata/gen_cert.sh | 2 +- .../testdata/localhost.chain.pem | 43 +++ s2energy-connection/testdata/localhost.key | 28 ++ s2energy-connection/testdata/localhost.pem | 22 ++ 15 files changed, 957 insertions(+), 6 deletions(-) create mode 100644 s2energy-connection/examples/communication-client.rs create mode 100644 s2energy-connection/examples/communication-server.rs create mode 100644 s2energy-connection/src/communication/client.rs create mode 100644 s2energy-connection/src/communication/mod.rs create mode 100644 s2energy-connection/src/communication/server.rs create mode 100644 s2energy-connection/src/communication/wire.rs create mode 100644 s2energy-connection/testdata/localhost.chain.pem create mode 100644 s2energy-connection/testdata/localhost.key create mode 100644 s2energy-connection/testdata/localhost.pem diff --git a/Cargo.lock b/Cargo.lock index 5bbd612..e1a4200 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,6 +112,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-core", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-server" version = "0.8.0" @@ -568,6 +590,30 @@ dependencies = [ "foldhash", ] +[[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + [[package]] name = "heck" version = "0.5.0" @@ -1303,6 +1349,7 @@ name = "s2energy-connection" version = "0.1.0" dependencies = [ "axum", + "axum-extra", "axum-server", "base64", "hmac", diff --git a/Cargo.toml b/Cargo.toml index 0e2a882..d85c744 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "3" [workspace.dependencies] axum = "0.8.8" +axum-extra = { version = "0.12.5", features = ["typed-header"] } base64 = "0.22.1" bon = "3.8.0" chrono = { version = "0.4.42", features = ["serde"] } diff --git a/s2energy-connection/Cargo.toml b/s2energy-connection/Cargo.toml index fd858b7..06832d7 100644 --- a/s2energy-connection/Cargo.toml +++ b/s2energy-connection/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] axum.workspace = true +axum-extra.workspace = true base64.workspace = true hmac.workspace = true rand.workspace = true diff --git a/s2energy-connection/examples/communication-client.rs b/s2energy-connection/examples/communication-client.rs new file mode 100644 index 0000000..b50becd --- /dev/null +++ b/s2energy-connection/examples/communication-client.rs @@ -0,0 +1,73 @@ +use std::{convert::Infallible, path::PathBuf, sync::Arc}; + +use rustls::pki_types::{CertificateDer, pem::PemObject}; +use s2energy_connection::{ + AccessToken, MessageVersion, S2NodeId, + communication::{Client, ClientConfig, ClientPairing, NodeConfig}, +}; + +struct MemoryPairing { + communication_url: String, + tokens: Vec, + server: S2NodeId, + client: S2NodeId, +} + +impl ClientPairing for &mut MemoryPairing { + type Error = Infallible; + + fn client_id(&self) -> S2NodeId { + self.client.clone() + } + + fn server_id(&self) -> S2NodeId { + self.server.clone() + } + + fn access_tokens(&self) -> impl AsRef<[AccessToken]> { + &self.tokens + } + + fn communication_url(&self) -> impl AsRef { + &self.communication_url + } + + async fn set_access_tokens(&mut self, tokens: Vec) -> Result<(), Self::Error> { + self.tokens = tokens; + Ok(()) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let client = Client::new( + ClientConfig { + additional_certificates: vec![ + CertificateDer::from_pem_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("testdata").join("root.pem")).unwrap(), + ], + endpoint_description: None, + }, + Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()), + ); + + let mut pairing = MemoryPairing { + communication_url: "https://localhost:8005/".into(), + tokens: vec![AccessToken("0123456789ABCDEF".into())], + server: S2NodeId("12".into()), + client: S2NodeId("34".into()), + }; + + let connection_info = client.connect(&mut pairing).await.unwrap(); + + println!( + "Url: {}, token: {}", + connection_info.communication_url, connection_info.communication_token.0 + ); + + let connection_info = client.connect(&mut pairing).await.unwrap(); + + println!( + "Url: {}, token: {}", + connection_info.communication_url, connection_info.communication_token.0 + ); +} diff --git a/s2energy-connection/examples/communication-server.rs b/s2energy-connection/examples/communication-server.rs new file mode 100644 index 0000000..ce29135 --- /dev/null +++ b/s2energy-connection/examples/communication-server.rs @@ -0,0 +1,107 @@ +use std::{ + convert::Infallible, + net::SocketAddr, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +use axum_server::tls_rustls::RustlsConfig; +use s2energy_connection::{ + AccessToken, MessageVersion, S2NodeId, + communication::{NodeConfig, PairingLookupResult, Server, ServerConfig, ServerPairing, ServerPairingStore}, +}; + +struct MemoryPairingStoreInner { + token: AccessToken, + config: Arc, + server: S2NodeId, + client: S2NodeId, +} + +#[derive(Clone)] +struct MemoryPairingStore(Arc>); + +impl MemoryPairingStore { + fn new() -> Self { + MemoryPairingStore(Arc::new(Mutex::new(MemoryPairingStoreInner { + token: AccessToken("0123456789ABCDEF".into()), + config: Arc::new(NodeConfig::builder(vec![MessageVersion("v1".into())]).build()), + server: S2NodeId("12".into()), + client: S2NodeId("34".into()), + }))) + } +} + +impl ServerPairingStore for MemoryPairingStore { + type Error = Infallible; + + type Pairing<'a> + = MemoryPairingStore + where + Self: 'a; + + async fn lookup( + &self, + request: s2energy_connection::communication::PairingLookup, + ) -> Result>, Self::Error> { + let this = self.0.lock().unwrap(); + if this.client == request.client && this.server == request.server { + Ok(PairingLookupResult::Pairing(self.clone())) + } else { + Ok(PairingLookupResult::NeverPaired) + } + } +} + +impl ServerPairing for MemoryPairingStore { + type Error = Infallible; + + fn access_token(&self) -> impl AsRef { + self.0.lock().unwrap().token.clone() + } + + fn config(&self) -> impl AsRef { + self.0.lock().unwrap().config.clone() + } + + async fn set_access_token(&mut self, token: AccessToken) -> Result<(), Self::Error> { + self.0.lock().unwrap().token = token; + Ok(()) + } + + async fn update_remote_node_description(&mut self, _node_description: s2energy_connection::S2NodeDescription) { + println!("Received updated node description"); + } + + async fn update_remote_endpoint_description(&mut self, _endpoint_description: s2energy_connection::S2EndpointDescription) { + println!("Received updated endpoint description"); + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + let server = Server::new( + ServerConfig { + base_url: "localhost".into(), + endpoint_description: None, + }, + MemoryPairingStore::new(), + ); + + let addr = SocketAddr::from(([127, 0, 0, 1], 8005)); + + let rustls_config = RustlsConfig::from_pem_file( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("testdata") + .join("localhost.chain.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("testdata").join("localhost.key"), + ) + .await + .unwrap(); + + println!("listening on http://{}", addr); + axum_server::bind_rustls(addr, rustls_config) + .serve(server.get_router().into_make_service()) + .await + .unwrap(); +} diff --git a/s2energy-connection/src/common/wire.rs b/s2energy-connection/src/common/wire.rs index a110e68..8013b20 100644 --- a/s2energy-connection/src/common/wire.rs +++ b/s2energy-connection/src/common/wire.rs @@ -1,3 +1,6 @@ +use axum::extract::FromRequestParts; +use axum_extra::{TypedHeader, headers}; +use reqwest::StatusCode; use serde::{Deserialize, Serialize}; #[derive(Serialize, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -26,11 +29,11 @@ impl TryFrom for PairingVersion { } /// Message schema version. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct MessageVersion(pub String); /// Information about the pairing endpoint of a S2 node -#[derive(Default, Debug, Serialize, Deserialize, Clone)] +#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] #[serde(rename_all = "camelCase")] pub struct S2EndpointDescription { /// Name of the endpoint @@ -47,7 +50,7 @@ pub struct S2EndpointDescription { /// One-time access token for secure access to the S2 message communication channel. It must be renewed every time a client wants to access /// the S2 message communication channel by calling the requestToken endpoint. This token is valid for one time login, with a maximum 5 /// years, and should have a minimum length of 32 bytes. -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct AccessToken(pub String); impl AccessToken { @@ -62,12 +65,34 @@ impl AccessToken { } } +impl AsRef for AccessToken { + fn as_ref(&self) -> &AccessToken { + self + } +} + +impl FromRequestParts for AccessToken { + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { + let Some(token) = Option::>>::from_request_parts(parts, state) + .await + .ok() + .flatten() + else { + return Err(StatusCode::UNAUTHORIZED); + }; + + Ok(AccessToken(token.token().into())) + } +} + /// Unique identifier of the S2 node #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] pub struct S2NodeId(pub String); /// Information about the S2 node -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] #[serde(rename_all = "camelCase")] pub struct S2NodeDescription { /// Unique identifier of the node diff --git a/s2energy-connection/src/communication/client.rs b/s2energy-connection/src/communication/client.rs new file mode 100644 index 0000000..ca0de6a --- /dev/null +++ b/s2energy-connection/src/communication/client.rs @@ -0,0 +1,148 @@ +use std::sync::Arc; + +use reqwest::{StatusCode, Url}; +use rustls::pki_types::CertificateDer; + +use crate::{ + AccessToken, CommunicationProtocol, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, + common::negotiate_version, + communication::{ + CommunicationResult, Error, NodeConfig, + wire::{CommunicationDetails, CommunicationToken, InitiateConnectionRequest, InitiateConnectionResponse}, + }, +}; + +/// Configuration for communication clients. +pub struct ClientConfig { + /// Additional roots of trust for TLS connections. Useful when testing during the development of WAN endpoints. + /// + /// When the remote is on the LAN, this is not used. + pub additional_certificates: Vec>, + /// Optional description of this endpoint, sent as update to the server. + pub endpoint_description: Option, +} + +pub struct Client { + config: Arc, + additional_certificates: Vec>, + endpoint_description: Option, +} + +pub struct ConnectionInfo { + pub server_node_description: Option, + pub server_endpoint_description: Option, + pub message_version: MessageVersion, + + // TODO: replace with actual transport. + pub communication_token: CommunicationToken, + pub communication_url: String, +} + +pub trait ClientPairing: Send { + type Error: std::error::Error; + + fn client_id(&self) -> S2NodeId; + fn server_id(&self) -> S2NodeId; + fn access_tokens(&self) -> impl AsRef<[AccessToken]>; + fn communication_url(&self) -> impl AsRef; + + fn set_access_tokens(&mut self, tokens: Vec) -> impl Future> + Send; +} + +impl Client { + pub fn new(config: ClientConfig, node_config: Arc) -> Self { + Client { + config: node_config, + additional_certificates: config.additional_certificates, + endpoint_description: config.endpoint_description, + } + } + + pub async fn connect(&self, mut pairing: impl ClientPairing) -> CommunicationResult { + let client = reqwest::Client::builder() + .tls_certs_merge( + self.additional_certificates + .iter() + .filter_map(|v| reqwest::Certificate::from_der(v).ok()), + ) + .build() + .map_err(|_| Error::TransportFailed)?; + + let communication_url = Url::parse(pairing.communication_url().as_ref()).map_err(|_| Error::InvalidUrl)?; + + let version = negotiate_version(&client, communication_url.clone()).await?; + + match version { + crate::common::wire::PairingVersion::V1 => { + let base_url = communication_url.join("v1/").unwrap(); + + let request = InitiateConnectionRequest { + client_node_id: pairing.client_id(), + server_node_id: pairing.server_id(), + supported_message_versions: self.config.supported_message_versions.clone(), + supported_communication_protocols: vec![CommunicationProtocol("WebSocket".into())], + node_description: self.config.node_description().cloned(), + endpoint_description: self.endpoint_description.clone(), + }; + + let Some((initiate_response, current_token)) = ('found: { + for token in pairing.access_tokens().as_ref() { + let response = client + .post(base_url.join("initiateConnection").unwrap()) + .bearer_auth(&token.0) + .json(&request) + .send() + .await + .map_err(|_| Error::TransportFailed)?; + + if response.status() == StatusCode::UNAUTHORIZED { + continue; + } + if response.status() != StatusCode::OK { + return Err(Error::TransportFailed); + } + + break 'found Some(( + response + .json::() + .await + .map_err(|_| Error::TransportFailed)?, + token.clone(), + )); + } + None + }) else { + return Err(Error::NotPaired); + }; + + pairing + .set_access_tokens(vec![current_token, initiate_response.access_token.clone()]) + .await + .map_err(|_| Error::Storage)?; + + let response = client + .post(base_url.join("confirmAccessToken").unwrap()) + .bearer_auth(&initiate_response.access_token.0) + .send() + .await + .map_err(|_| Error::TransportFailed)?; + + if response.status() != StatusCode::OK { + return Err(Error::ProtocolError); + } + + let communication_details = response.json::().await.map_err(|_| Error::TransportFailed)?; + + match communication_details { + CommunicationDetails::WebSocket(web_socket_communication_details) => Ok(ConnectionInfo { + server_node_description: initiate_response.node_description, + server_endpoint_description: initiate_response.endpoint_description, + message_version: initiate_response.message_version, + communication_token: web_socket_communication_details.websocket_token, + communication_url: web_socket_communication_details.websocket_url, + }), + } + } + } + } +} diff --git a/s2energy-connection/src/communication/mod.rs b/s2energy-connection/src/communication/mod.rs new file mode 100644 index 0000000..413bb0f --- /dev/null +++ b/s2energy-connection/src/communication/mod.rs @@ -0,0 +1,94 @@ +use crate::{MessageVersion, S2NodeDescription, common::BaseError}; + +mod client; +mod server; +mod wire; + +pub use client::{Client, ClientConfig, ClientPairing, ConnectionInfo}; +pub use server::{PairingLookup, PairingLookupResult, Server, ServerConfig, ServerPairing, ServerPairingStore}; + +/// Full description of an S2 endpoint +#[derive(Debug, Clone)] +pub struct NodeConfig { + node_description: Option, + supported_message_versions: Vec, +} + +impl NodeConfig { + /// Description of the S2 node. + pub fn node_description(&self) -> Option<&S2NodeDescription> { + self.node_description.as_ref() + } + + /// Message versions supported by this endpoint. + pub fn supported_message_versions(&self) -> &[MessageVersion] { + &self.supported_message_versions + } + + /// Create a builder for a new [`EndpointConfig`] + /// + /// All endpoint configurations must at least contain description of the node and supported message versions. Additional + /// properties can be configured through the builder. + pub fn builder(supported_message_versions: Vec) -> ConfigBuilder { + ConfigBuilder { + node_description: None, + supported_message_versions, + } + } +} + +/// Builder for an [`EndpointConfig`] +pub struct ConfigBuilder { + node_description: Option, + supported_message_versions: Vec, +} + +impl ConfigBuilder { + /// Set the node description. + /// + /// Note that this replaces any previous node decriptions passed + pub fn with_node_description(mut self, node_description: S2NodeDescription) -> Self { + self.node_description = Some(node_description); + self + } + + /// Create the actual [`EndpointConfig`], validating that it is reasonable. + pub fn build(self) -> NodeConfig { + NodeConfig { + node_description: self.node_description, + supported_message_versions: self.supported_message_versions, + } + } +} + +/// Error that occured during the communication process. +#[derive(Debug, Clone)] +pub enum Error { + /// Invalid URL for remote + InvalidUrl, + /// Something went wrong in the transport layers + TransportFailed, + /// The remote reacted outside our expectations + ProtocolError, + /// No shared version with the remote. + NoSupportedVersion, + /// The nodes are no longer paired + Unpaired, + /// The nodes were not paired + NotPaired, + /// Storage failed to persist token + Storage, +} + +impl From for Error { + fn from(value: BaseError) -> Self { + match value { + BaseError::TransportFailed => Self::TransportFailed, + BaseError::ProtocolError => Self::ProtocolError, + BaseError::NoSupportedVersion => Self::NoSupportedVersion, + } + } +} + +/// Convenience type for [`Result`] +pub type CommunicationResult = Result; diff --git a/s2energy-connection/src/communication/server.rs b/s2energy-connection/src/communication/server.rs new file mode 100644 index 0000000..fa918dc --- /dev/null +++ b/s2energy-connection/src/communication/server.rs @@ -0,0 +1,263 @@ +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + time::Duration, +}; + +use axum::{ + Json, Router, + extract::State, + response::IntoResponse, + routing::{get, post}, +}; +use reqwest::StatusCode; + +use crate::{ + CommunicationProtocol, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, + common::{root, wire::AccessToken}, + communication::{ + NodeConfig, + wire::{ + CommunicationDetails, CommunicationDetailsErrorMessage, CommunicationToken, InitiateConnectionRequest, + InitiateConnectionResponse, WebSocketCommunicationDetails, + }, + }, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// A pairing to be looked up. +pub struct PairingLookup { + /// Identifier of the remote end of the pairing + pub client: S2NodeId, + /// Identifier of the local end of the pairing + pub server: S2NodeId, +} + +/// Result of looking up a pairing +pub enum PairingLookupResult { + /// Pairing exists + Pairing(Pairing), + /// Pairing existed in the past, but has recently unpaired + Unpaired, + /// Pairing never existed, or existed so long ago that that is no longer known. + NeverPaired, +} + +pub trait ServerPairingStore: Sync + Send + 'static { + type Error: std::error::Error; + type Pairing<'a>: ServerPairing + 'a + where + Self: 'a; + + fn lookup(&self, request: PairingLookup) -> impl Future>, Self::Error>> + Send; +} + +pub trait ServerPairing: Send { + type Error: std::error::Error; + + fn access_token(&self) -> impl AsRef; + fn config(&self) -> impl AsRef; + + fn set_access_token(&mut self, token: AccessToken) -> impl Future> + Send; + fn update_remote_node_description(&mut self, node_description: S2NodeDescription) -> impl Future + Send; + fn update_remote_endpoint_description(&mut self, endpoint_description: S2EndpointDescription) -> impl Future + Send; +} + +/// Configuration for the S2 connection server. +pub struct ServerConfig { + /// URL at which the communication server is reachable. + pub base_url: String, + pub endpoint_description: Option, +} + +pub struct Server { + app_state: AppState, +} + +type AppState = Arc>; + +struct AppStateInner { + store: Store, + pending_tokens: Mutex>, + base_url: String, + endpoint_description: Option, +} + +struct ExpiringSession { + start_time: tokio::time::Instant, + session: Session, +} + +impl ExpiringSession { + fn into_state(self) -> Option { + if self.start_time.elapsed() > Duration::from_secs(15) { + None + } else { + Some(self.session) + } + } +} + +#[expect(unused)] +struct Session { + lookup: PairingLookup, + token: AccessToken, + node_description: Option, + endpoint_description: Option, + message_version: MessageVersion, + communication_protocol: CommunicationProtocol, +} + +impl Server { + pub fn new(config: ServerConfig, store: Store) -> Self { + Server { + app_state: Arc::new(AppStateInner { + store, + pending_tokens: Mutex::new(HashMap::new()), + base_url: config.base_url, + endpoint_description: config.endpoint_description, + }), + } + } + + /// Get an [`axum::Router`] handling the endpoints for the communication protocol. + /// + /// Incomming http requests can be handled by this router through the [axum-server](https://docs.rs/axum-server/0.8.0/axum_server/) crate. + pub fn get_router(&self) -> axum::Router<()> { + Router::new() + .route("/", get(root)) + .nest("/v1", v1_router()) + .with_state(self.app_state.clone()) + } +} + +impl IntoResponse for CommunicationDetailsErrorMessage { + fn into_response(self) -> axum::response::Response { + (StatusCode::BAD_REQUEST, Json(self)).into_response() + } +} + +impl IntoResponse for InitiateConnectionResponse { + fn into_response(self) -> axum::response::Response { + Json(self).into_response() + } +} + +fn select_overlap(primary: &[T], secondary: &[T]) -> Option { + for el in primary { + if secondary.contains(el) { + return Some(el.clone()); + } + } + + None +} + +fn v1_router() -> Router> { + Router::new() + .route("/initiateConnection", post(v1_initiate_connection)) + .route("/confirmAccessToken", post(v1_confirm_access_token)) +} + +async fn v1_initiate_connection( + State(state): State>, + token: AccessToken, + Json(request): Json, +) -> axum::response::Response { + let lookup = PairingLookup { + client: request.client_node_id, + server: request.server_node_id, + }; + + let pairing = match state.store.lookup(lookup.clone()).await { + Ok(PairingLookupResult::Pairing(pairing)) => pairing, + Ok(PairingLookupResult::Unpaired) => return CommunicationDetailsErrorMessage::NoLongerPaired.into_response(), + Ok(PairingLookupResult::NeverPaired) => return StatusCode::UNAUTHORIZED.into_response(), + Err(_) => return StatusCode::INTERNAL_SERVER_ERROR.into_response(), + }; + + if pairing.access_token().as_ref() != &token { + return StatusCode::UNAUTHORIZED.into_response(); + } + + let config = pairing.config(); + + let Some(communication_protocol) = select_overlap( + &request.supported_communication_protocols, + &[CommunicationProtocol("WebSocket".into())], + ) else { + return CommunicationDetailsErrorMessage::IncompatibleCommunicationProtocols.into_response(); + }; + let Some(message_version) = select_overlap(&request.supported_message_versions, config.as_ref().supported_message_versions()) else { + return CommunicationDetailsErrorMessage::IncompatibleS2MessageVersions.into_response(); + }; + + let mut pending_tokens = state.pending_tokens.lock().unwrap(); + + let new_access_token = loop { + let candidate = AccessToken::new(&mut rand::rng()); + if !pending_tokens.contains_key(&candidate) { + break candidate; + } + }; + + pending_tokens.insert( + new_access_token.clone(), + ExpiringSession { + start_time: tokio::time::Instant::now(), + session: Session { + lookup, + token, + node_description: request.node_description, + endpoint_description: request.endpoint_description, + message_version: message_version.clone(), + communication_protocol: communication_protocol.clone(), + }, + }, + ); + + InitiateConnectionResponse { + communication_protocol, + message_version, + access_token: new_access_token, + node_description: config.as_ref().node_description().cloned(), + endpoint_description: state.endpoint_description.clone(), + } + .into_response() +} + +impl IntoResponse for CommunicationDetails { + fn into_response(self) -> axum::response::Response { + Json(self).into_response() + } +} + +async fn v1_confirm_access_token( + State(state): State>, + token: AccessToken, +) -> Result { + let session = { + let mut pending_tokens = state.pending_tokens.lock().unwrap(); + pending_tokens + .remove(&token) + .and_then(|v| v.into_state()) + .ok_or(StatusCode::UNAUTHORIZED)? + }; + + let mut pairing = match state.store.lookup(session.lookup.clone()).await { + Ok(PairingLookupResult::Pairing(pairing)) => pairing, + Ok(PairingLookupResult::Unpaired | PairingLookupResult::NeverPaired) => return Err(StatusCode::UNAUTHORIZED), + Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + + pairing + .set_access_token(token) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + // TODO: Implement websocket communication + Ok(CommunicationDetails::WebSocket(WebSocketCommunicationDetails { + websocket_token: CommunicationToken::new(&mut rand::rng()), + websocket_url: format!("wss://{}/v1/websocket", state.base_url), + })) +} diff --git a/s2energy-connection/src/communication/wire.rs b/s2energy-connection/src/communication/wire.rs new file mode 100644 index 0000000..efd8317 --- /dev/null +++ b/s2energy-connection/src/communication/wire.rs @@ -0,0 +1,96 @@ +use axum::extract::FromRequestParts; +use axum_extra::{TypedHeader, headers}; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::{CommunicationProtocol, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, common::wire::AccessToken}; + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub(crate) enum CommunicationDetails { + WebSocket(WebSocketCommunicationDetails), +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub(crate) struct WebSocketCommunicationDetails { + pub(crate) websocket_token: CommunicationToken, + pub(crate) websocket_url: String, +} + +#[derive(Serialize, Deserialize, Debug, Error, Clone, PartialEq, Eq, Hash)] +pub(crate) enum CommunicationDetailsErrorMessage { + #[error("Incompatible S2 message versions")] + IncompatibleS2MessageVersions, + #[error("Incompatible communication protocols")] + IncompatibleCommunicationProtocols, + #[error("No longer paired")] + NoLongerPaired, + #[error("Parsing error")] + ParsingError, + #[error("Other")] + Other, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct InitiateConnectionRequest { + #[serde(rename = "clientS2NodeId")] + pub(crate) client_node_id: S2NodeId, + #[serde(rename = "serverS2NodeId")] + pub(crate) server_node_id: S2NodeId, + #[serde(rename = "supportedS2MessageVersions")] + pub(crate) supported_message_versions: Vec, + #[serde(rename = "supportedCommunicationProtocols")] + pub(crate) supported_communication_protocols: Vec, + #[serde(rename = "clientS2NodeDescription")] + pub(crate) node_description: Option, + #[serde(rename = "clientS2EndpointDescription")] + pub(crate) endpoint_description: Option, +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub(crate) struct InitiateConnectionResponse { + #[serde(rename = "selectedCommunicationProtocol")] + pub(crate) communication_protocol: CommunicationProtocol, + #[serde(rename = "selectedS2MessageVersion")] + pub(crate) message_version: MessageVersion, + #[serde(rename = "accessToken")] + pub(crate) access_token: AccessToken, + #[serde(rename = "serverS2NodeDescription")] + pub(crate) node_description: Option, + #[serde(rename = "serverS2EndpointDescription")] + pub(crate) endpoint_description: Option, +} + +/// One-time access token for secure access to the S2 message communication channel. It must be renewed every time a client wants to access +/// the S2 message communication channel by calling the requestToken endpoint. This token is valid for one time login, with a maximum 5 +/// years, and should have a minimum length of 32 bytes. +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash)] +pub struct CommunicationToken(pub String); + +impl CommunicationToken { + pub fn new(rng: &mut impl rand::Rng) -> Self { + use base64::{Engine as _, engine::general_purpose::STANDARD}; + + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + + let encoded = STANDARD.encode(bytes); + Self(encoded) + } +} + +impl FromRequestParts for CommunicationToken { + type Rejection = StatusCode; + + async fn from_request_parts(parts: &mut axum::http::request::Parts, state: &S) -> Result { + let Some(token) = Option::>>::from_request_parts(parts, state) + .await + .ok() + .flatten() + else { + return Err(StatusCode::UNAUTHORIZED); + }; + + Ok(CommunicationToken(token.token().into())) + } +} diff --git a/s2energy-connection/src/lib.rs b/s2energy-connection/src/lib.rs index d97e31d..a34efce 100644 --- a/s2energy-connection/src/lib.rs +++ b/s2energy-connection/src/lib.rs @@ -1,4 +1,7 @@ pub(crate) mod common; +pub mod communication; pub mod pairing; -pub use common::wire::{CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, S2Role}; +pub use common::wire::{ + AccessToken, CommunicationProtocol, Deployment, MessageVersion, S2EndpointDescription, S2NodeDescription, S2NodeId, S2Role, +}; diff --git a/s2energy-connection/testdata/gen_cert.sh b/s2energy-connection/testdata/gen_cert.sh index 75041f7..32f7ce2 100755 --- a/s2energy-connection/testdata/gen_cert.sh +++ b/s2energy-connection/testdata/gen_cert.sh @@ -42,7 +42,7 @@ EOF openssl x509 -req -in "$FILENAME".csr -CA "$CA".pem -CAkey "$CA".key -out "$FILENAME".pem -days 365 -sha256 -extfile "$FILENAME".ext # generate the full certificate chain version -cat "$FILENAME".pem "$CA".pem > "$FILENAME".fullchain.pem +cat "$FILENAME".pem "$CA".pem > "$FILENAME".chain.pem # cleanup rm "$FILENAME".csr "$FILENAME".ext diff --git a/s2energy-connection/testdata/localhost.chain.pem b/s2energy-connection/testdata/localhost.chain.pem new file mode 100644 index 0000000..51f34fd --- /dev/null +++ b/s2energy-connection/testdata/localhost.chain.pem @@ -0,0 +1,43 @@ +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIUZeLrw4Ef1ghsGyyOQ7o4NRI4by8wDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNjAyMTcxMjE0NDVaFw0yNzAy +MTcxMjE0NDVaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCwtrLZGNMRtGKPZFNCbDu5Fgf+Zx7291enIj6Fp/DW ++TvgZT5ZlaihGn849/P2IVY+0qKQgDPIkfgqh5Xx2atJO74shcr/BAPPGdrVkW1d +9Xte93Qlqi/3497Syy51cXM1+/SmatH1RfbJ7opUKHFw/J2LZOtOo0X7FB4Y2mGz +pA02FGFe7gsZi018z7UBDaw6s4Uhn0VGlCS7/wHwa3RFA0eT70/uon5TDeGlthtX +xa0kw2Wu0uVCDHRThIyh3S2F3YDsqpN9oNafmwlASqD0k5IqnRKZICkL8t0H424G +8V9auyemPdPAz6lQR9XzlVBHjMVy5KAPk4C3rk1klEn9AgMBAAGjgYYwgYMwHwYD +VR0jBBgwFoAU/s43TRxw6GnGd7XFg0DaX7vRhnAwCQYDVR0TBAIwADALBgNVHQ8E +BAMCA6gwEwYDVR0lBAwwCgYIKwYBBQUHAwEwFAYDVR0RBA0wC4IJbG9jYWxob3N0 +MB0GA1UdDgQWBBTLzjoq/J2PyhSPRcy1QZZGuiUWhzANBgkqhkiG9w0BAQsFAAOC +AQEAI8w9HGhK4iRSA4vHxRTc7JhkFuCAf/V9JLrPtEXubtUfZKsVgKNpIsuwjMgl +dORwBXH5z3UIwU8AeZBgsSex9iufXfQV3hf6z9u647F+Vz7DgrsYsfapKs59olci +Z3Wcf6NFTuQEhEZdzl35w2PpTbIK0lrrJ2WMoVgYFljvCRojuzqDPEJQtiWgxQ7f +MSgFmHkRfJ+gJbYzwozIJPS1uH/ZQQa/iIXRV4UIZddRp6XI2tqBsaqT1r/KjF8P +wMaLefxC0+6C6yHcB5sMptHSvl2VeerWLs4Dag/O53N1QzLoCZnRzAZodiO39FyU +fcvdl6j2lIbax24bgxR3zQUZ2Q== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIURu/3EiBqW5zBHLYfNRjzTfs2QMQwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNjAyMTAwODI3NDRaFw0zMTAy +MDkwODI3NDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDXr9mtQIgvaDBkxr6mw5UVg9Ox35e01rCz1rrl9MqS +c9SaKtjCCQVyotokHsXT6DKJ4H+CGDD+Y8ULhEQN7B8VSfGB8gg/nD7KHk2dzxNt +8kkZGWDKanyWWdawrsegcApvwV2eHa5/94sHHkJCZNJoRMtmoimZ0o848jOIAUoS +pO1bIxRq7N2YluJVaMYk/U2GBOfwpjhXcy74kQrq1mGyyE3hzJUgtaRGlDsvp3c0 +99b9Pd2fRAmqUzjijibQfheuum4KCLwoZCGvwnY4iQM6vQjNY06djAqyR6XFGH8s +7EVzNFSJyhJZK30FaAYPDeVDIJUKTSrJlTr2ddjF5Gp/AgMBAAGjUzBRMB0GA1Ud +DgQWBBT+zjdNHHDoacZ3tcWDQNpfu9GGcDAfBgNVHSMEGDAWgBT+zjdNHHDoacZ3 +tcWDQNpfu9GGcDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQCI +3ZjK6htARQAYx80BIV4ZOm/MPtQxEOMJ0gywcOckwUICHbmdjj5T2XdsE1L9EbYw +8U/5XItzcfnmf5A+7Pf1UqnfOgeCAw7tl4zX5zCHDm0l3nXmOSnyU1RMetJ+aXTT +LZyV6JJxcEFseQsqdBwx6AkXGz4CqLBDMbwi6j+1yRfib11m2gZGYozNFKDrw6xS +L0KFcBWCM8lzb6W5oc3P+oA+EoF3nhgydtb1vNwe9wkubrRl5GkFzRrnEHTDRpLe +NShyxRuBPtQoKwcIfMaNt+9W5qMwrYjh21mCGX122K8kAdXDT35AYcAK2X8WpT9F +nL09Lv6HBpesSih6ZRS3 +-----END CERTIFICATE----- diff --git a/s2energy-connection/testdata/localhost.key b/s2energy-connection/testdata/localhost.key new file mode 100644 index 0000000..0ece3c1 --- /dev/null +++ b/s2energy-connection/testdata/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCwtrLZGNMRtGKP +ZFNCbDu5Fgf+Zx7291enIj6Fp/DW+TvgZT5ZlaihGn849/P2IVY+0qKQgDPIkfgq +h5Xx2atJO74shcr/BAPPGdrVkW1d9Xte93Qlqi/3497Syy51cXM1+/SmatH1RfbJ +7opUKHFw/J2LZOtOo0X7FB4Y2mGzpA02FGFe7gsZi018z7UBDaw6s4Uhn0VGlCS7 +/wHwa3RFA0eT70/uon5TDeGlthtXxa0kw2Wu0uVCDHRThIyh3S2F3YDsqpN9oNaf +mwlASqD0k5IqnRKZICkL8t0H424G8V9auyemPdPAz6lQR9XzlVBHjMVy5KAPk4C3 +rk1klEn9AgMBAAECggEACRKNgsEthPPX444TtY33pab7PhUZQotXAAkCthCgjXOP +nr5T7Hwsg349bd/9Bx6Vs5/+IfZoXNxxpe1UG23rcf++jGv0t2cDXbdGxD7fzeUe +h7T4oj5nPA4s5cbyBFbmGALEKrnCkcRyc9JACStnt5SgmgQnuIqAEJtAeFIt60UO +MdWwn+5NWUu1E9EbYGDP3X+295VkcJXaJQEoqnFA5JEj9T09VTVa9vq/pTUYn/qB +aW/bmfWiZl5bfu2L+Qv1zs/i0ieEmkC8yMV61msjgwPvaS56KSo2qVBmk3V7/+hJ +eqEtjC0Eg0+87Bn3oi9kSV7ZQD523wzBF3SKZUqJgQKBgQDzP8bErUWnuhMEh75z +7nGa+0vYq1WeC52yPebXnW48mtTvbaZIB2ZK4pypJZHaObqNZPcscLzJ87/rIlwN +J1UvyeOCIXFyjuo/0ehSb4+drs3fyydKrMHFwwl/oXVJ/YW2LUZP7iO/LW2++5oI +9oKXdgvMt5BDJO5uBDMcqAdZ0QKBgQC5+hDCj/Uv62KzAUcVCI1W0PY4nOlT/shi +wQAlyhflz6m2xSkiF9OLuQyJ0C1YuSuiHLo9Q3hufd7U+eRb4GN0xgsD9uu32feJ +jSh48NcRaZ/6HKqISv1eGKTLod+MXrcVC/rXAelvZVuBpnp9MHqrKqO4btxRVoLF +/d6Rh3lMbQKBgQCakiNPpT+G9pHRJiUa7CEat6cZtr5AIOeDdRx0VODQ+B5pSscI +LFOPMHMWdP46qsZlxQvgHH+K4S5KT1opLZ5PML42WeQKRNCL32n+wE+FhqfiFukP +5bl4Xphxlvq+GrDV8+0jK5Nhj4+WdbELEwInFucmnlq4oAY2uMp14jxRkQKBgDf9 +EqKgWD5O7O3bCp1Ib9SdICM3Cf+hio5AcFzwFHW5KOy/OnzrE2LTGPU8WQqG5J3v +bBoZf94zwqv3d0o5qXd0T8inw5sb4avldTPDvdueIu1XR/e0K8byQFqVpwlJUnDh +pGiqSK6iowPLLMEXoTZ6pcNWjLloBAK7RRAm6tuZAoGAbR03owQ2zWjBzKZhIY24 +2MWLdhvdQmKK05x8k13fDG25iMdti9c7V9hf639f/vMGJBDTKdsnoMR7fGLuPkFg +wSBglTSrWMVqMPozCSeVu+JFtmv2nYdUcJyn7HAe4kHErydVzw9Rff9j9AuDWu4d +s1nWYDiS6FgnIYFBcnkl0P4= +-----END PRIVATE KEY----- diff --git a/s2energy-connection/testdata/localhost.pem b/s2energy-connection/testdata/localhost.pem new file mode 100644 index 0000000..96f666f --- /dev/null +++ b/s2energy-connection/testdata/localhost.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIUZeLrw4Ef1ghsGyyOQ7o4NRI4by8wDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNjAyMTcxMjE0NDVaFw0yNzAy +MTcxMjE0NDVaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCwtrLZGNMRtGKPZFNCbDu5Fgf+Zx7291enIj6Fp/DW ++TvgZT5ZlaihGn849/P2IVY+0qKQgDPIkfgqh5Xx2atJO74shcr/BAPPGdrVkW1d +9Xte93Qlqi/3497Syy51cXM1+/SmatH1RfbJ7opUKHFw/J2LZOtOo0X7FB4Y2mGz +pA02FGFe7gsZi018z7UBDaw6s4Uhn0VGlCS7/wHwa3RFA0eT70/uon5TDeGlthtX +xa0kw2Wu0uVCDHRThIyh3S2F3YDsqpN9oNafmwlASqD0k5IqnRKZICkL8t0H424G +8V9auyemPdPAz6lQR9XzlVBHjMVy5KAPk4C3rk1klEn9AgMBAAGjgYYwgYMwHwYD +VR0jBBgwFoAU/s43TRxw6GnGd7XFg0DaX7vRhnAwCQYDVR0TBAIwADALBgNVHQ8E +BAMCA6gwEwYDVR0lBAwwCgYIKwYBBQUHAwEwFAYDVR0RBA0wC4IJbG9jYWxob3N0 +MB0GA1UdDgQWBBTLzjoq/J2PyhSPRcy1QZZGuiUWhzANBgkqhkiG9w0BAQsFAAOC +AQEAI8w9HGhK4iRSA4vHxRTc7JhkFuCAf/V9JLrPtEXubtUfZKsVgKNpIsuwjMgl +dORwBXH5z3UIwU8AeZBgsSex9iufXfQV3hf6z9u647F+Vz7DgrsYsfapKs59olci +Z3Wcf6NFTuQEhEZdzl35w2PpTbIK0lrrJ2WMoVgYFljvCRojuzqDPEJQtiWgxQ7f +MSgFmHkRfJ+gJbYzwozIJPS1uH/ZQQa/iIXRV4UIZddRp6XI2tqBsaqT1r/KjF8P +wMaLefxC0+6C6yHcB5sMptHSvl2VeerWLs4Dag/O53N1QzLoCZnRzAZodiO39FyU +fcvdl6j2lIbax24bgxR3zQUZ2Q== +-----END CERTIFICATE-----