Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion s2energy-connection/src/pairing/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ impl<'a> V1Session<'a> {
Network::Wan
};

let client_hmac_challenge = HmacChallenge::new(&mut rand::rng());
const HMAC_CHALLENGE_BYTES: usize = 32;
let client_hmac_challenge = HmacChallenge::new(&mut rand::rng(), HMAC_CHALLENGE_BYTES);

let request_pairing_response = self.request_pairing(id, &client_hmac_challenge).await?;
let attempt_id = request_pairing_response.pairing_attempt_id;
Expand All @@ -139,6 +140,7 @@ impl<'a> V1Session<'a> {
}
}

debug_assert!(request_pairing_response.server_hmac_challenge.0.len() < 32);
let server_hmac_challenge_response = match request_pairing_response.selected_hmac_hashing_algorithm {
HmacHashingAlgorithm::Sha256 => request_pairing_response.server_hmac_challenge.sha256(&network, pairing_token),
};
Expand Down
6 changes: 3 additions & 3 deletions s2energy-connection/src/pairing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ pub struct Pairing {
}

impl HmacChallenge {
pub fn new(rng: &mut impl Rng) -> Self {
Self(rng.random())
pub fn new(rng: &mut impl Rng, len: usize) -> Self {
Self(rng.random_iter().take(len).collect())
}

pub fn sha256(&self, network: &Network, pairing_token: &[u8]) -> HmacChallengeResponse {
Expand All @@ -353,7 +353,7 @@ impl HmacChallenge {
}
}

HmacChallengeResponse(mac.finalize().into_bytes().into())
HmacChallengeResponse(mac.finalize().into_bytes().to_vec())
}
}

Expand Down
6 changes: 5 additions & 1 deletion s2energy-connection/src/pairing/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,11 @@ async fn v1_request_pairing(
return Err(PairingResponseErrorMessage::IncompatibleHMACHashingAlgorithms.into());
}

// 32 bytes is the minimum, this tests that the client can handle more.
const HMAC_CHALLENGE_BYTES: usize = 64;

let mut rng = rand::rng();
let server_hmac_challenge = HmacChallenge::new(&mut rng);
let server_hmac_challenge = HmacChallenge::new(&mut rng, HMAC_CHALLENGE_BYTES);

let open_pairing = {
let mut open_pairings = state.open_pairings.lock().unwrap();
Expand Down Expand Up @@ -297,6 +300,7 @@ async fn v1_request_pairing(
}
}

debug_assert!(request_pairing.client_hmac_challenge.0.len() < 32);
let client_hmac_challenge_response = request_pairing.client_hmac_challenge.sha256(&state.network, &open_pairing.token.0);

let pairing_attempt_id = {
Expand Down
33 changes: 18 additions & 15 deletions s2energy-connection/src/pairing/wire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,26 @@ pub(crate) enum HmacHashingAlgorithm {

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(crate) struct HmacChallenge(
#[serde(
serialize_with = "base64_bytes::serialize",
deserialize_with = "base64_bytes::deserialize::<_, 32>"
)]
pub(crate) [u8; 32],
#[serde(serialize_with = "base64_bytes::serialize", deserialize_with = "deserialize_hmac_challenge")] pub(crate) Vec<u8>,
);

pub(crate) fn deserialize_hmac_challenge<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let decoded = base64_bytes::deserialize(deserializer)?;

// The spec demands that an hmac challenge is at least 32 bytes.
if decoded.len() < 32 {
return Err(de::Error::custom("hmac challenge shorter than 32 bytes"));
}

Ok(decoded)
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(crate) struct HmacChallengeResponse(
#[serde(
serialize_with = "base64_bytes::serialize",
deserialize_with = "base64_bytes::deserialize::<_, 32>"
)]
pub(crate) [u8; 32],
#[serde(serialize_with = "base64_bytes::serialize", deserialize_with = "base64_bytes::deserialize")] pub(crate) Vec<u8>,
);

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -150,16 +156,13 @@ mod base64_bytes {
serializer.serialize_str(&encoded)
}

pub(crate) fn deserialize<'de, D, const N: usize>(deserializer: D) -> Result<[u8; N], D::Error>
pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let decoded = STANDARD.decode(&s).map_err(de::Error::custom)?;

decoded
.as_slice()
.try_into()
.map_err(|_| de::Error::custom(format!("expected {N} bytes after base64 decoding, got {}", decoded.len())))
Ok(decoded)
}
}