diff --git a/crates/openfang-extensions/src/vault.rs b/crates/openfang-extensions/src/vault.rs index a264ffb..e439f1f 100644 --- a/crates/openfang-extensions/src/vault.rs +++ b/crates/openfang-extensions/src/vault.rs @@ -31,7 +31,6 @@ const SALT_LEN: usize = 16; /// Nonce length for AES-256-GCM. const NONCE_LEN: usize = 12; /// Magic bytes for vault file format versioning. -#[allow(dead_code)] const VAULT_MAGIC: &[u8; 4] = b"OFV1"; /// On-disk vault format (encrypted). @@ -312,14 +311,35 @@ impl CredentialVault { if let Some(parent) = self.path.parent() { std::fs::create_dir_all(parent)?; } - std::fs::write(&self.path, content)?; + + // Write to file: magic bytes + JSON + let mut file_content = Vec::with_capacity(VAULT_MAGIC.len() + content.len()); + file_content.extend_from_slice(VAULT_MAGIC); + file_content.extend_from_slice(content.as_bytes()); + std::fs::write(&self.path, &file_content)?; Ok(()) } /// Load and decrypt vault from disk. fn load(&mut self, master_key: &[u8; 32]) -> ExtensionResult<()> { - let content = std::fs::read_to_string(&self.path)?; - let vault_file: VaultFile = serde_json::from_str(&content) + let raw = std::fs::read(&self.path)?; + + // Determine if file has magic prefix or is legacy JSON + let json_bytes = if raw.len() >= 4 && &raw[..4] == VAULT_MAGIC { + &raw[4..] + } else if raw.first() == Some(&b'{') { + // Legacy format: plain JSON without magic prefix + &raw[..] + } else { + return Err(ExtensionError::Vault( + "Invalid vault file: unrecognized format (expected OFV1 header or JSON)" + .to_string(), + )); + }; + + let content = std::str::from_utf8(json_bytes) + .map_err(|e| ExtensionError::Vault(format!("Vault file is not valid UTF-8: {e}")))?; + let vault_file: VaultFile = serde_json::from_str(content) .map_err(|e| ExtensionError::Vault(format!("Vault file parse failed: {e}")))?; if vault_file.version != 1 { @@ -590,4 +610,89 @@ mod tests { let k2 = derive_key(&master, &salt).unwrap(); assert_eq!(k1.as_ref(), k2.as_ref()); } + + #[test] + fn vault_file_starts_with_magic_bytes() { + let (_dir, mut vault) = test_vault(); + let key = random_key(); + + vault.init_with_key(key).unwrap(); + vault + .set( + "TEST_KEY".to_string(), + Zeroizing::new("test_value".to_string()), + ) + .unwrap(); + + // Read raw bytes and verify magic prefix + let raw = std::fs::read(&vault.path).unwrap(); + assert!( + raw.len() >= 4, + "Vault file too short to contain magic bytes" + ); + assert_eq!( + &raw[..4], + VAULT_MAGIC, + "Vault file should start with OFV1 magic bytes" + ); + + // The rest should be valid JSON + let json_part = std::str::from_utf8(&raw[4..]).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(json_part).unwrap(); + assert!(parsed.get("version").is_some()); + assert!(parsed.get("ciphertext").is_some()); + } + + #[test] + fn vault_rejects_wrong_magic_bytes() { + let (dir, _vault) = test_vault(); + let vault_path = dir.path().join("vault.enc"); + + // Write a file with wrong magic bytes + let bad_content = b"BAD!{\"not\": \"valid\"}"; + std::fs::write(&vault_path, bad_content).unwrap(); + + let mut vault = CredentialVault::new(vault_path); + let key = random_key(); + let result = vault.unlock_with_key(key); + + assert!(result.is_err(), "Should reject file with wrong magic bytes"); + let err_msg = format!("{}", result.unwrap_err()); + assert!( + err_msg.contains("unrecognized format"), + "Error should mention unrecognized format, got: {err_msg}" + ); + } + + #[test] + fn vault_loads_legacy_json_without_magic() { + let (dir, mut vault) = test_vault(); + let key = random_key(); + + // Init and store a secret (this writes with magic prefix) + vault.init_with_key(key.clone()).unwrap(); + vault + .set( + "LEGACY_SECRET".to_string(), + Zeroizing::new("legacy_value".to_string()), + ) + .unwrap(); + + // Strip the magic prefix to simulate a legacy vault file + let raw = std::fs::read(&vault.path).unwrap(); + assert_eq!(&raw[..4], VAULT_MAGIC); + let legacy_json = &raw[4..]; + std::fs::write(&vault.path, legacy_json).unwrap(); + + // Verify the file now starts with '{' (legacy format) + let raw_legacy = std::fs::read(&vault.path).unwrap(); + assert_eq!(raw_legacy[0], b'{', "Legacy file should start with '{{' "); + + // Load should succeed with backward compatibility + let mut vault2 = CredentialVault::new(dir.path().join("vault.enc")); + vault2.unlock_with_key(key).unwrap(); + + let val = vault2.get("LEGACY_SECRET").unwrap(); + assert_eq!(val.as_str(), "legacy_value"); + } }