wx-cli/src/daemon/cache.rs

222 lines
7.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, info};
use crate::config;
use crate::crypto;
use crate::crypto::wal;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MtimeEntry {
db_mt: u64,
wal_mt: u64,
path: String,
}
#[derive(Debug, Clone)]
struct CacheEntry {
db_mtime: u64,
wal_mtime: u64,
decrypted_path: PathBuf,
}
/// 解密后数据库的 mtime-aware 缓存
///
/// 当数据库文件(.db或 WAL 文件(.db-wal的 mtime 发生变化时,
/// 自动重新解密并更新缓存。跨进程重启可通过持久化 mtime 文件复用已解密的 DB。
pub struct DbCache {
db_dir: PathBuf,
cache_dir: PathBuf,
all_keys: HashMap<String, String>, // rel_key -> enc_key(hex)
inner: Arc<Mutex<HashMap<String, CacheEntry>>>,
}
impl DbCache {
pub async fn new(
db_dir: PathBuf,
all_keys: HashMap<String, String>,
) -> Result<Self> {
let cache_dir = config::cache_dir();
tokio::fs::create_dir_all(&cache_dir).await?;
let inner: HashMap<String, CacheEntry> = HashMap::new();
let cache = DbCache {
db_dir,
cache_dir,
all_keys,
inner: Arc::new(Mutex::new(inner)),
};
cache.load_persistent().await;
Ok(cache)
}
fn cache_file_path(&self, rel_key: &str) -> PathBuf {
let hash = format!("{:x}", md5::compute(rel_key.as_bytes()));
self.cache_dir.join(format!("{}.db", hash))
}
/// 从持久化文件加载 mtime 记录,复用未过期的解密文件
async fn load_persistent(&self) {
let mtime_file = config::mtime_file();
let content = match tokio::fs::read_to_string(&mtime_file).await {
Ok(c) => c,
Err(_) => return,
};
let saved: HashMap<String, MtimeEntry> = match serde_json::from_str(&content) {
Ok(v) => v,
Err(_) => return,
};
let mut inner = self.inner.lock().await;
let mut reused = 0usize;
for (rel_key, entry) in &saved {
let dec_path = PathBuf::from(&entry.path);
if !dec_path.exists() {
continue;
}
let db_path = self.db_dir.join(rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR).replace('/', std::path::MAIN_SEPARATOR_STR));
let wal_path = wal_path_for(&db_path);
let db_mt = mtime_nanos(&db_path);
let wal_mt = if wal_path.exists() { mtime_nanos(&wal_path) } else { 0 };
if db_mt == entry.db_mt && wal_mt == entry.wal_mt {
inner.insert(rel_key.clone(), CacheEntry {
db_mtime: db_mt,
wal_mtime: wal_mt,
decrypted_path: dec_path,
});
reused += 1;
}
}
if reused > 0 {
info!(reused, "复用已解密 DB");
}
}
/// 持久化 mtime 记录
async fn save_persistent(&self) {
let mtime_file = config::mtime_file();
let inner = self.inner.lock().await;
let data: HashMap<String, MtimeEntry> = inner.iter().map(|(k, v)| {
(k.clone(), MtimeEntry {
db_mt: v.db_mtime,
wal_mt: v.wal_mtime,
path: v.decrypted_path.to_string_lossy().into_owned(),
})
}).collect();
drop(inner);
if let Ok(json) = serde_json::to_string_pretty(&data) {
let _ = tokio::fs::write(&mtime_file, json).await;
}
}
/// 获取解密后的数据库路径
///
/// 如果 mtime 未变,直接返回缓存路径;否则重新解密
pub async fn get(&self, rel_key: &str) -> Result<Option<PathBuf>> {
let enc_key_hex = match self.all_keys.get(rel_key) {
Some(k) => k.clone(),
None => return Ok(None),
};
let db_path = self.db_dir.join(
rel_key.replace('\\', std::path::MAIN_SEPARATOR_STR)
.replace('/', std::path::MAIN_SEPARATOR_STR)
);
if !db_path.exists() {
return Ok(None);
}
let wal_path = wal_path_for(&db_path);
let db_mt = mtime_nanos(&db_path);
let wal_mt = if wal_path.exists() { mtime_nanos(&wal_path) } else { 0 };
// 检查缓存
{
let inner = self.inner.lock().await;
if let Some(entry) = inner.get(rel_key) {
if entry.db_mtime == db_mt
&& entry.wal_mtime == wal_mt
&& entry.decrypted_path.exists()
{
debug!(db = rel_key, "缓存命中");
return Ok(Some(entry.decrypted_path.clone()));
}
}
}
// 需要重新解密
let out_path = self.cache_file_path(rel_key);
let enc_key_bytes = hex_to_32bytes(&enc_key_hex)
.with_context(|| format!("密钥格式错误: {}", rel_key))?;
let t0 = std::time::Instant::now();
let db_path2 = db_path.clone();
let out_path2 = out_path.clone();
let key_copy = enc_key_bytes;
tokio::task::spawn_blocking(move || {
crypto::full_decrypt(&db_path2, &out_path2, &key_copy)
}).await??;
// 应用 WAL
if wal_path.exists() {
let out_path3 = out_path.clone();
let wal_path3 = wal_path.clone();
let key_copy2 = enc_key_bytes;
tokio::task::spawn_blocking(move || {
wal::apply_wal(&wal_path3, &out_path3, &key_copy2)
}).await??;
}
let elapsed_ms = t0.elapsed().as_millis();
info!(db = rel_key, elapsed_ms, "解密完成");
// 更新内存缓存
{
let mut inner = self.inner.lock().await;
inner.insert(rel_key.to_string(), CacheEntry {
db_mtime: db_mt,
wal_mtime: wal_mt,
decrypted_path: out_path.clone(),
});
}
self.save_persistent().await;
Ok(Some(out_path))
}
}
pub(super) fn mtime_nanos(path: &Path) -> u64 {
std::fs::metadata(path)
.and_then(|m| m.modified())
.map(|t| t.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_nanos() as u64)
.unwrap_or(0)
}
/// `foo/bar.db` → `foo/bar.db-wal`(用 OsString 拼接,避免 display() 的 UTF-8 问题)
fn wal_path_for(db_path: &Path) -> PathBuf {
let mut name = db_path.file_name().unwrap_or_default().to_os_string();
name.push("-wal");
db_path.with_file_name(name)
}
fn hex_to_32bytes(s: &str) -> Result<[u8; 32]> {
if s.len() != 64 {
anyhow::bail!("密钥 hex 长度应为 64实际为 {}", s.len());
}
let mut out = [0u8; 32];
for i in 0..32 {
out[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
.with_context(|| format!("非法 hex 字符 at {}", i * 2))?;
}
Ok(out)
}