Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 109 additions & 4 deletions crates/openfang-extensions/src/vault.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const SALT_LEN: usize = 16;
/// Nonce length for AES-256-GCM.
const NONCE_LEN: usize = 12;
/// Magic bytes for vault file format versioning.
#[allow(dead_code)]
const VAULT_MAGIC: &[u8; 4] = b"OFV1";

/// On-disk vault format (encrypted).
Expand Down Expand Up @@ -312,14 +311,35 @@ impl CredentialVault {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&self.path, content)?;

// Write to file: magic bytes + JSON
let mut file_content = Vec::with_capacity(VAULT_MAGIC.len() + content.len());
file_content.extend_from_slice(VAULT_MAGIC);
file_content.extend_from_slice(content.as_bytes());
std::fs::write(&self.path, &file_content)?;
Ok(())
}

/// Load and decrypt vault from disk.
fn load(&mut self, master_key: &[u8; 32]) -> ExtensionResult<()> {
let content = std::fs::read_to_string(&self.path)?;
let vault_file: VaultFile = serde_json::from_str(&content)
let raw = std::fs::read(&self.path)?;

// Determine if file has magic prefix or is legacy JSON
let json_bytes = if raw.len() >= 4 && &raw[..4] == VAULT_MAGIC {
&raw[4..]
} else if raw.first() == Some(&b'{') {
// Legacy format: plain JSON without magic prefix
&raw[..]
} else {
return Err(ExtensionError::Vault(
"Invalid vault file: unrecognized format (expected OFV1 header or JSON)"
.to_string(),
));
};

let content = std::str::from_utf8(json_bytes)
.map_err(|e| ExtensionError::Vault(format!("Vault file is not valid UTF-8: {e}")))?;
let vault_file: VaultFile = serde_json::from_str(content)
.map_err(|e| ExtensionError::Vault(format!("Vault file parse failed: {e}")))?;

if vault_file.version != 1 {
Expand Down Expand Up @@ -590,4 +610,89 @@ mod tests {
let k2 = derive_key(&master, &salt).unwrap();
assert_eq!(k1.as_ref(), k2.as_ref());
}

#[test]
fn vault_file_starts_with_magic_bytes() {
let (_dir, mut vault) = test_vault();
let key = random_key();

vault.init_with_key(key).unwrap();
vault
.set(
"TEST_KEY".to_string(),
Zeroizing::new("test_value".to_string()),
)
.unwrap();

// Read raw bytes and verify magic prefix
let raw = std::fs::read(&vault.path).unwrap();
assert!(
raw.len() >= 4,
"Vault file too short to contain magic bytes"
);
assert_eq!(
&raw[..4],
VAULT_MAGIC,
"Vault file should start with OFV1 magic bytes"
);

// The rest should be valid JSON
let json_part = std::str::from_utf8(&raw[4..]).unwrap();
let parsed: serde_json::Value = serde_json::from_str(json_part).unwrap();
assert!(parsed.get("version").is_some());
assert!(parsed.get("ciphertext").is_some());
}

#[test]
fn vault_rejects_wrong_magic_bytes() {
let (dir, _vault) = test_vault();
let vault_path = dir.path().join("vault.enc");

// Write a file with wrong magic bytes
let bad_content = b"BAD!{\"not\": \"valid\"}";
std::fs::write(&vault_path, bad_content).unwrap();

let mut vault = CredentialVault::new(vault_path);
let key = random_key();
let result = vault.unlock_with_key(key);

assert!(result.is_err(), "Should reject file with wrong magic bytes");
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("unrecognized format"),
"Error should mention unrecognized format, got: {err_msg}"
);
}

#[test]
fn vault_loads_legacy_json_without_magic() {
let (dir, mut vault) = test_vault();
let key = random_key();

// Init and store a secret (this writes with magic prefix)
vault.init_with_key(key.clone()).unwrap();
vault
.set(
"LEGACY_SECRET".to_string(),
Zeroizing::new("legacy_value".to_string()),
)
.unwrap();

// Strip the magic prefix to simulate a legacy vault file
let raw = std::fs::read(&vault.path).unwrap();
assert_eq!(&raw[..4], VAULT_MAGIC);
let legacy_json = &raw[4..];
std::fs::write(&vault.path, legacy_json).unwrap();

// Verify the file now starts with '{' (legacy format)
let raw_legacy = std::fs::read(&vault.path).unwrap();
assert_eq!(raw_legacy[0], b'{', "Legacy file should start with '{{' ");

// Load should succeed with backward compatibility
let mut vault2 = CredentialVault::new(dir.path().join("vault.enc"));
vault2.unlock_with_key(key).unwrap();

let val = vault2.get("LEGACY_SECRET").unwrap();
assert_eq!(val.as_str(), "legacy_value");
}
}