diff --git a/willow/src/shell/vahe.rs b/willow/src/shell/vahe.rs index 55c1544..0faa362 100644 --- a/willow/src/shell/vahe.rs +++ b/willow/src/shell/vahe.rs @@ -356,6 +356,44 @@ impl EncryptVerify for ShellVahe { } Ok(()) } + + fn verify_multiple_encrypts( + &self, + items: &[(&Self::EncryptionProof, &Self::PartialDecCiphertext, &[u8])], + ) -> Status { + if items.is_empty() { + return Ok(()); + } + + let (transcript, proof_seed) = self.get_transcript_and_proof_seed(b"encryption")?; + let verifier = RlweRelationVerifier::new(proof_seed.as_bytes(), self.ahe.num_coeffs()); + + for &(proof, ciphertext, nonce) in items { + let num_polynomials = ciphertext.0.len(); + if proof.0.len() != num_polynomials { + return Err(status::permission_denied( + "Invalid proof. Proof length does not match number of polynomials in ciphertext.", + )); + } + + let mut transcript = transcript.clone(); + transcript.append_message(b"nonce:", nonce); + for i in 0..num_polynomials { + let statement = RlweRelationProofStatement { + n: self.ahe.num_coeffs(), + context: self.ahe.rns_context(), + a: &self.ahe.public_key_component_a()?, + flip_a: false, + c: &ciphertext.0[i], + q: self.q, + bound_r: 1, + bound_e: 16, + }; + verifier.verify(&statement, &proof.0[i], &mut transcript)?; + } + } + Ok(()) + } } impl VerifiablePartialDec for ShellVahe { @@ -627,6 +665,77 @@ mod test { Ok(()) } + #[gtest] + fn test_verify_multiple_encrypts() -> googletest::Result<()> { + let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING)?; + let seed = SingleThreadHkdfPrng::generate_seed()?; + let mut prng = SingleThreadHkdfPrng::create(&seed)?; + let (_, pk_share, _) = vahe.verifiable_key_gen(&mut prng)?; + let pk = vahe.aggregate_public_key_shares(&[pk_share])?; + let plaintext = vec![47i64; 8]; + let nonce1 = b"nonce1"; + let nonce2 = b"nonce2"; + + let (ciphertext1, proof1) = vahe.verifiable_encrypt(&plaintext, &pk, nonce1, &mut prng)?; + let (ciphertext2, proof2) = vahe.verifiable_encrypt(&plaintext, &pk, nonce2, &mut prng)?; + + let items = vec![ + (&proof1, &ciphertext1.component_a, nonce1 as &[u8]), + (&proof2, &ciphertext2.component_a, nonce2 as &[u8]), + ]; + + vahe.verify_multiple_encrypts(&items)?; + Ok(()) + } + + #[gtest] + fn test_verify_multiple_encrypts_with_bad_nonce() -> googletest::Result<()> { + let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING)?; + let seed = SingleThreadHkdfPrng::generate_seed()?; + let mut prng = SingleThreadHkdfPrng::create(&seed)?; + let (_, pk_share, _) = vahe.verifiable_key_gen(&mut prng)?; + let pk = vahe.aggregate_public_key_shares(&[pk_share])?; + let plaintext = vec![47i64; 8]; + let nonce1 = b"nonce1"; + let nonce2 = b"nonce2"; + + let (ciphertext1, proof1) = vahe.verifiable_encrypt(&plaintext, &pk, nonce1, &mut prng)?; + let (ciphertext2, proof2) = vahe.verifiable_encrypt(&plaintext, &pk, nonce2, &mut prng)?; + + let items = vec![ + (&proof1, &ciphertext1.component_a, nonce1 as &[u8]), + (&proof2, &ciphertext2.component_a, b"bad_nonce" as &[u8]), + ]; + + let status = vahe.verify_multiple_encrypts(&items); + assert!(status.is_err()); + Ok(()) + } + + #[gtest] + fn test_verify_multiple_encrypts_with_bad_proof() -> googletest::Result<()> { + let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING)?; + let seed = SingleThreadHkdfPrng::generate_seed()?; + let mut prng = SingleThreadHkdfPrng::create(&seed)?; + let (_, pk_share, _) = vahe.verifiable_key_gen(&mut prng)?; + let pk = vahe.aggregate_public_key_shares(&[pk_share])?; + let plaintext = vec![47i64; 8]; + let nonce1 = b"nonce1"; + let nonce2 = b"nonce2"; + + let (ciphertext1, proof1) = vahe.verifiable_encrypt(&plaintext, &pk, nonce1, &mut prng)?; + let (ciphertext2, _) = vahe.verifiable_encrypt(&plaintext, &pk, nonce2, &mut prng)?; + + let items = vec![ + (&proof1, &ciphertext1.component_a, nonce1 as &[u8]), + (&proof1, &ciphertext2.component_a, nonce2 as &[u8]), // Using proof1 for ciphertext2 + ]; + + let status = vahe.verify_multiple_encrypts(&items); + assert!(status.is_err()); + Ok(()) + } + #[gtest] fn test_verifiable_partial_dec() -> googletest::Result<()> { let vahe = ShellVahe::new(make_ahe_config(), CONTEXT_STRING)?; diff --git a/willow/src/traits/vahe.rs b/willow/src/traits/vahe.rs index f3881cf..851dba4 100644 --- a/willow/src/traits/vahe.rs +++ b/willow/src/traits/vahe.rs @@ -68,6 +68,14 @@ pub trait EncryptVerify: VaheBase { ciphertext: &Self::PartialDecCiphertext, nonce: &[u8], ) -> Status; + + /// Verify that multiple encryption proofs are valid. + /// + /// `nonce` must match the nonce passed to `verifiable_encrypt` for each proof. + fn verify_multiple_encrypts( + &self, + items: &[(&Self::EncryptionProof, &Self::PartialDecCiphertext, &[u8])], + ) -> Status; } pub trait VerifiablePartialDec: VaheBase { diff --git a/willow/src/zk/linear_ip.rs b/willow/src/zk/linear_ip.rs index d96f4da..2c6a2a2 100644 --- a/willow/src/zk/linear_ip.rs +++ b/willow/src/zk/linear_ip.rs @@ -36,6 +36,7 @@ pub struct LinearInnerProductParameters { F: RistrettoPoint, F_: RistrettoPoint, G: Vec, + seed: Vec, } pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar { @@ -59,6 +60,7 @@ fn common_setup(length: usize, parameter_seed: &[u8]) -> LinearInnerProductParam ) }) .collect(), + seed: parameter_seed.to_vec(), } } @@ -67,11 +69,9 @@ fn append_params_to_transcript( params: &LinearInnerProductParameters, ) { transcript.append_u64(b"n", params.n as u64); - for G_i in ¶ms.G { - transcript.append_message(b"G_i", G_i.compress().as_bytes()); - } - transcript.append_message(b"F", params.F.compress().as_bytes()); - transcript.append_message(b"F_", params.F_.compress().as_bytes()); + // We append the seed not the resulting params themselves because appending that many params + // more than doubles the run time of both prove and verify. + transcript.append_message(b"seed", ¶ms.seed); } fn validate_and_append_point( diff --git a/willow/src/zk/rlwe_relation.rs b/willow/src/zk/rlwe_relation.rs index c8dbc45..a77cde1 100644 --- a/willow/src/zk/rlwe_relation.rs +++ b/willow/src/zk/rlwe_relation.rs @@ -266,15 +266,15 @@ fn create_public_vec( fn update_public_vec_for_range_proof( public_vec: &mut Vec, result: &mut Scalar, - R_r: &Vec, - R_e: &Vec, - R_vw: &Vec, - z_r: &Vec, - z_e: &Vec, - z_vw: &Vec, - psi_r: Scalar, - psi_e: Scalar, - psi_vw: Scalar, + R_r: &[Scalar], + R_e: &[Scalar], + R_vw: &[Scalar], + z_r: &[Scalar], + z_e: &[Scalar], + z_vw: &[Scalar], + psi_r: &[Scalar], + psi_e: &[Scalar], + psi_vw: &[Scalar], n: usize, range_comm_offset: usize, samples_required: usize, @@ -298,20 +298,14 @@ fn update_public_vec_for_range_proof( // The range proofs equation also involves length 128 innerproducts involving the relevant // psi these are included in the last 3*128 entries of the inner product vectors. - let mut phi_psi_r_pow = phi; - let mut phi2_psi_e_pow = phi2; - let mut phi3_psi_vw_pow = phi3; for i in 0..128 { - public_vec[i + range_comm_offset] = phi_psi_r_pow; - public_vec[i + range_comm_offset + 128] = phi2_psi_e_pow; - public_vec[i + range_comm_offset + 256] = phi3_psi_vw_pow; + public_vec[i + range_comm_offset] = phi * Scalar::from(psi_r[i]); + public_vec[i + range_comm_offset + 128] = phi2 * Scalar::from(psi_e[i]); + public_vec[i + range_comm_offset + 256] = phi3 * Scalar::from(psi_vw[i]); // Add contributions of the range proofs to the overall inner product result. - *result += z_r[i] * phi_psi_r_pow; - *result += z_e[i] * phi2_psi_e_pow; - *result += z_vw[i] * phi3_psi_vw_pow; - phi_psi_r_pow *= psi_r; - phi2_psi_e_pow *= psi_e; - phi3_psi_vw_pow *= psi_vw; + *result += z_r[i] * public_vec[i + range_comm_offset]; + *result += z_e[i] * public_vec[i + range_comm_offset + 128]; + *result += z_vw[i] * public_vec[i + range_comm_offset + 256]; } } @@ -333,28 +327,41 @@ pub fn generate_challenge_matrix( result } -// Multiplies a 128 by n matrix m and a length n vector v. -// m is a binary matrix each column of which has entries given by the bits of a single entry in the -// input vector m. -// Both the output and v are vectors of 128 bit signed integers. -pub fn multiply_by_challenge_matrix( +// Applies a challenge matrix R1-R2 to a vector v and checks if the result satisfies the conditions +// for not needing to be rejected. An internal error is returned in the event of rejection otherwise +// the resulting vector z is returned. +// +// To understand the rejection conditions see the comment for generate_range_product. +pub fn try_matrices_and_compute_z( v: &[i128], - m: &[u128], + R1: &[u128], + R2: &[u128], + y: &[i128], + half_loose_bound: i128, ) -> Result, status::StatusError> { let n = v.len(); - if m.len() != n { - return Err(status::failed_precondition("m and v have different lengths".to_string())); + if n != R1.len() || n != R2.len() { + return Err(status::failed_precondition( + "R1, R2, and v must have the same length".to_string(), + )); } - - let mut result = vec![0 as i128; 128]; - for i in 0..n { - for j in 0..128 { - if m[i] & (1u128 << j) != 0 { - result[j] += v[i]; + let mut z = vec![0 as i128; 128]; + for j in 0..128 { + let mut u = 0i128; + for i in 0..n { + if R1[i] & (1u128 << j) != 0 { + u += v[i]; } + if R2[i] & (1u128 << j) != 0 { + u -= v[i]; + } + } + z[j] = u + y[j]; + if u.abs() > half_loose_bound / 128 || z[j].abs() > half_loose_bound { + return Err(status::internal("Sample Rejected")); } } - Ok(result) + Ok(z) } // Linearly combines the 128 vector challenges of a challenge matrix into a single vector challenge @@ -364,33 +371,38 @@ pub fn flatten_challenge_matrix( R1: Vec, R2: Vec, challenge_label: &'static [u8], -) -> Result<(Vec, Scalar), status::StatusError> { +) -> Result<(Vec, Vec), status::StatusError> { let n = R1.len(); if n != R2.len() { return Err(status::failed_precondition("R1 and R2 have different lengths".to_string())); } - let mut buf = [0u8; 64]; - transcript.challenge_bytes(challenge_label, &mut buf); - let psi = Scalar::from_bytes_mod_order_wide(&buf); - - let mut R = vec![Scalar::from(0 as u64); n]; - let mut psi_powers = [Scalar::from(1 as u64); 128]; - for j in 1..128 { - psi_powers[j] = psi_powers[j - 1] * psi; + let mut Rplus = vec![0u128; n]; + let mut Rminus = vec![0u128; n]; + let mut Rscalar = vec![Scalar::from(0u64); n]; + + let mut psi = [0u128; 128]; + let mut psi_scalar = vec![Scalar::from(0 as u64); 128]; + let mut buf = [0u8; 16]; + for j in 0..128 { + transcript.challenge_bytes(challenge_label, &mut buf); + // We only take challenges up to 2^121 so that the sum of 128 of them will fit in a u128. + psi[j] = u128::from_le_bytes(buf) >> 7; + psi_scalar[j] = Scalar::from(psi[j]); } for i in 0..n { for j in 0..128 { if R1[i] & (1u128 << j) != 0 { - R[i] += psi_powers[j]; + Rplus[i] += psi[j]; } if R2[i] & (1u128 << j) != 0 { - R[i] -= psi_powers[j]; + Rminus[i] += psi[j]; } } + Rscalar[i] = Scalar::from(Rplus[i]) - Scalar::from(Rminus[i]); } - Ok((R, psi)) + Ok((Rscalar, psi_scalar)) } // Check that loose_bound = bound*2500*sqrt(v.len()+1) fits within an i128. @@ -407,6 +419,16 @@ fn check_loose_bound_will_not_overflow(bound: u128, n: usize) -> Result<(), stat Ok(()) } +// Struct to hold the results of the generate_range_product function. +struct RangeProductMetadata { + R: Vec, + comm_y: RistrettoPoint, + y: Vec, + delta_y: Scalar, + psi: Vec, + z: Vec, +} + // Return the inner product that needs to be checked for the range proof, the commitment to y that // the verifier will need to verify it and the blinding information required for the proof. // @@ -429,10 +451,7 @@ fn generate_range_product( start: usize, transcript: &mut (impl Transcript + Clone), challenge_label: &'static [u8], -) -> Result< - (Vec, RistrettoPoint, Vec, Scalar, Scalar, Vec), - status::StatusError, -> { +) -> Result { // Check that computing loose bound does not result in an overflow. check_loose_bound_will_not_overflow(bound, v.len())?; @@ -453,7 +472,6 @@ fn generate_range_product( let mut z = vec![0 as i128; 128]; let mut attempts = 0; loop { - let mut done = true; attempts += 1; y = (0..128).map(|_| (rng.gen_range(0..possible_y) as i128)).collect(); for i in 0..128 { @@ -468,21 +486,9 @@ fn generate_range_product( // subtracting the other we get a challenge matrix with the correct distribution. R1 = generate_challenge_matrix(transcript, challenge_label, v.len()); R2 = generate_challenge_matrix(transcript, challenge_label, v.len()); - let u1 = multiply_by_challenge_matrix(v, &R1)?; - let u2 = multiply_by_challenge_matrix(v, &R2)?; - for i in 0..128 { - let u = u1[i] - u2[i]; - if u.abs() > half_loose_bound / 128 { - done = false; - break; - } - z[i] = u + y[i]; - if z[i].abs() > half_loose_bound { - done = false; - break; - } - } - if done { + let z_or_error = try_matrices_and_compute_z(v, &R1, &R2, &y, half_loose_bound); + if z_or_error.is_ok() { + z = z_or_error.unwrap(); break; } if attempts > 1000 { @@ -512,17 +518,19 @@ fn generate_range_product( }) .collect(); - Ok((R, comm_y, scalar_y, delta_y, psi, scalar_z)) + Ok(RangeProductMetadata { R, comm_y, y: scalar_y, delta_y, psi, z: scalar_z }) } +// Verifies the z bound and returns the linear combination of the 128 rows of the range proof +// projection matrix R and a vector psi of the coefficients used in that linear combination. fn generate_range_product_for_verification_and_verify_z_bound( n: usize, bound: u128, comm_y: RistrettoPoint, - z: &Vec, + z: &[Scalar], transcript: &mut impl Transcript, challenge_label: &'static [u8], -) -> Result<(Vec, Scalar), status::StatusError> { +) -> Result<(Vec, Vec), status::StatusError> { // Check that computing loose bound does not result in an overflow. check_loose_bound_will_not_overflow(bound, n)?; @@ -762,7 +770,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi // Get inner products to prove for range proofs. We then need to check // + = mod P etc. // This is explained in more detail in the comment above generate_range_product. - let (R_r, comm_y_r, y_r, delta_y_r, psi_r, z_r) = generate_range_product( + let range_product_r = generate_range_product( &signed_r, bound_r, &self.prover, @@ -770,7 +778,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi transcript, b"range matrix r", )?; - let (R_e, comm_y_e, y_e, delta_y_e, psi_e, z_e) = generate_range_product( + let range_product_e = generate_range_product( &signed_e, bound_e, &self.prover, @@ -778,7 +786,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi transcript, b"range matrix e", )?; - let (R_vw, comm_y_vw, y_vw, delta_y_vw, psi_vw, z_vw) = generate_range_product( + let range_product_vw = generate_range_product( &signed_vw, q * (n as u128), &self.prover, @@ -792,15 +800,15 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi update_public_vec_for_range_proof( &mut public_vec, &mut result, - &R_r, - &R_e, - &R_vw, - &z_r, - &z_e, - &z_vw, - psi_r, - psi_e, - psi_vw, + &range_product_r.R, + &range_product_e.R, + &range_product_vw.R, + &range_product_r.z, + &range_product_e.z, + &range_product_vw.z, + &range_product_r.psi, + &range_product_e.psi, + &range_product_vw.psi, n, range_comm_offset, samples_required, @@ -818,13 +826,21 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi private_vec[i + n + n + n] = scalar_wrho_vec[i]; } for i in 0..128 { - private_vec[i + range_comm_offset] = y_r[i]; - private_vec[i + range_comm_offset + 128] = y_e[i]; - private_vec[i + range_comm_offset + 256] = y_vw[i]; + private_vec[i + range_comm_offset] = range_product_r.y[i]; + private_vec[i + range_comm_offset + 128] = range_product_e.y[i]; + private_vec[i + range_comm_offset + 256] = range_product_vw.y[i]; } - let private_vec_comm = comm_rev + comm_wrho + comm_y_r + comm_y_e + comm_y_vw; - let blinding_factor = delta_rev + delta_w + delta_y_r + delta_y_e + delta_y_vw; + let private_vec_comm = comm_rev + + comm_wrho + + range_product_r.comm_y + + range_product_e.comm_y + + range_product_vw.comm_y; + let blinding_factor = delta_rev + + delta_w + + range_product_r.delta_y + + range_product_e.delta_y + + range_product_vw.delta_y; // Set up linear product statement and prove it let lip_statement = LinearInnerProductProofStatement { @@ -841,12 +857,12 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi Ok(RlweRelationProof { comm_rev: comm_rev.compress(), comm_wrho: comm_wrho.compress(), - comm_y_r: comm_y_r.compress(), - comm_y_e: comm_y_e.compress(), - comm_y_vw: comm_y_vw.compress(), - z_r: z_r, - z_e: z_e, - z_vw: z_vw, + comm_y_r: range_product_r.comm_y.compress(), + comm_y_e: range_product_e.comm_y.compress(), + comm_y_vw: range_product_vw.comm_y.compress(), + z_r: range_product_r.z, + z_e: range_product_e.z, + z_vw: range_product_vw.z, lip_proof: lip_proof, }) } @@ -977,9 +993,9 @@ impl<'a> ZeroKnowledgeVerifier, RlweRelationProof &proof.z_r, &proof.z_e, &proof.z_vw, - psi_r, - psi_e, - psi_vw, + &psi_r, + &psi_e, + &psi_vw, n, range_comm_offset, samples_required, @@ -1228,16 +1244,44 @@ mod tests { } #[test] - fn test_multiply_by_challenge_matrix_basic_case() -> googletest::Result<()> { - let v = &[10i128, 20i128]; - let m = &[(1u128 << 0) | (1u128 << 2), (1u128 << 1) | (1u128 << 2)]; + fn test_try_matrices_and_compute_z_valid() -> googletest::Result<()> { + let v = [1, -2, 3, -4]; + let R1 = [1, 2, 3, 4]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 10000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound)?; + let mut expected_z = vec![1; 128]; + expected_z[0] += 10; + expected_z[1] += 0; + expected_z[2] += -5; + verify_eq!(result, expected_z)?; + Ok(()) + } - let mut expected_result = vec![0i128; 128]; - expected_result[0] = 10; - expected_result[1] = 20; - expected_result[2] = 30; + #[test] + fn test_try_matrices_and_compute_z_mismatched_lengths() -> googletest::Result<()> { + let v = [1, -2, 3, -4]; + let R1 = [1, 2, 3]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 1000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound); + assert!(result.is_err()); + verify_eq!(result.unwrap_err().message(), "R1, R2, and v must have the same length")?; + Ok(()) + } - assert_eq!(multiply_by_challenge_matrix(v, m).unwrap(), expected_result); + #[test] + fn test_try_matrices_and_compute_z_sample_rejected() -> googletest::Result<()> { + let v = [1000, -2000, 3000, -4000]; + let R1 = [1, 2, 3, 4]; + let R2 = [4, 3, 2, 1]; + let y = [1; 128]; + let half_loose_bound = 100000; + let result = try_matrices_and_compute_z(&v, &R1, &R2, &y, half_loose_bound); + assert!(result.is_err()); + verify_eq!(result.unwrap_err().message(), "Sample Rejected")?; Ok(()) } @@ -1247,33 +1291,31 @@ mod tests { let v = [1, -2, 3, -4]; let prover = LinearInnerProductProver::new(b"42", 132); let mut transcript = MerlinTranscript::new(b"42"); - let (R, comm_y, y, delta_y, psi, z) = + let result = generate_range_product(&v, bound, &prover, 4, &mut transcript, b"test vector")?; let mut private_vec = [Scalar::from(0u128); 132]; for i in 0..4 { private_vec[i] = Scalar::from((v[i] + (bound as i128)) as u128) - Scalar::from(bound); } - for i in 4..132 { - private_vec[i] = y[i - 4]; + for i in 0..128 { + private_vec[i + 4] = result.y[i]; } let mut public_vec = [Scalar::from(0u128); 132]; for i in 0..4 { - public_vec[i] = R[i]; + public_vec[i] = result.R[i]; } - let mut psi_pow = Scalar::from(1u128); - let mut result = Scalar::from(0u128); - for i in 4..132 { - public_vec[i] = psi_pow; - result += z[i - 4] * psi_pow; - psi_pow *= psi; + let mut inner_product = Scalar::from(0u128); + for i in 0..128 { + public_vec[i + 4] = result.psi[i]; + inner_product += result.z[i] * result.psi[i]; } let mut expected_result = Scalar::from(0u128); for j in 0..132 { expected_result += public_vec[j] * private_vec[j]; } - assert_eq!(result, expected_result); - let expected_comm_y = prover.commit_partial(&y, delta_y, 4, 132)?; - assert_eq!(comm_y, expected_comm_y); + assert_eq!(inner_product, expected_result); + let expected_comm_y = prover.commit_partial(&result.y, result.delta_y, 4, 132)?; + assert_eq!(result.comm_y, expected_comm_y); Ok(()) }