use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::Mutex; use tracing::{debug, info}; use crate::config; use crate::crypto; use crate::crypto::wal; #[derive(Debug, Clone, Serialize, Deserialize)] struct MtimeEntry { db_mt: u64, wal_mt: u64, path: String, } #[derive(Debug, Clone)] struct CacheEntry { db_mtime: u64, wal_mtime: u64, decrypted_path: PathBuf, } /// 解密后数据库的 mtime-aware 缓存 /// /// 当数据库文件(.db)或 WAL 文件(.db-wal)的 mtime 发生变化时, /// 自动重新解密并更新缓存。跨进程重启可通过持久化 mtime 文件复用已解密的 DB。 pub struct DbCache { db_dir: PathBuf, cache_dir: PathBuf, all_keys: HashMap, // rel_key -> enc_key(hex) inner: Arc>>, } impl DbCache { pub async fn new( db_dir: PathBuf, all_keys: HashMap, ) -> Result { let cache_dir = config::cache_dir(); tokio::fs::create_dir_all(&cache_dir).await?; let inner: HashMap = HashMap::new(); let cache = DbCache { db_dir, cache_dir, all_keys, inner: Arc::new(Mutex::new(inner)), }; cache.load_persistent().await; Ok(cache) } fn cache_file_path(&self, rel_key: &str) -> PathBuf { let hash = format!("{:x}", md5::compute(rel_key.as_bytes())); self.cache_dir.join(format!("{}.db", hash)) } /// 从持久化文件加载 mtime 记录,复用未过期的解密文件 async fn load_persistent(&self) { let mtime_file = config::mtime_file(); let content = match tokio::fs::read_to_string(&mtime_file).await { Ok(c) => c, Err(_) => return, }; let saved: HashMap = match serde_json::from_str(&content) { Ok(v) => v, Err(_) => return, }; let mut inner = self.inner.lock().await; let mut reused = 0usize; for (rel_key, entry) in &saved { let dec_path = PathBuf::from(&entry.path); if !dec_path.exists() { continue; } let db_path = self.db_dir.join(rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR).replace('/', std::path::MAIN_SEPARATOR_STR)); let wal_path = wal_path_for(&db_path); let db_mt = mtime_nanos(&db_path); let wal_mt = if wal_path.exists() { mtime_nanos(&wal_path) } else { 0 }; if db_mt == entry.db_mt && wal_mt == entry.wal_mt { inner.insert(rel_key.clone(), CacheEntry { db_mtime: db_mt, wal_mtime: wal_mt, decrypted_path: dec_path, }); reused += 1; } } if reused > 0 { info!(reused, "复用已解密 DB"); } } /// 持久化 mtime 记录 async fn save_persistent(&self) { let mtime_file = config::mtime_file(); let inner = self.inner.lock().await; let data: HashMap = inner.iter().map(|(k, v)| { (k.clone(), MtimeEntry { db_mt: v.db_mtime, wal_mt: v.wal_mtime, path: v.decrypted_path.to_string_lossy().into_owned(), }) }).collect(); drop(inner); if let Ok(json) = serde_json::to_string_pretty(&data) { let _ = tokio::fs::write(&mtime_file, json).await; } } /// 获取解密后的数据库路径 /// /// 如果 mtime 未变,直接返回缓存路径;否则重新解密 pub async fn get(&self, rel_key: &str) -> Result> { let enc_key_hex = match self.all_keys.get(rel_key) { Some(k) => k.clone(), None => return Ok(None), }; let db_path = self.db_dir.join( rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR) .replace('/', std::path::MAIN_SEPARATOR_STR) ); if !db_path.exists() { return Ok(None); } let wal_path = wal_path_for(&db_path); let db_mt = mtime_nanos(&db_path); let wal_mt = if wal_path.exists() { mtime_nanos(&wal_path) } else { 0 }; // 检查缓存 { let inner = self.inner.lock().await; if let Some(entry) = inner.get(rel_key) { if entry.db_mtime == db_mt && entry.wal_mtime == wal_mt && entry.decrypted_path.exists() { debug!(db = rel_key, "缓存命中"); return Ok(Some(entry.decrypted_path.clone())); } } } // 需要重新解密 let out_path = self.cache_file_path(rel_key); let enc_key_bytes = hex_to_32bytes(&enc_key_hex) .with_context(|| format!("密钥格式错误: {}", rel_key))?; let t0 = std::time::Instant::now(); let db_path2 = db_path.clone(); let out_path2 = out_path.clone(); let key_copy = enc_key_bytes; tokio::task::spawn_blocking(move || { crypto::full_decrypt(&db_path2, &out_path2, &key_copy) }).await??; // 应用 WAL if wal_path.exists() { let out_path3 = out_path.clone(); let wal_path3 = wal_path.clone(); let key_copy2 = enc_key_bytes; tokio::task::spawn_blocking(move || { wal::apply_wal(&wal_path3, &out_path3, &key_copy2) }).await??; } let elapsed_ms = t0.elapsed().as_millis(); info!(db = rel_key, elapsed_ms, "解密完成"); // 更新内存缓存 { let mut inner = self.inner.lock().await; inner.insert(rel_key.to_string(), CacheEntry { db_mtime: db_mt, wal_mtime: wal_mt, decrypted_path: out_path.clone(), }); } self.save_persistent().await; Ok(Some(out_path)) } } pub(super) fn mtime_nanos(path: &Path) -> u64 { std::fs::metadata(path) .and_then(|m| m.modified()) .map(|t| t.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_nanos() as u64) .unwrap_or(0) } /// `foo/bar.db` → `foo/bar.db-wal`(用 OsString 拼接,避免 display() 的 UTF-8 问题) fn wal_path_for(db_path: &Path) -> PathBuf { let mut name = db_path.file_name().unwrap_or_default().to_os_string(); name.push("-wal"); db_path.with_file_name(name) } fn hex_to_32bytes(s: &str) -> Result<[u8; 32]> { if s.len() != 64 { anyhow::bail!("密钥 hex 长度应为 64,实际为 {}", s.len()); } let mut out = [0u8; 32]; for i in 0..32 { out[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16) .with_context(|| format!("非法 hex 字符 at {}", i * 2))?; } Ok(out) }