diff --git a/client/client_db/migrations/02-denim/up.sql b/client/client_db/migrations/02-denim/up.sql index d7553b1..be41d6b 100644 --- a/client/client_db/migrations/02-denim/up.sql +++ b/client/client_db/migrations/02-denim/up.sql @@ -6,7 +6,8 @@ CREATE TABLE DeniableDeviceSessionStore ( CREATE TABLE DeniablePayload ( id INTEGER PRIMARY KEY, - content BLOB NOT NULL + content BLOB NOT NULL, + chunk_count INTEGER NOT NULL ); CREATE TABLE DeniableKeyRequestsSent ( diff --git a/client/src/client.rs b/client/src/client.rs index 9434731..f7d517f 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -449,7 +449,7 @@ impl Client { .device .lock() .await - .store_deniable_payload(None, deniable_payload_serialized) + .store_deniable_payload(None, 0, deniable_payload_serialized) .await .map_err(DatabaseError::from)?; } @@ -583,15 +583,17 @@ impl Client { let mut deniable_payloads = Vec::new(); for chunk in new_chunks { - chunks.push(chunk.clone()); if chunk.is_final() { chunks.sort(); + chunks.push(chunk.clone()); let deniable_bytes: Vec = chunks.into_iter().flat_map(|c| c.chunk).collect(); let deniable_payload = deserialize(&deniable_bytes).expect("Should be deniable payload"); deniable_payloads.push(deniable_payload); chunks = Vec::new(); + } else { + chunks.push(chunk.clone()); } } @@ -682,7 +684,7 @@ impl Client { .device .lock() .await - .store_deniable_payload(None, deniable_payload_serialized) + .store_deniable_payload(None, 0, deniable_payload_serialized) .await .map_err(DatabaseError::from)?; self.storage diff --git a/client/src/storage/database.rs b/client/src/storage/database.rs index cbde786..5a0d6be 100644 --- a/client/src/storage/database.rs +++ b/client/src/storage/database.rs @@ -118,11 +118,15 @@ pub trait ClientDB { async fn get_aci(&self) -> Result; async fn set_pni(&mut self, new_pni: Pni) -> Result<(), Self::Error>; async fn get_pni(&self) -> Result; - async fn get_deniable_payload(&self) -> Result<(u32, Vec), Self::Error>; - async fn get_deniable_payload_by_id(&self, payload_id: u32) -> Result, Self::Error>; + async fn get_deniable_payload(&self) -> Result<(u32, Vec, i32), Self::Error>; + async fn get_deniable_payload_by_id( + &self, + payload_id: u32, + ) -> Result<(Vec, i32), Self::Error>; async fn store_deniable_payload( &self, payload_id: Option, + chunk_count: i32, payload: Vec, ) -> Result<(), Self::Error>; async fn remove_deniable_payload(&self, payload_id: u32) -> Result<(), Self::Error>; @@ -470,7 +474,7 @@ impl SessionStore for DeniableStore { #[async_trait(?Send)] impl DeniableSendingBuffer for DeniableStore { - async fn get_outgoing_message(&mut self) -> Result<(u32, Vec), SignalProtocolError> { + async fn get_outgoing_message(&mut self) -> Result<(u32, Vec, i32), SignalProtocolError> { self.db .lock() .await @@ -482,12 +486,13 @@ impl DeniableSendingBuffer for DeniableStore { async fn set_outgoing_message( &mut self, message_id: Option, + chunk_count: i32, outgoing_message: Vec, ) -> Result<(), SignalProtocolError> { self.db .lock() .await - .store_deniable_payload(message_id, outgoing_message) + .store_deniable_payload(message_id, chunk_count, outgoing_message) .await .map_err(|err| SignalProtocolError::InvalidArgument(format!("{err}"))) } diff --git a/client/src/storage/device.rs b/client/src/storage/device.rs index d538de7..836fc2c 100644 --- a/client/src/storage/device.rs +++ b/client/src/storage/device.rs @@ -965,32 +965,35 @@ impl ClientDB for Device { )?) } - async fn get_deniable_payload(&self) -> Result<(u32, Vec), Self::Error> { + async fn get_deniable_payload(&self) -> Result<(u32, Vec, i32), Self::Error> { let mut stmt = self .conn .prepare( r#" SELECT - id, content + id, content, chunk_count FROM DeniablePayload "#, ) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; - let row: (u32, Vec) = stmt - .query_row([], |row| Ok((row.get(0)?, row.get(1)?))) + let row: (u32, Vec, i32) = stmt + .query_row([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?))) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; Ok(row) } - async fn get_deniable_payload_by_id(&self, payload_id: u32) -> Result, Self::Error> { + async fn get_deniable_payload_by_id( + &self, + payload_id: u32, + ) -> Result<(Vec, i32), Self::Error> { let mut stmt = self .conn .prepare( r#" SELECT - content + content, chunk_count FROM DeniablePayload WHERE @@ -999,8 +1002,8 @@ impl ClientDB for Device { ) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; - let row: Vec = stmt - .query_row([payload_id], |row| Ok(row.get(0)?)) + let row: (Vec, i32) = stmt + .query_row([payload_id], |row| Ok((row.get(0)?, row.get(1)?))) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; Ok(row) } @@ -1008,6 +1011,7 @@ impl ClientDB for Device { async fn store_deniable_payload( &self, payload_id: Option, + chunk_count: i32, payload: Vec, ) -> Result<(), Self::Error> { if let Some(id) = payload_id { @@ -1016,21 +1020,21 @@ impl ClientDB for Device { .prepare( r#" UPDATE DeniablePayload - SET content = ?1 - WHERE id = ?2 + SET content = ?1, chunk_count = ?2 + WHERE id = ?3 "#, ) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; - stmt.execute(params![payload, id]) + stmt.execute(params![payload, chunk_count, id]) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; } else { let mut stmt = self .conn .prepare( r#" - INSERT INTO DeniablePayload (content) - VALUES (?1) + INSERT INTO DeniablePayload (content, chunk_count) + VALUES (?1, 0) "#, ) .map_err(|err| SignalProtocolError::InvalidArgument(format!("{}", err)))?; diff --git a/client/src/storage/in_memory.rs b/client/src/storage/in_memory.rs index 1e41bd2..96881f9 100644 --- a/client/src/storage/in_memory.rs +++ b/client/src/storage/in_memory.rs @@ -267,14 +267,19 @@ impl ClientDB for InMemory { Ok(self.pni) } - async fn get_deniable_payload(&self) -> Result<(u32, Vec), Self::Error> { + async fn get_deniable_payload(&self) -> Result<(u32, Vec, i32), Self::Error> { todo!() } - async fn get_deniable_payload_by_id(&self, _: u32) -> Result, Self::Error> { + async fn get_deniable_payload_by_id(&self, _: u32) -> Result<(Vec, i32), Self::Error> { todo!() } - async fn store_deniable_payload(&self, _: Option, _: Vec) -> Result<(), Self::Error> { + async fn store_deniable_payload( + &self, + _: Option, + _: i32, + _: Vec, + ) -> Result<(), Self::Error> { todo!() } diff --git a/common/src/deniable/chunk.rs b/common/src/deniable/chunk.rs index 27f9f59..c8b2182 100644 --- a/common/src/deniable/chunk.rs +++ b/common/src/deniable/chunk.rs @@ -60,13 +60,14 @@ impl Chunker { } else { new_chunk = DenimChunk { chunk: current_outgoing_message.1[..chunk_size].to_vec(), - flags: ChunkType::Data(0).into(), + flags: ChunkType::Data(current_outgoing_message.2).into(), }; let remaining_current_outgoing_message = current_outgoing_message.1[chunk_size..].to_vec(); buffer .set_outgoing_message( Some(current_outgoing_message.0), + current_outgoing_message.2 - 1, remaining_current_outgoing_message, ) .await @@ -157,13 +158,16 @@ mod test { #[async_trait(?Send)] impl DeniableSendingBuffer for MockDeniableSendingBuffer { - async fn get_outgoing_message(&mut self) -> Result<(u32, Vec), SignalProtocolError> { + async fn get_outgoing_message( + &mut self, + ) -> Result<(u32, Vec, i32), SignalProtocolError> { let message: [u8; 32] = rand::random(); - Ok((1, message.to_vec())) + Ok((1, message.to_vec(), 0)) } async fn set_outgoing_message( &mut self, _: Option, + _: i32, _: Vec, ) -> Result<(), SignalProtocolError> { Ok(()) diff --git a/common/src/deniable/mod.rs b/common/src/deniable/mod.rs index c8ce2e5..afd53c4 100644 --- a/common/src/deniable/mod.rs +++ b/common/src/deniable/mod.rs @@ -6,10 +6,11 @@ pub mod constants; #[async_trait(?Send)] pub trait DeniableSendingBuffer { - async fn get_outgoing_message(&mut self) -> Result<(u32, Vec), SignalProtocolError>; + async fn get_outgoing_message(&mut self) -> Result<(u32, Vec, i32), SignalProtocolError>; async fn set_outgoing_message( &mut self, message_id: Option, + chunk_count: i32, outgoing_message: Vec, ) -> Result<(), SignalProtocolError>; async fn remove_outgoing_message(&mut self, message_id: u32) diff --git a/common/src/web_api/mod.rs b/common/src/web_api/mod.rs index 574ef07..17c15fb 100644 --- a/common/src/web_api/mod.rs +++ b/common/src/web_api/mod.rs @@ -343,11 +343,11 @@ pub struct DenimChunk { impl DenimChunk { pub fn is_dummy(&self) -> bool { - self.flags & 1 == 1 + self.flags == 1 } pub fn is_final(&self) -> bool { - self.flags & 2 == 2 + self.flags == 2 } }