Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions willow/proto/willow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ proto_library(
deps = [
":aggregation_config_proto",
":messages_proto",
"//willow/proto/shell:shell_ciphertexts_proto",
],
)

Expand Down
12 changes: 12 additions & 0 deletions willow/proto/willow/messages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,15 @@ message VerifierStateProto {
bytes nonce_lower_bound = 2;
bytes nonce_upper_bound = 3;
}

// Result of finalizing an accumulator.
message FinalizedAccumulatorResult {
// Serialized decryption request to include in a DecryptRequest to be sent to
// the decryptor service.
bytes decryption_request = 1;

// Serialized state for creating a final result decryptor, which will handle
// the response from the decryptor service and produce a plaintext aggregation
// result.
bytes final_result_decryptor_state = 2;
}
8 changes: 7 additions & 1 deletion willow/proto/willow/server_accumulator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ message ServerAccumulatorState {
ServerStateProto server_state = 1;
AggregationConfigProto aggregation_config = 3;
// We have one verifier state per range of nonces processed by this
// accumulator. States gat merged when adjacent ranges are processed.
// accumulator. States get merged when adjacent ranges are processed.
repeated VerifierStateProto verifier_states = 2;
// The ranges of nonces processed by this accumulator. In the same order as
// the corresponding verifier states.
Expand All @@ -43,3 +43,9 @@ message NonceRange {
bytes start = 1; // Inclusive.
bytes end = 2; // Exclusive.
}

// State for creating a FinalResultDecryptor.
message FinalResultDecryptorState {
ServerStateProto server_state = 1;
AggregationConfigProto aggregation_config = 2;
}
5 changes: 5 additions & 0 deletions willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ cc_library(
"@abseil-cpp//absl/strings",
"@cxx.rs//:core",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
"//willow/src/input_encoding:codec",
],
)

Expand All @@ -77,6 +79,7 @@ cc_test(
"@abseil-cpp//absl/status:statusor",
"//shell_wrapper:status_matchers",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
"//willow/src/input_encoding:codec",
"//willow/src/testing_utils:shell_testing_decryptor_cc",
Expand All @@ -102,6 +105,8 @@ rust_library(
"//shell_wrapper:status",
"//willow/proto/willow:aggregation_config_rust_proto",
"//willow/proto/willow:server_accumulator_rust_proto",
"//willow/proto/shell:shell_ciphertexts_rust_proto",
"//willow/proto/willow:messages_rust_proto",
"//willow/src/shell:kahe_shell",
"//willow/src/shell:parameters_shell",
"//willow/src/shell:vahe_shell",
Expand Down
4 changes: 3 additions & 1 deletion willow/src/api/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ namespace willow {
namespace {

using secure_aggregation::secagg_internal::StatusIs;
using secure_aggregation::testing::ShellTestingDecryptor;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Pair;
using ::testing::UnorderedElementsAre;

Expand Down Expand Up @@ -108,7 +110,7 @@ TEST(WillowShellClientTest, InitializeAndGenerateContribution) {
for (const auto& [name, values] : encoded_data) {
EXPECT_TRUE(decrypted_encoded_data.contains(name));
const auto& decrypted_values = decrypted_encoded_data[name];
EXPECT_THAT(decrypted_values, testing::ElementsAreArray(values));
EXPECT_THAT(decrypted_values, ElementsAreArray(values));
}

// Decode decrypted data.
Expand Down
95 changes: 81 additions & 14 deletions willow/src/api/server_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
Expand All @@ -26,12 +27,13 @@
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/server_accumulator.rs.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation {
namespace willow {

absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
WillowShellServerAccumulator::Create(
const willow::AggregationConfigProto& aggregation_config) {
absl::StatusOr<std::unique_ptr<ServerAccumulator>> ServerAccumulator::Create(
const AggregationConfigProto& aggregation_config) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code =
Expand All @@ -41,12 +43,11 @@ WillowShellServerAccumulator::Create(
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
}

absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
WillowShellServerAccumulator::CreateFromSerializedState(
std::string serialized_state) {
absl::StatusOr<std::unique_ptr<ServerAccumulator>>
ServerAccumulator::CreateFromSerializedState(std::string serialized_state) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState(
Expand All @@ -55,17 +56,17 @@ WillowShellServerAccumulator::CreateFromSerializedState(
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::WrapUnique(new WillowShellServerAccumulator(IntoBox(out)));
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
}

absl::Status WillowShellServerAccumulator::ProcessClientMessages(
willow::ClientMessageRange client_messages) {
absl::Status ServerAccumulator::ProcessClientMessages(
ClientMessageRange client_messages) {
auto serialized_client_messages = client_messages.SerializeAsString();
client_messages.Clear();
return ProcessClientMessages(std::move(serialized_client_messages));
}

absl::Status WillowShellServerAccumulator::ProcessClientMessages(
absl::Status ServerAccumulator::ProcessClientMessages(
std::string serialized_client_messages) {
std::unique_ptr<std::string> status_message;
int status_code = accumulator_->ProcessClientMessages(
Expand All @@ -77,8 +78,8 @@ absl::Status WillowShellServerAccumulator::ProcessClientMessages(
return absl::OkStatus();
}

absl::Status WillowShellServerAccumulator::Merge(
std::unique_ptr<WillowShellServerAccumulator> other) {
absl::Status ServerAccumulator::Merge(
std::unique_ptr<ServerAccumulator> other) {
std::unique_ptr<std::string> status_message;
int status_code =
accumulator_->Merge(std::move(other->accumulator_), &status_message);
Expand All @@ -88,7 +89,7 @@ absl::Status WillowShellServerAccumulator::Merge(
return absl::OkStatus();
}

absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
absl::StatusOr<std::string> ServerAccumulator::ToSerializedState() {
rust::Vec<uint8_t> serialized_state;
std::unique_ptr<std::string> status_message;
int status_code =
Expand All @@ -100,4 +101,70 @@ absl::StatusOr<std::string> WillowShellServerAccumulator::ToSerializedState() {
serialized_state.size());
}

absl::StatusOr<FinalizedAccumulatorResult> ServerAccumulator::Finalize() && {
// Finalize accumulator in Rust and store the serialized results.
rust::Vec<uint8_t> decryption_request;
rust::Vec<uint8_t> final_result_decryptor_state;
std::unique_ptr<std::string> status_message;
int status_code = secure_aggregation::FinalizeServerAccumulator(
std::move(accumulator_), &decryption_request,
&final_result_decryptor_state, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}

// Pack the two serialized results into a single proto.
FinalizedAccumulatorResult result_proto;
result_proto.set_decryption_request(
std::string(reinterpret_cast<const char*>(decryption_request.data()),
decryption_request.size()));
result_proto.set_final_result_decryptor_state(std::string(
reinterpret_cast<const char*>(final_result_decryptor_state.data()),
final_result_decryptor_state.size()));

return result_proto;
}

absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
FinalResultDecryptor::CreateFromSerialized(
std::string final_result_decryptor_state) {
secure_aggregation::FinalResultDecryptor* out;
std::unique_ptr<std::string> status_message;
int status_code =
secure_aggregation::CreateFinalResultDecryptorFromSerialized(
std::make_unique<std::string>(
std::move(final_result_decryptor_state)),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
return absl::WrapUnique(new FinalResultDecryptor(
secure_aggregation::FinalResultDecryptorIntoBox(out)));
}

absl::StatusOr<EncodedData> FinalResultDecryptor::Decrypt(
std::string serialized_partial_decryption_response) {
rust::Vec<EncodedDataEntry> out;
std::unique_ptr<std::string> status_message;
int status_code = aggregated_ciphertexts_->Decrypt(
std::make_unique<std::string>(
std::move(serialized_partial_decryption_response)),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
EncodedData encoded_data;
for (const auto& rust_entry : out) {
std::string key(rust_entry.key);
std::vector<int64_t> val;
val.reserve(rust_entry.values.size());
for (auto v : rust_entry.values) {
val.push_back(static_cast<int64_t>(v));
}
encoded_data[std::move(key)] = std::move(val);
}
return encoded_data;
}

} // namespace willow
} // namespace secure_aggregation
49 changes: 40 additions & 9 deletions willow/src/api/server_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,84 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "include/cxx.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/messages.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/server_accumulator.rs.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation {
namespace willow {

// Holds the relevant state from a finalized accumulation, and can decrypt the
// final result using the response from the decryptor service.
class FinalResultDecryptor {
public:
// Creates a new final result decryptor from the given serialized
// state, likely coming from a FinalizedAccumulatorResult.
static absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
CreateFromSerialized(std::string final_result_decryptor_state);

// Decrypts final result using the given partial decryption
// response.
absl::StatusOr<EncodedData> Decrypt(
std::string serialized_partial_decryption_response);

private:
explicit FinalResultDecryptor(
rust::Box<secure_aggregation::FinalResultDecryptor>
aggregated_ciphertexts)
: aggregated_ciphertexts_(std::move(aggregated_ciphertexts)) {}

rust::Box<secure_aggregation::FinalResultDecryptor> aggregated_ciphertexts_;
};

// Implements an accumulator class intended to be used by a batch processing
// system. Combines both the server and the verifier functionality of willow_v1,
// using SHELL for the underlying cryptography.
class WillowShellServerAccumulator {
class ServerAccumulator {
public:
// Creates a new accumulator with the given aggregation_config and empty
// state.
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>> Create(
const willow::AggregationConfigProto& aggregation_config);
static absl::StatusOr<std::unique_ptr<ServerAccumulator>> Create(
const AggregationConfigProto& aggregation_config);

// Creates a new accumulator from the given serialized state, which must
// correspond to a serialized ServerAccumulatorState proto.
static absl::StatusOr<std::unique_ptr<WillowShellServerAccumulator>>
static absl::StatusOr<std::unique_ptr<ServerAccumulator>>
CreateFromSerializedState(std::string serialized_state);

// Processes a list of client messages. If an invalid message is encountered,
// an error is logged and processing continues.
absl::Status ProcessClientMessages(
willow::ClientMessageRange client_messages);
absl::Status ProcessClientMessages(ClientMessageRange client_messages);

// Processes a list of client messages, given as a serialized
// ClientMessageList proto.
absl::Status ProcessClientMessages(std::string serialized_client_messages);

// Merges the state of `other` into the current accumulator.
absl::Status Merge(std::unique_ptr<WillowShellServerAccumulator> other);
absl::Status Merge(std::unique_ptr<ServerAccumulator> other);

// Converts the current state of the accumulator to a serialized
// ServerAccumulatorState proto.
absl::StatusOr<std::string> ToSerializedState();

// Finalizes the accumulator and returns a proto that holds the serialized
// decryption request (to be sent to the decryptor service) and the
// serialized decryptor state (to create a FinalResultDecryptor). This
// consumes the accumulator.
absl::StatusOr<FinalizedAccumulatorResult> Finalize() &&;

private:
explicit WillowShellServerAccumulator(
explicit ServerAccumulator(
rust::Box<secure_aggregation::ServerAccumulator> accumulator)
: accumulator_(std::move(accumulator)) {}

rust::Box<secure_aggregation::ServerAccumulator> accumulator_;
};

} // namespace willow
} // namespace secure_aggregation

#endif // SECURE_AGGREGATION_WILLOW_SRC_API_SERVER_ACCUMULATOR_H_
Loading