From 3156e5fcf6238797638fa6b16cf105e9d35cfab4 Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Tue, 13 Jan 2026 15:38:23 -0800 Subject: [PATCH] Add Finalize and Decrypt functions to ServerAccumulator, expand ShellTestingDecryptor to handle partial decryption requests. PiperOrigin-RevId: 855912104 --- willow/proto/willow/BUILD | 1 + willow/proto/willow/messages.proto | 12 + willow/proto/willow/server_accumulator.proto | 8 +- willow/src/api/BUILD | 5 + willow/src/api/client_test.cc | 4 +- willow/src/api/server_accumulator.cc | 95 ++++++- willow/src/api/server_accumulator.h | 49 +++- willow/src/api/server_accumulator.rs | 262 ++++++++++++++++-- willow/src/api/server_accumulator_test.cc | 115 ++++++-- .../testing_utils/shell_testing_decryptor.cc | 20 +- .../testing_utils/shell_testing_decryptor.h | 12 +- .../testing_utils/shell_testing_decryptor.rs | 82 +++++- .../shell_testing_decryptor_test.cc | 10 +- willow/src/willow_v1/server.rs | 2 +- 14 files changed, 580 insertions(+), 97 deletions(-) diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index fae653c..7d3ae66 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -94,6 +94,7 @@ proto_library( deps = [ ":aggregation_config_proto", ":messages_proto", + "//willow/proto/shell:shell_ciphertexts_proto", ], ) diff --git a/willow/proto/willow/messages.proto b/willow/proto/willow/messages.proto index 53c1318..a9375de 100644 --- a/willow/proto/willow/messages.proto +++ b/willow/proto/willow/messages.proto @@ -65,3 +65,15 @@ message VerifierStateProto { bytes nonce_lower_bound = 2; bytes nonce_upper_bound = 3; } + +// Result of finalizing an accumulator. +message FinalizedAccumulatorResult { + // Serialized decryption request to include in a DecryptRequest to be sent to + // the decryptor service. + bytes decryption_request = 1; + + // Serialized state for creating a final result decryptor, which will handle + // the response from the decryptor service and produce a plaintext aggregation + // result. + bytes final_result_decryptor_state = 2; +} diff --git a/willow/proto/willow/server_accumulator.proto b/willow/proto/willow/server_accumulator.proto index ef6d606..20c6819 100644 --- a/willow/proto/willow/server_accumulator.proto +++ b/willow/proto/willow/server_accumulator.proto @@ -26,7 +26,7 @@ message ServerAccumulatorState { ServerStateProto server_state = 1; AggregationConfigProto aggregation_config = 3; // We have one verifier state per range of nonces processed by this - // accumulator. States gat merged when adjacent ranges are processed. + // accumulator. States get merged when adjacent ranges are processed. repeated VerifierStateProto verifier_states = 2; // The ranges of nonces processed by this accumulator. In the same order as // the corresponding verifier states. @@ -43,3 +43,9 @@ message NonceRange { bytes start = 1; // Inclusive. bytes end = 2; // Exclusive. } + +// State for creating a FinalResultDecryptor. +message FinalResultDecryptorState { + ServerStateProto server_state = 1; + AggregationConfigProto aggregation_config = 2; +} diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index 44c5da4..db2311a 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -61,7 +61,9 @@ cc_library( "@abseil-cpp//absl/strings", "@cxx.rs//:core", "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", + "//willow/src/input_encoding:codec", ], ) @@ -77,6 +79,7 @@ cc_test( "@abseil-cpp//absl/status:statusor", "//shell_wrapper:status_matchers", "//willow/proto/willow:aggregation_config_cc_proto", + "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", "//willow/src/input_encoding:codec", "//willow/src/testing_utils:shell_testing_decryptor_cc", @@ -102,6 +105,8 @@ rust_library( "//shell_wrapper:status", "//willow/proto/willow:aggregation_config_rust_proto", "//willow/proto/willow:server_accumulator_rust_proto", + "//willow/proto/shell:shell_ciphertexts_rust_proto", + "//willow/proto/willow:messages_rust_proto", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", "//willow/src/shell:vahe_shell", diff --git a/willow/src/api/client_test.cc b/willow/src/api/client_test.cc index 64143a6..836995c 100644 --- a/willow/src/api/client_test.cc +++ b/willow/src/api/client_test.cc @@ -35,7 +35,9 @@ namespace willow { namespace { using secure_aggregation::secagg_internal::StatusIs; +using secure_aggregation::testing::ShellTestingDecryptor; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -108,7 +110,7 @@ TEST(WillowShellClientTest, InitializeAndGenerateContribution) { for (const auto& [name, values] : encoded_data) { EXPECT_TRUE(decrypted_encoded_data.contains(name)); const auto& decrypted_values = decrypted_encoded_data[name]; - EXPECT_THAT(decrypted_values, testing::ElementsAreArray(values)); + EXPECT_THAT(decrypted_values, ElementsAreArray(values)); } // Decode decrypted data. diff --git a/willow/src/api/server_accumulator.cc b/willow/src/api/server_accumulator.cc index 139893f..8bc148a 100644 --- a/willow/src/api/server_accumulator.cc +++ b/willow/src/api/server_accumulator.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -26,12 +27,13 @@ #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/server_accumulator.rs.h" +#include "willow/src/input_encoding/codec.h" namespace secure_aggregation { +namespace willow { -absl::StatusOr> -WillowShellServerAccumulator::Create( - const willow::AggregationConfigProto& aggregation_config) { +absl::StatusOr> ServerAccumulator::Create( + const AggregationConfigProto& aggregation_config) { secure_aggregation::ServerAccumulator* out; std::unique_ptr status_message; int status_code = @@ -41,12 +43,11 @@ WillowShellServerAccumulator::Create( if (status_code != 0) { return absl::Status(absl::StatusCode(status_code), *status_message); } - return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out))); + return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } -absl::StatusOr> -WillowShellServerAccumulator::CreateFromSerializedState( - std::string serialized_state) { +absl::StatusOr> +ServerAccumulator::CreateFromSerializedState(std::string serialized_state) { secure_aggregation::ServerAccumulator* out; std::unique_ptr status_message; int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState( @@ -55,17 +56,17 @@ WillowShellServerAccumulator::CreateFromSerializedState( if (status_code != 0) { return absl::Status(absl::StatusCode(status_code), *status_message); } - return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out))); + return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } -absl::Status WillowShellServerAccumulator::ProcessClientMessages( - willow::ClientMessageRange client_messages) { +absl::Status ServerAccumulator::ProcessClientMessages( + ClientMessageRange client_messages) { auto serialized_client_messages = client_messages.SerializeAsString(); client_messages.Clear(); return ProcessClientMessages(std::move(serialized_client_messages)); } -absl::Status WillowShellServerAccumulator::ProcessClientMessages( +absl::Status ServerAccumulator::ProcessClientMessages( std::string serialized_client_messages) { std::unique_ptr status_message; int status_code = accumulator_->ProcessClientMessages( @@ -77,8 +78,8 @@ absl::Status WillowShellServerAccumulator::ProcessClientMessages( return absl::OkStatus(); } -absl::Status WillowShellServerAccumulator::Merge( - std::unique_ptr other) { +absl::Status ServerAccumulator::Merge( + std::unique_ptr other) { std::unique_ptr status_message; int status_code = accumulator_->Merge(std::move(other->accumulator_), &status_message); @@ -88,7 +89,7 @@ absl::Status WillowShellServerAccumulator::Merge( return absl::OkStatus(); } -absl::StatusOr WillowShellServerAccumulator::ToSerializedState() { +absl::StatusOr ServerAccumulator::ToSerializedState() { rust::Vec serialized_state; std::unique_ptr status_message; int status_code = @@ -100,4 +101,70 @@ absl::StatusOr WillowShellServerAccumulator::ToSerializedState() { serialized_state.size()); } +absl::StatusOr ServerAccumulator::Finalize() && { + // Finalize accumulator in Rust and store the serialized results. + rust::Vec decryption_request; + rust::Vec final_result_decryptor_state; + std::unique_ptr status_message; + int status_code = secure_aggregation::FinalizeServerAccumulator( + std::move(accumulator_), &decryption_request, + &final_result_decryptor_state, &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + + // Pack the two serialized results into a single proto. + FinalizedAccumulatorResult result_proto; + result_proto.set_decryption_request( + std::string(reinterpret_cast(decryption_request.data()), + decryption_request.size())); + result_proto.set_final_result_decryptor_state(std::string( + reinterpret_cast(final_result_decryptor_state.data()), + final_result_decryptor_state.size())); + + return result_proto; +} + +absl::StatusOr> +FinalResultDecryptor::CreateFromSerialized( + std::string final_result_decryptor_state) { + secure_aggregation::FinalResultDecryptor* out; + std::unique_ptr status_message; + int status_code = + secure_aggregation::CreateFinalResultDecryptorFromSerialized( + std::make_unique( + std::move(final_result_decryptor_state)), + &out, &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return absl::WrapUnique(new FinalResultDecryptor( + secure_aggregation::FinalResultDecryptorIntoBox(out))); +} + +absl::StatusOr FinalResultDecryptor::Decrypt( + std::string serialized_partial_decryption_response) { + rust::Vec out; + std::unique_ptr status_message; + int status_code = aggregated_ciphertexts_->Decrypt( + std::make_unique( + std::move(serialized_partial_decryption_response)), + &out, &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + EncodedData encoded_data; + for (const auto& rust_entry : out) { + std::string key(rust_entry.key); + std::vector val; + val.reserve(rust_entry.values.size()); + for (auto v : rust_entry.values) { + val.push_back(static_cast(v)); + } + encoded_data[std::move(key)] = std::move(val); + } + return encoded_data; +} + +} // namespace willow } // namespace secure_aggregation \ No newline at end of file diff --git a/willow/src/api/server_accumulator.h b/willow/src/api/server_accumulator.h index e8167e9..cc8427a 100644 --- a/willow/src/api/server_accumulator.h +++ b/willow/src/api/server_accumulator.h @@ -23,53 +23,84 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "include/cxx.h" #include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/messages.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/server_accumulator.rs.h" +#include "willow/src/input_encoding/codec.h" namespace secure_aggregation { +namespace willow { + +// Holds the relevant state from a finalized accumulation, and can decrypt the +// final result using the response from the decryptor service. +class FinalResultDecryptor { + public: + // Creates a new final result decryptor from the given serialized + // state, likely coming from a FinalizedAccumulatorResult. + static absl::StatusOr> + CreateFromSerialized(std::string final_result_decryptor_state); + + // Decrypts final result using the given partial decryption + // response. + absl::StatusOr Decrypt( + std::string serialized_partial_decryption_response); + + private: + explicit FinalResultDecryptor( + rust::Box + aggregated_ciphertexts) + : aggregated_ciphertexts_(std::move(aggregated_ciphertexts)) {} + + rust::Box aggregated_ciphertexts_; +}; // Implements an accumulator class intended to be used by a batch processing // system. Combines both the server and the verifier functionality of willow_v1, // using SHELL for the underlying cryptography. -class WillowShellServerAccumulator { +class ServerAccumulator { public: // Creates a new accumulator with the given aggregation_config and empty // state. - static absl::StatusOr> Create( - const willow::AggregationConfigProto& aggregation_config); + static absl::StatusOr> Create( + const AggregationConfigProto& aggregation_config); // Creates a new accumulator from the given serialized state, which must // correspond to a serialized ServerAccumulatorState proto. - static absl::StatusOr> + static absl::StatusOr> CreateFromSerializedState(std::string serialized_state); // Processes a list of client messages. If an invalid message is encountered, // an error is logged and processing continues. - absl::Status ProcessClientMessages( - willow::ClientMessageRange client_messages); + absl::Status ProcessClientMessages(ClientMessageRange client_messages); // Processes a list of client messages, given as a serialized // ClientMessageList proto. absl::Status ProcessClientMessages(std::string serialized_client_messages); // Merges the state of `other` into the current accumulator. - absl::Status Merge(std::unique_ptr other); + absl::Status Merge(std::unique_ptr other); // Converts the current state of the accumulator to a serialized // ServerAccumulatorState proto. absl::StatusOr ToSerializedState(); + // Finalizes the accumulator and returns a proto that holds the serialized + // decryption request (to be sent to the decryptor service) and the + // serialized decryptor state (to create a FinalResultDecryptor). This + // consumes the accumulator. + absl::StatusOr Finalize() &&; + private: - explicit WillowShellServerAccumulator( + explicit ServerAccumulator( rust::Box accumulator) : accumulator_(std::move(accumulator)) {} rust::Box accumulator_; }; +} // namespace willow } // namespace secure_aggregation #endif // SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_ diff --git a/willow/src/api/server_accumulator.rs b/willow/src/api/server_accumulator.rs index b6456a2..fddc00a 100644 --- a/willow/src/api/server_accumulator.rs +++ b/willow/src/api/server_accumulator.rs @@ -17,13 +17,16 @@ use aggregation_config_rust_proto::AggregationConfigProto; use ahe_traits::AheBase; use kahe_shell::ShellKahe; use kahe_traits::KaheBase; -use messages::ClientMessage; -use parameters_shell::{create_shell_ahe_config, create_shell_kahe_config}; +use messages::{ClientMessage, PartialDecryptionResponse}; +use messages_rust_proto::PartialDecryptionResponse as PartialDecryptionResponseProto; +use parameters_shell::create_shell_configs; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use protobuf::AsView; use rangemap::RangeSet; -use server_accumulator_rust_proto::{ClientMessageRange, NonceRange, ServerAccumulatorState}; +use server_accumulator_rust_proto::{ + ClientMessageRange, FinalResultDecryptorState, NonceRange, ServerAccumulatorState, +}; use server_traits::SecureAggregationServer; use status::StatusError; use std::collections::BTreeMap; @@ -33,10 +36,16 @@ use verifier_traits::SecureAggregationVerifier; use willow_v1_server::{ServerState, WillowV1Server}; use willow_v1_verifier::{VerifierState, WillowV1Verifier}; -#[cxx::bridge] +#[cxx::bridge(namespace = "secure_aggregation")] pub mod ffi { + + // CXX requires shared structs to be defined in the same module. + struct EncodedDataEntry { + key: String, + values: Vec, + } + extern "Rust" { - #[namespace = "secure_aggregation"] type ServerAccumulator; // We cannot use status::FfiStatus because CXX requires shared structs to be defined in the @@ -45,23 +54,20 @@ pub mod ffi { // ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer // arguments are not null. - #[namespace = "secure_aggregation"] #[cxx_name = "NewServerAccumulatorFromSerializedConfig"] - unsafe fn new_server_accumulator_from_serialized_config( + unsafe fn new_accumulator_from_serialized_config( serialized_aggregation_config: UniquePtr, out: *mut *mut ServerAccumulator, out_status_message: *mut UniquePtr, ) -> i32; - #[namespace = "secure_aggregation"] #[cxx_name = "NewServerAccumulatorFromSerializedState"] - unsafe fn new_server_accumulator_from_serialized_state( - serialized_server_accumulator: UniquePtr, + unsafe fn new_accumulator_from_serialized_state( + serialized_accumulator: UniquePtr, out: *mut *mut ServerAccumulator, out_status_message: *mut UniquePtr, ) -> i32; - #[namespace = "secure_aggregation"] #[cxx_name = "ProcessClientMessages"] unsafe fn process_client_messages_ffi( self: &mut ServerAccumulator, @@ -69,7 +75,6 @@ pub mod ffi { out_status_message: *mut UniquePtr, ) -> i32; - #[namespace = "secure_aggregation"] #[cxx_name = "ToSerializedState"] unsafe fn to_serialized_state_ffi( self: &ServerAccumulator, @@ -77,7 +82,6 @@ pub mod ffi { out_status_message: *mut UniquePtr, ) -> i32; - #[namespace = "secure_aggregation"] #[cxx_name = "Merge"] unsafe fn merge_ffi( self: &mut ServerAccumulator, @@ -85,9 +89,39 @@ pub mod ffi { out_status_message: *mut UniquePtr, ) -> i32; - #[namespace = "secure_aggregation"] #[cxx_name = "IntoBox"] unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box; + + // #[cxx_name = "FinalResultDecryptorRust"] + type FinalResultDecryptor; + + #[cxx_name = "FinalizeServerAccumulator"] + unsafe fn finalize_accumulator_ffi( + accumulator: Box, + out_decryption_request: *mut Vec, + out_final_result_decryptor_state: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[cxx_name = "Decrypt"] + unsafe fn decrypt_ffi( + self: &mut FinalResultDecryptor, + serialized_partial_decryption_response: UniquePtr, + out: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[cxx_name = "CreateFinalResultDecryptorFromSerialized"] + unsafe fn create_final_result_decryptor_from_serialized( + serialized_final_result_decryptor_state: UniquePtr, + out: *mut *mut FinalResultDecryptor, + out_status_message: *mut UniquePtr, + ) -> i32; + + #[cxx_name = "FinalResultDecryptorIntoBox"] + unsafe fn final_result_decryptor_into_box( + ptr: *mut FinalResultDecryptor, + ) -> Box; } } @@ -113,8 +147,7 @@ pub struct ServerAccumulator { impl ServerAccumulator { fn new(aggregation_config: AggregationConfig) -> Result { let context_string = aggregation_config.compute_context_bytes()?; - let vahe_config = create_shell_ahe_config(aggregation_config.max_number_of_decryptors)?; - let kahe_config = create_shell_kahe_config(&aggregation_config)?; + let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?; let server_kahe = ShellKahe::new(kahe_config, &context_string)?; let server_vahe = ShellVahe::new(vahe_config.clone(), &context_string)?; let verifier_vahe = ShellVahe::new(vahe_config, &context_string)?; @@ -143,13 +176,13 @@ impl ServerAccumulator { } fn new_from_serialized_state( - serialized_server_accumulator: cxx::UniquePtr, + serialized_accumulator: cxx::UniquePtr, ) -> Result { - let serialized_server_accumulator_proto = ServerAccumulatorState::parse( - serialized_server_accumulator.as_bytes(), - ) - .map_err(|e| status::internal(format!("Failed to parse ServerAccumulatorState: {}", e)))?; - Self::from_proto(serialized_server_accumulator_proto, ()) + let serialized_accumulator_proto = + ServerAccumulatorState::parse(serialized_accumulator.as_bytes()).map_err(|e| { + status::internal(format!("Failed to parse ServerAccumulatorState: {}", e)) + })?; + Self::from_proto(serialized_accumulator_proto, ()) } // Updates the server and verifier states with the given client message. In case of error, @@ -492,14 +525,14 @@ impl FromProto for ServerAccumulator { // SAFETY: // - `out` must not be null. It must be turned into a rust::Box on the C++ side. // - `out_status_message` must not be null. -unsafe fn new_server_accumulator_from_serialized_config( +unsafe fn new_accumulator_from_serialized_config( serialized_aggregation_config: cxx::UniquePtr, out: *mut *mut ServerAccumulator, out_status_message: *mut cxx::UniquePtr, ) -> i32 { match ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); + Ok(accumulator) => { + *out = Box::into_raw(Box::new(accumulator)); 0 } Err(status_error) => { @@ -513,14 +546,14 @@ unsafe fn new_server_accumulator_from_serialized_config( // SAFETY: // - `out` must not be null. It must be turned into a rust::Box on the C++ side. // - `out_status_message` must not be null. -unsafe fn new_server_accumulator_from_serialized_state( - serialized_server_accumulator: cxx::UniquePtr, +unsafe fn new_accumulator_from_serialized_state( + serialized_accumulator: cxx::UniquePtr, out: *mut *mut ServerAccumulator, out_status_message: *mut cxx::UniquePtr, ) -> i32 { - match ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); + match ServerAccumulator::new_from_serialized_state(serialized_accumulator) { + Ok(accumulator) => { + *out = Box::into_raw(Box::new(accumulator)); 0 } Err(status_error) => { @@ -536,3 +569,172 @@ unsafe fn new_server_accumulator_from_serialized_state( unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box { Box::from_raw(ptr) } + +// SAFETY: +// - `ptr` must have been created by Box::into_raw or one of the functions in this module. +unsafe fn final_result_decryptor_into_box( + ptr: *mut FinalResultDecryptor, +) -> Box { + Box::from_raw(ptr) +} + +/// Final result decryptor. +pub struct FinalResultDecryptor { + /// Contains aggregated KAHE ciphertexts and aggregated AHE recover ciphertexts (ct_0) + /// + /// NOTE: We technically only need client_sum, not decryptor_public_key_shares or + /// partial_decryption_sum, but because of the monolithic SecureAggregationServer trait + /// (b/476137863) we need a complete ServerState to call the decryption functions. + server_state: ServerState, + + /// Server used to hold the necessary KAHE and AHE contexts. + server: WillowV1Server, +} + +fn finalize_accumulator(accumulator: ServerAccumulator) -> Result<(Vec, Vec), StatusError> { + // Consume and merge all verifier states into one. + let mut final_verifier_state = VerifierState::default(); + let verifier_states = accumulator.verifier_states; + for (_, verifier_state) in verifier_states.into_iter() { + final_verifier_state = + accumulator.verifier.merge_states(verifier_state, final_verifier_state)?; + } + + // Use merged verifier to prepare partial decryption request (i.e. sum of AHE ct_1 ciphertexts) + // The decryption service expects a serialized PartialDecryptionRequestProto + // (https://github.com/google-parfait/trusted-computations-platform/blob/60804e2364ad789cf0682d19d5957dba5d076553/apps/willow/decryptor/actor/src/actor.rs#L290) + let partial_decryption_request = + accumulator.verifier.create_partial_decryption_request(final_verifier_state)?; + let serialized_decryption_request = partial_decryption_request + .to_proto(&accumulator.server)? + .serialize() + .map_err(|e| status::internal(format!("Failed to serialize: {}", e)))?; + + // Extract the server state (i.e. sum of KAHE ciphertexts and sum of AHE ct_0 ciphertexts). + let server_state_proto = accumulator.server_state.to_proto(&accumulator.server)?; + let aggregation_config_proto = accumulator.aggregation_config.to_proto(())?; + + let final_result_decryptor_state = proto!(FinalResultDecryptorState { + server_state: server_state_proto, + aggregation_config: aggregation_config_proto, + }); + + let serialized_final_result_decryptor_state = final_result_decryptor_state + .serialize() + .map_err(|e| status::internal(format!("Failed to serialize: {}", e)))?; + + Ok((serialized_decryption_request, serialized_final_result_decryptor_state)) +} + +pub unsafe fn finalize_accumulator_ffi( + accumulator: Box, + out_decryption_request: *mut Vec, + out_final_result_decryptor_state: *mut Vec, + out_status_message: *mut cxx::UniquePtr, +) -> i32 { + match finalize_accumulator(*accumulator) { + Ok((decryption_request, final_result_decryptor_state)) => { + *out_decryption_request = decryption_request; + *out_final_result_decryptor_state = final_result_decryptor_state; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } +} + +impl FinalResultDecryptor { + fn new_from_serialized( + serialized_proto: cxx::UniquePtr, + ) -> Result { + // Parse aggregation config and server state protos. + let final_result_decryptor_state_proto = + FinalResultDecryptorState::parse(serialized_proto.as_bytes()).map_err(|e| { + status::internal(format!("Failed to parse FinalResultDecryptorState: {}", e)) + })?; + let server_state_proto = final_result_decryptor_state_proto.server_state(); + let aggregation_config_proto = final_result_decryptor_state_proto.aggregation_config(); + + // Build server that holds the necessary KAHE and AHE contexts, and recover server state. + let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?; + let context_string = aggregation_config.compute_context_bytes()?; + let (kahe_config, vahe_config) = create_shell_configs(&aggregation_config)?; + let kahe = ShellKahe::new(kahe_config, &context_string)?; + let vahe = ShellVahe::new(vahe_config, &context_string)?; + let server = WillowV1Server { kahe, vahe }; + let server_state = ServerState::from_proto(server_state_proto, &server)?; + + Ok(FinalResultDecryptor { server_state, server }) + } + + /// Receives a single partial decryption response and attempts to recover right away. + /// This only works in the single-decryptor case. + pub fn decrypt( + &mut self, + serialized_partial_decryption_response: cxx::UniquePtr, + ) -> Result, StatusError> { + let pd_proto = PartialDecryptionResponseProto::parse( + serialized_partial_decryption_response.as_bytes(), + ) + .map_err(|e| { + status::internal(format!("Failed to parse PartialDecryptionResponse: {}", e)) + })?; + let pd = PartialDecryptionResponse::from_proto(pd_proto, &self.server)?; + + self.server.handle_partial_decryption(pd, &mut self.server_state)?; + + // This is a Kahe::Plaintext, i.e. HashMap> + let aggregation_result = self.server.recover_aggregation_result(&self.server_state)?; + + // Flatten hashmap for FFI like in shell_testing_decryptor.rs + let entries = aggregation_result + .into_iter() + .map(|(key, values)| ffi::EncodedDataEntry { key, values }) + .collect(); + Ok(entries) + } + + /// SAFETY: `out` and `out_status_message` must not be null. + pub unsafe fn decrypt_ffi( + &mut self, + serialized_partial_decryption_response: cxx::UniquePtr, + out: *mut Vec, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.decrypt(serialized_partial_decryption_response) { + Ok(result) => { + *out = result; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } +} + +// SAFETY: +// - `out` must not be null. +// - `out_status_message` must not be null. +unsafe fn create_final_result_decryptor_from_serialized( + serialized_proto: cxx::UniquePtr, + out: *mut *mut FinalResultDecryptor, + out_status_message: *mut cxx::UniquePtr, +) -> i32 { + match FinalResultDecryptor::new_from_serialized(serialized_proto) { + Ok(final_result_decryptor) => { + *out = Box::into_raw(Box::new(final_result_decryptor)); + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } +} diff --git a/willow/src/api/server_accumulator_test.cc b/willow/src/api/server_accumulator_test.cc index 903a08b..05558e5 100644 --- a/willow/src/api/server_accumulator_test.cc +++ b/willow/src/api/server_accumulator_test.cc @@ -15,6 +15,7 @@ #include "willow/src/api/server_accumulator.h" #include +#include #include #include "absl/status/status.h" @@ -23,19 +24,21 @@ #include "gtest/gtest.h" #include "shell_wrapper/status_matchers.h" #include "willow/proto/willow/aggregation_config.pb.h" +#include "willow/proto/willow/messages.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/client.h" #include "willow/src/input_encoding/codec.h" #include "willow/src/testing_utils/shell_testing_decryptor.h" namespace secure_aggregation { +namespace willow { namespace { using ::secure_aggregation::secagg_internal::StatusIs; -using ::secure_aggregation::willow::AggregationConfigProto; -using ::secure_aggregation::willow::ClientMessageRange; -using ::secure_aggregation::willow::ServerAccumulatorState; -using ::secure_aggregation::willow::VectorConfig; +// using ::secure_aggregation::willow::AggregationConfigProto; +// using ::secure_aggregation::willow::ClientMessageRange; +// using ::secure_aggregation::willow::ServerAccumulatorState; +// using ::secure_aggregation::willow::VectorConfig; using ::testing::HasSubstr; AggregationConfigProto CreateValidConfig() { @@ -50,17 +53,17 @@ AggregationConfigProto CreateValidConfig() { return config; } -TEST(WillowShellServerAccumulatorTest, CreateSucceedsWithValidConfig) { +TEST(BasicServerAccumulatorTest, CreateSucceedsWithValidConfig) { AggregationConfigProto config = CreateValidConfig(); - auto accumulator_or = WillowShellServerAccumulator::Create(config); + auto accumulator_or = ServerAccumulator::Create(config); ASSERT_TRUE(accumulator_or.ok()) << accumulator_or.status(); EXPECT_NE(*accumulator_or, nullptr); } -TEST(WillowShellServerAccumulatorTest, ToSerializedStateHasCorrectConfig) { +TEST(BasicServerAccumulatorTest, ToSerializedStateHasCorrectConfig) { AggregationConfigProto config = CreateValidConfig(); SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator, - WillowShellServerAccumulator::Create(config)); + ServerAccumulator::Create(config)); auto serialized_state_or = accumulator->ToSerializedState(); ASSERT_TRUE(serialized_state_or.ok()) << serialized_state_or.status(); @@ -73,16 +76,15 @@ TEST(WillowShellServerAccumulatorTest, ToSerializedStateHasCorrectConfig) { config.max_number_of_clients()); } -TEST(WillowShellServerAccumulatorTest, CreateFromSerializedStateRoundTrip) { +TEST(BasicServerAccumulatorTest, CreateFromSerializedStateRoundTrip) { AggregationConfigProto config = CreateValidConfig(); SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator, - WillowShellServerAccumulator::Create(config)); + ServerAccumulator::Create(config)); auto serialized_state_or = accumulator->ToSerializedState(); ASSERT_TRUE(serialized_state_or.ok()) << serialized_state_or.status(); auto accumulator2_or = - WillowShellServerAccumulator::CreateFromSerializedState( - *serialized_state_or); + ServerAccumulator::CreateFromSerializedState(*serialized_state_or); ASSERT_TRUE(accumulator2_or.ok()) << accumulator2_or.status(); EXPECT_NE(*accumulator2_or, nullptr); @@ -91,12 +93,12 @@ TEST(WillowShellServerAccumulatorTest, CreateFromSerializedStateRoundTrip) { EXPECT_EQ(*serialized_state_or, *serialized_state2_or); } -TEST(WillowShellServerAccumulatorTest, MergeSucceedsWithEmptyAccumulators) { +TEST(BasicServerAccumulatorTest, MergeSucceedsWithEmptyAccumulators) { AggregationConfigProto config = CreateValidConfig(); SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator1, - WillowShellServerAccumulator::Create(config)); + ServerAccumulator::Create(config)); SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2, - WillowShellServerAccumulator::Create(config)); + ServerAccumulator::Create(config)); SECAGG_ASSERT_OK_AND_ASSIGN(auto serialized_1, accumulator1->ToSerializedState()); SECAGG_ASSERT_OK_AND_ASSIGN(auto serialized_2, @@ -105,9 +107,9 @@ TEST(WillowShellServerAccumulatorTest, MergeSucceedsWithEmptyAccumulators) { EXPECT_TRUE(accumulator1->Merge(std::move(accumulator2)).ok()); } -TEST(WillowShellServerAccumulatorTest, ProcessClientMessagesWithEmptyList) { +TEST(BasicServerAccumulatorTest, ProcessClientMessagesWithEmptyList) { AggregationConfigProto config = CreateValidConfig(); - auto accumulator = *WillowShellServerAccumulator::Create(config); + auto accumulator = *ServerAccumulator::Create(config); ClientMessageRange empty_list; EXPECT_TRUE(accumulator->ProcessClientMessages(empty_list).ok()); } @@ -117,15 +119,15 @@ class ServerAccumulatorTest : public ::testing::Test { void SetUp() override { config_ = CreateValidConfig(); SECAGG_ASSERT_OK_AND_ASSIGN(accumulator_, - WillowShellServerAccumulator::Create(config_)); - SECAGG_ASSERT_OK_AND_ASSIGN(decryptor_, - ShellTestingDecryptor::Create(config_)); + ServerAccumulator::Create(config_)); + SECAGG_ASSERT_OK_AND_ASSIGN( + decryptor_, testing::ShellTestingDecryptor::Create(config_)); SECAGG_ASSERT_OK_AND_ASSIGN(public_key_, decryptor_->GeneratePublicKey()); } AggregationConfigProto config_; - std::unique_ptr accumulator_; - std::unique_ptr decryptor_; + std::unique_ptr accumulator_; + std::unique_ptr decryptor_; willow::ShellAhePublicKey public_key_; }; @@ -280,7 +282,7 @@ TEST_F(ServerAccumulatorTest, ProcessClientMessagesMergesAdjacentRanges) { TEST_F(ServerAccumulatorTest, MergeSucceedsWithNonEmptyAccumulators) { SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2, - WillowShellServerAccumulator::Create(config_)); + ServerAccumulator::Create(config_)); willow::EncodedData encoded_data = { {"test_vector", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}; @@ -319,7 +321,7 @@ TEST_F(ServerAccumulatorTest, MergeSucceedsWithNonEmptyAccumulators) { TEST_F(ServerAccumulatorTest, MergeSucceedsAndMergesAdjacentRanges) { SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2, - WillowShellServerAccumulator::Create(config_)); + ServerAccumulator::Create(config_)); willow::EncodedData encoded_data = { {"test_vector", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}; @@ -355,7 +357,7 @@ TEST_F(ServerAccumulatorTest, MergeSucceedsAndMergesAdjacentRanges) { TEST_F(ServerAccumulatorTest, MergeFailsWithOverlappingRanges) { SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2, - WillowShellServerAccumulator::Create(config_)); + ServerAccumulator::Create(config_)); willow::EncodedData encoded_data = { {"test_vector", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}; @@ -386,7 +388,7 @@ TEST_F(ServerAccumulatorTest, MergeFailsWithConfigMismatch) { AggregationConfigProto config2 = config_; config2.set_session_id("other_session"); SECAGG_ASSERT_OK_AND_ASSIGN(auto accumulator2, - WillowShellServerAccumulator::Create(config2)); + ServerAccumulator::Create(config2)); EXPECT_THAT(accumulator_->Merge(std::move(accumulator2)), StatusIs(absl::StatusCode::kInvalidArgument, @@ -484,5 +486,66 @@ TEST_F(ServerAccumulatorTest, VerifiesCorrectly) { ASSERT_GE(verifier_state.ByteSizeLong(), 1); } +TEST_F(ServerAccumulatorTest, FinalizeSucceeds) { + willow::EncodedData encoded_data = { + {"test_vector", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}; + std::string nonce = "nonce1"; + SECAGG_ASSERT_OK_AND_ASSIGN( + auto client_message, + GenerateClientContribution(config_, encoded_data, public_key_, nonce)); + ClientMessageRange messages; + *messages.add_client_messages() = client_message; + messages.mutable_nonce_range()->set_start("nonce1"); + messages.mutable_nonce_range()->set_end("nonce2"); + + SECAGG_ASSERT_OK(accumulator_->ProcessClientMessages(messages)); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto aggregated_ciphertexts, + std::move(*accumulator_).Finalize()); +} + +TEST_F(ServerAccumulatorTest, FinalizeFailsWithEmptyAccumulator) { + EXPECT_THAT(std::move(*accumulator_).Finalize(), + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("Must handle at least one client message"))); +} + +TEST_F(ServerAccumulatorTest, FinalizesAndDecryptsCorrectly) { + willow::EncodedData encoded_data = { + {"test_vector", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}; + std::string nonce = "nonce1"; + SECAGG_ASSERT_OK_AND_ASSIGN( + auto client_message, + GenerateClientContribution(config_, encoded_data, public_key_, nonce)); + ClientMessageRange messages; + *messages.add_client_messages() = client_message; + messages.mutable_nonce_range()->set_start("nonce1"); + messages.mutable_nonce_range()->set_end("nonce2"); + + SECAGG_ASSERT_OK(accumulator_->ProcessClientMessages(messages)); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto finalized_result, + std::move(*accumulator_).Finalize()); + + std::string partial_decryption_request = + finalized_result.decryption_request(); + + SECAGG_ASSERT_OK_AND_ASSIGN( + std::string partial_decryption_response, + decryptor_->GenerateSerializedPartialDecryptionResponse( + partial_decryption_request)); + + SECAGG_ASSERT_OK_AND_ASSIGN( + auto aggregated_ciphertexts, + FinalResultDecryptor::CreateFromSerialized( + finalized_result.final_result_decryptor_state())); + + SECAGG_ASSERT_OK_AND_ASSIGN(auto result, aggregated_ciphertexts->Decrypt( + partial_decryption_response)); + + EXPECT_EQ(result, encoded_data); +} + } // namespace +} // namespace willow } // namespace secure_aggregation diff --git a/willow/src/testing_utils/shell_testing_decryptor.cc b/willow/src/testing_utils/shell_testing_decryptor.cc index 08cd02e..1167812 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.cc +++ b/willow/src/testing_utils/shell_testing_decryptor.cc @@ -30,6 +30,7 @@ #include "willow/src/testing_utils/shell_testing_decryptor.rs.h" namespace secure_aggregation { +namespace testing { ShellTestingDecryptor::ShellTestingDecryptor( rust::Box decryptor) @@ -41,7 +42,7 @@ ShellTestingDecryptor::Create( std::string aggregation_config_proto = aggregation_config.SerializeAsString(); rust::Slice slice = ToRustSlice(aggregation_config_proto); - secure_aggregation::ShellTestingDecryptorRust* out; + ShellTestingDecryptorRust* out; std::unique_ptr status_message; int status_code = create_shell_testing_decryptor(slice, &out, &status_message); @@ -77,7 +78,7 @@ absl::StatusOr ShellTestingDecryptor::Decrypt( reinterpret_cast(contribution_proto.data()), contribution_proto.size()); - rust::Vec rust_flat_data; + rust::Vec rust_flat_data; std::unique_ptr status_message; int status_code = decryptor_->decrypt(slice, &rust_flat_data, &status_message); @@ -100,4 +101,19 @@ absl::StatusOr ShellTestingDecryptor::Decrypt( return encoded_data; } +absl::StatusOr +ShellTestingDecryptor::GenerateSerializedPartialDecryptionResponse( + std::string serialized_partial_decryption_request) { + rust::Vec out; + std::unique_ptr status_message; + int status_code = decryptor_->generate_partial_decryption_response( + ToRustSlice(serialized_partial_decryption_request), &out, + &status_message); + if (status_code != 0) { + return absl::Status(absl::StatusCode(status_code), *status_message); + } + return std::string(reinterpret_cast(out.data()), out.size()); +} + +} // namespace testing } // namespace secure_aggregation diff --git a/willow/src/testing_utils/shell_testing_decryptor.h b/willow/src/testing_utils/shell_testing_decryptor.h index 6aa2988..ec8d258 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.h +++ b/willow/src/testing_utils/shell_testing_decryptor.h @@ -18,6 +18,7 @@ #define SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_ #include +#include #include "absl/status/statusor.h" #include "willow/proto/shell/ciphertexts.pb.h" @@ -27,10 +28,11 @@ #include "willow/src/testing_utils/shell_testing_decryptor.rs.h" namespace secure_aggregation { +namespace testing { // Basic implementation of a single decryptor that uses Shell operations -// directly. Useful for testing Shell clients, by checking that encrypted -// messages can be decrypted properly. +// directly. Useful for testing Shell clients or servers, by checking that +// encrypted messages can be decrypted properly. class ShellTestingDecryptor { public: // Creates a new ShellTestingDecryptor from the given config, hashing the @@ -47,6 +49,11 @@ class ShellTestingDecryptor { absl::StatusOr Decrypt( const willow::ClientMessage& message); + // Computes partial decryption for a request containing an AHE partial + // decryption ciphertext. + absl::StatusOr GenerateSerializedPartialDecryptionResponse( + std::string serialized_partial_decryption_request); + private: explicit ShellTestingDecryptor( rust::Box decryptor); @@ -54,6 +61,7 @@ class ShellTestingDecryptor { rust::Box decryptor_; }; +} // namespace testing } // namespace secure_aggregation #endif // SECURE_AGGREGATION_WILLOW_SRC_TESTING_UTILS_SHELL_TESTING_DECRYPTOR_H_ diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index be53c0e..67ca1c2 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -21,8 +21,9 @@ use ahe_traits::{AheBase, AheKeygen, PartialDec}; use kahe_shell::Ciphertext as KaheCiphertext; use kahe_shell::ShellKahe; use kahe_traits::{KaheBase, KaheDecrypt, TrySecretKeyFrom}; -use messages::ClientMessage; +use messages::{ClientMessage, PartialDecryptionRequest, PartialDecryptionResponse}; use messages_rust_proto::ClientMessage as ClientMessageProto; +use messages_rust_proto::PartialDecryptionRequest as PartialDecryptionRequestProto; use parameters_shell::create_shell_configs; use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; @@ -32,7 +33,7 @@ use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; use vahe_shell::ShellVahe; use vahe_traits::Recover; -use vahe_traits::VaheBase; +use vahe_traits::{HasVahe, VaheBase}; /// Basic implementation of a single decryptor that uses Shell operations directly. Useful for /// testing Shell clients, by checking that encrypted messages can be decrypted properly. Comes with @@ -44,6 +45,13 @@ pub struct ShellTestingDecryptor { secret_key: Option<::SecretKeyShare>, } +impl HasVahe for ShellTestingDecryptor { + type Vahe = ShellVahe; + fn vahe(&self) -> &Self::Vahe { + &self.vahe + } +} + impl ShellTestingDecryptor { /// Creates a new ShellTestingDecryptor, using the given context string to seed KAHE and AHE /// public parameters. @@ -76,7 +84,7 @@ impl ShellTestingDecryptor { &mut self, client_message: &ClientMessage, ) -> Result<::Plaintext, StatusError> { - let decryption_request = + let partial_dec_ciphertext = self.vahe.get_partial_dec_ciphertext(&client_message.ahe_ciphertext)?; let rest_of_ciphertext = self.vahe.get_recover_ciphertext(&client_message.ahe_ciphertext)?; @@ -87,7 +95,7 @@ impl ShellTestingDecryptor { )), Some(sk_share) => { let partial_decryption = - self.vahe.partial_decrypt(&decryption_request, sk_share, &mut self.prng)?; + self.vahe.partial_decrypt(&partial_dec_ciphertext, sk_share, &mut self.prng)?; let decrypted_kahe_key = self.vahe.recover(&partial_decryption, &rest_of_ciphertext, None)?; let decrypted_kahe_key = self.kahe.try_secret_key_from(decrypted_kahe_key)?; @@ -170,6 +178,62 @@ impl ShellTestingDecryptor { } } } + + fn generate_partial_decryption_response( + &mut self, + request: &PartialDecryptionRequest, + ) -> Result, StatusError> { + match &self.secret_key { + None => Err(StatusError::new_with_current_location( + StatusErrorCode::InvalidArgument, + "No secret key available", + )), + Some(sk_share) => { + let partial_decryption = self.vahe.partial_decrypt( + &request.partial_dec_ciphertext, + sk_share, + &mut self.prng, + )?; + Ok(PartialDecryptionResponse { partial_decryption }) + } + } + } + + fn generate_partial_decryption_response_serialized( + &mut self, + request: &[u8], + ) -> Result, StatusError> { + let request_proto = PartialDecryptionRequestProto::parse(request).map_err(|e| { + status::internal(format!("Failed to parse PartialDecryptionRequestProto: {}", e)) + })?; + let request = PartialDecryptionRequest::from_proto(request_proto, self)?; + let response = self.generate_partial_decryption_response(&request)?; + response + .to_proto(self) + .map_err(|e| status::internal(format!("ToProto error: {}", e)))? + .serialize() + .map_err(|e| status::internal(format!("Serialize error: {}", e))) + } + + /// SAFETY: `out` and `out_status_message` must not be null. + unsafe fn generate_partial_decryption_response_ffi( + &mut self, + request: &[u8], + out: *mut Vec, + out_status_message: *mut cxx::UniquePtr, + ) -> i32 { + match self.generate_partial_decryption_response_serialized(request) { + Ok(response) => { + *out = response; + 0 + } + Err(status_error) => { + let ffi_status: FfiStatus = status_error.into(); + *out_status_message = ffi_status.message; + ffi_status.code + } + } + } } /// CXX bridge to call ShellTestingDecryptor from C++, using serialized protos as input and output. @@ -177,7 +241,7 @@ impl ShellTestingDecryptor { /// SAFETY: all functions in this module are only called from the wrapping C++ library, /// ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer /// arguments are not null. -#[cxx::bridge(namespace = "secure_aggregation")] +#[cxx::bridge(namespace = "secure_aggregation::testing")] pub mod ffi { struct EncodedDataEntry { key: String, @@ -209,6 +273,14 @@ pub mod ffi { out_status_message: *mut UniquePtr, ) -> i32; + #[rust_name = "generate_partial_decryption_response_ffi"] + unsafe fn generate_partial_decryption_response( + self: &mut ShellTestingDecryptor, + request: &[u8], + out: *mut Vec, + out_status_message: *mut UniquePtr, + ) -> i32; + unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) -> Box; } diff --git a/willow/src/testing_utils/shell_testing_decryptor_test.cc b/willow/src/testing_utils/shell_testing_decryptor_test.cc index efa58e1..a2be3dd 100644 --- a/willow/src/testing_utils/shell_testing_decryptor_test.cc +++ b/willow/src/testing_utils/shell_testing_decryptor_test.cc @@ -23,14 +23,13 @@ #include "willow/proto/willow/aggregation_config.pb.h" namespace secure_aggregation { -namespace willow { namespace { using secure_aggregation::secagg_internal::StatusIs; using ::testing::NotNull; TEST(ShellTestingDecryptorTest, CreateAndGenerateKey) { - AggregationConfigProto config; + willow::AggregationConfigProto config; config.set_max_number_of_decryptors(1); config.set_max_number_of_clients(1); config.set_max_decryptor_dropouts(0); @@ -40,7 +39,7 @@ TEST(ShellTestingDecryptorTest, CreateAndGenerateKey) { vector_config.set_bound(100); SECAGG_ASSERT_OK_AND_ASSIGN(auto decryptor, - ShellTestingDecryptor::Create(config)); + testing::ShellTestingDecryptor::Create(config)); ASSERT_THAT(decryptor, NotNull()); SECAGG_ASSERT_OK_AND_ASSIGN(const auto& pk, decryptor->GeneratePublicKey()); @@ -49,17 +48,16 @@ TEST(ShellTestingDecryptorTest, CreateAndGenerateKey) { TEST(ShellTestingDecryptorTest, InvalidAggregationConfig) { // Aggregation config with no metrics. - AggregationConfigProto config_proto; + willow::AggregationConfigProto config_proto; config_proto.set_max_number_of_decryptors(1); config_proto.set_max_decryptor_dropouts(0); config_proto.set_max_number_of_clients(2); config_proto.set_session_id("test"); // Initialization fails because aggregation config is invalid. - EXPECT_THAT(ShellTestingDecryptor::Create(config_proto), + EXPECT_THAT(testing::ShellTestingDecryptor::Create(config_proto), StatusIs(absl::StatusCode::kInvalidArgument)); } } // namespace -} // namespace willow } // namespace secure_aggregation diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index 5ca4f61..ff3ec8b 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -58,7 +58,7 @@ pub struct ServerState { decryptor_public_key_shares: HashMap>, /// Running sum of client ciphertexts. client_sum: Option<(Kahe::Ciphertext, Vahe::RecoverCiphertext)>, - /// Running sum of partial decryption ciphertexts. + /// Running sum of partial decryptions. partial_decryption_sum: Option, }