diff --git a/src/cli/transport.rs b/src/cli/transport.rs index 73c2f88..ce2eb0d 100644 --- a/src/cli/transport.rs +++ b/src/cli/transport.rs @@ -1,50 +1,32 @@ use anyhow::{bail, Context, Result}; +use serde::{Deserialize, Serialize}; use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; +use std::thread; use std::time::Duration; use crate::config; use crate::ipc::{Request, Response}; const STARTUP_TIMEOUT_SECS: u64 = 15; +const STOP_TIMEOUT_MS: u64 = 2_000; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PidFile { + pid: u32, + #[serde(default)] + exe: Option, +} /// 检查 daemon 是否存活 pub fn is_alive() -> bool { #[cfg(unix)] { - use std::os::unix::net::UnixStream; - let sock_path = config::sock_path(); - if !sock_path.exists() { - return false; - } - let mut stream = match UnixStream::connect(&sock_path) { - Ok(s) => s, - Err(_) => return false, - }; - stream.set_read_timeout(Some(Duration::from_secs(2))).ok(); - stream.set_write_timeout(Some(Duration::from_secs(2))).ok(); - - let req = serde_json::json!({"cmd": "ping"}); - if write!(stream, "{}\n", req).is_err() { - return false; - } - let mut line = String::new(); - let mut reader = BufReader::new(&stream); - if reader.read_line(&mut line).is_err() { - return false; - } - serde_json::from_str::(&line) - .ok() - .and_then(|v| v.get("pong").and_then(|p| p.as_bool())) - .unwrap_or(false) + ping_unix().unwrap_or(false) } #[cfg(windows)] { - use interprocess::local_socket::{prelude::*, GenericNamespaced, Stream}; - // 必须用 interprocess 自己的连接 API,和 server 保持一致 - match "wx-cli-daemon".to_ns_name::() { - Ok(name) => Stream::connect(name).is_ok(), - Err(_) => false, - } + ping_windows().unwrap_or(false) } #[cfg(not(any(unix, windows)))] { @@ -65,25 +47,33 @@ pub fn ensure_daemon() -> Result<()> { /// 停止 daemon(如果正在运行) pub fn stop_daemon() -> Result<()> { let pid_path = config::pid_path(); - if let Ok(pid_str) = std::fs::read_to_string(&pid_path) { - if let Ok(pid) = pid_str.trim().parse::() { - #[cfg(unix)] - { - let _ = std::process::Command::new("kill") - .arg("-TERM") - .arg(pid.to_string()) - .spawn(); + let pid_file = read_pid_file(&pid_path)?; + let daemon_alive = is_alive(); + + match pid_file { + Some(pid_file) => { + let belongs = pid_belongs_to_daemon(&pid_file)?; + if daemon_alive && !belongs { + bail!( + "daemon 正在运行,但 {} 指向的 PID {} 无法确认属于当前 wx-daemon", + pid_path.display(), + pid_file.pid + ); } - #[cfg(windows)] - { - let _ = std::process::Command::new("taskkill") - .args(["/F", "/PID", &pid.to_string()]) - .spawn(); + if belongs { + terminate_pid(pid_file.pid)?; } } + None if daemon_alive => { + bail!( + "daemon 正在运行,但 {} 缺失或损坏,无法安全停止", + pid_path.display() + ); + } + None => {} } - let _ = std::fs::remove_file(config::sock_path()); - let _ = std::fs::remove_file(&pid_path); + + cleanup_ipc_files(); Ok(()) } @@ -123,6 +113,7 @@ fn preflight_cli_dir_writable() -> Result<()> { /// 启动 daemon 进程(自身二进制,设置 WX_DAEMON_MODE=1) fn start_daemon() -> Result<()> { let exe = std::env::current_exe().context("无法获取当前可执行文件路径")?; + let child_pid: u32; // 预检:当前用户是否能写 ~/.wx-cli/。如果不能,给出可操作的错误信息, // 而不是 spawn 一个注定失败的 daemon 然后超时 15s。 @@ -138,7 +129,8 @@ fn start_daemon() -> Result<()> { let _ = std::fs::create_dir_all(parent); } let (stdout_stdio, stderr_stdio) = std::fs::OpenOptions::new() - .create(true).append(true) + .create(true) + .append(true) .open(&log_path) .and_then(|f| f.try_clone().map(|g| (f, g))) .map(|(f, g)| (std::process::Stdio::from(f), std::process::Stdio::from(g))) @@ -149,8 +141,14 @@ fn start_daemon() -> Result<()> { .stdout(stdout_stdio) .stderr(stderr_stdio); // SAFETY: setsid() 在 fork 后的子进程中调用,使 daemon 脱离控制终端 - unsafe { cmd.pre_exec(|| { libc::setsid(); Ok(()) }); } - let _ = cmd.spawn().context("无法启动 daemon 进程")?; + unsafe { + cmd.pre_exec(|| { + libc::setsid(); + Ok(()) + }); + } + let child = cmd.spawn().context("无法启动 daemon 进程")?; + child_pid = child.id(); } #[cfg(windows)] @@ -161,12 +159,13 @@ fn start_daemon() -> Result<()> { let _ = std::fs::create_dir_all(parent); } let (stdout_stdio, stderr_stdio) = std::fs::OpenOptions::new() - .create(true).append(true) + .create(true) + .append(true) .open(&log_path) .and_then(|f| f.try_clone().map(|g| (f, g))) .map(|(f, g)| (std::process::Stdio::from(f), std::process::Stdio::from(g))) .unwrap_or_else(|_| (std::process::Stdio::null(), std::process::Stdio::null())); - let _ = std::process::Command::new(&exe) + let child = std::process::Command::new(&exe) .env("WX_DAEMON_MODE", "1") .stdin(std::process::Stdio::null()) .stdout(stdout_stdio) @@ -174,6 +173,7 @@ fn start_daemon() -> Result<()> { .creation_flags(0x00000008) // DETACHED_PROCESS .spawn() .context("无法启动 daemon 进程")?; + child_pid = child.id(); } // 等待 daemon 就绪(最多 STARTUP_TIMEOUT_SECS 秒) @@ -181,6 +181,7 @@ fn start_daemon() -> Result<()> { while std::time::Instant::now() < deadline { std::thread::sleep(Duration::from_millis(300)); if is_alive() { + write_pid_file(child_pid, &exe)?; return Ok(()); } } @@ -192,6 +193,227 @@ fn start_daemon() -> Result<()> { ) } +fn write_pid_file(pid: u32, exe: &Path) -> Result<()> { + if let Some(parent) = config::pid_path().parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("创建 {} 失败", parent.display()))?; + } + let pid_file = PidFile { + pid, + exe: Some(exe.to_path_buf()), + }; + let content = serde_json::to_string(&pid_file)?; + std::fs::write(config::pid_path(), content) + .with_context(|| format!("写入 {} 失败", config::pid_path().display()))?; + Ok(()) +} + +fn read_pid_file(path: &Path) -> Result> { + let content = match std::fs::read_to_string(path) { + Ok(content) => content, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None), + Err(err) => return Err(err).with_context(|| format!("读取 {} 失败", path.display())), + }; + if let Ok(pid_file) = serde_json::from_str::(&content) { + return Ok(Some(pid_file)); + } + if let Ok(pid) = content.trim().parse::() { + return Ok(Some(PidFile { + pid, + exe: std::env::current_exe().ok(), + })); + } + bail!("{} 不是合法的 PID 文件", path.display()) +} + +fn cleanup_ipc_files() { + let _ = std::fs::remove_file(config::sock_path()); + let _ = std::fs::remove_file(config::pid_path()); +} + +#[cfg(unix)] +fn ping_unix() -> Result { + use std::os::unix::net::UnixStream; + let sock_path = config::sock_path(); + if !sock_path.exists() { + return Ok(false); + } + let mut stream = UnixStream::connect(&sock_path)?; + stream.set_read_timeout(Some(Duration::from_secs(2))).ok(); + stream.set_write_timeout(Some(Duration::from_secs(2))).ok(); + + let req = serde_json::to_string(&Request::Ping)? + "\n"; + stream.write_all(req.as_bytes())?; + + let mut line = String::new(); + let mut reader = BufReader::new(&stream); + reader.read_line(&mut line)?; + + let resp: Response = serde_json::from_str(&line)?; + Ok(resp.ok && resp.data.get("pong").and_then(|p| p.as_bool()) == Some(true)) +} + +#[cfg(windows)] +fn ping_windows() -> Result { + use interprocess::local_socket::{prelude::*, GenericNamespaced, Stream}; + + let name = "wx-cli-daemon".to_ns_name::()?; + let stream = Stream::connect(name)?; + let mut reader = BufReader::new(stream); + + let req = serde_json::to_string(&Request::Ping)? + "\n"; + reader.get_mut().write_all(req.as_bytes())?; + + let mut line = String::new(); + reader.read_line(&mut line)?; + + let resp: Response = serde_json::from_str(&line)?; + Ok(resp.ok && resp.data.get("pong").and_then(|p| p.as_bool()) == Some(true)) +} + +fn pid_belongs_to_daemon(pid_file: &PidFile) -> Result { + let expected_exe = pid_file + .exe + .clone() + .or_else(|| std::env::current_exe().ok()); + #[cfg(unix)] + { + unix_pid_matches_daemon(pid_file.pid, expected_exe.as_deref()) + } + #[cfg(windows)] + { + windows_pid_matches_daemon(pid_file.pid, expected_exe.as_deref()) + } + #[cfg(not(any(unix, windows)))] + { + let _ = expected_exe; + Ok(true) + } +} + +#[cfg(unix)] +fn unix_pid_matches_daemon(pid: u32, expected_exe: Option<&Path>) -> Result { + let Some(expected_exe) = expected_exe else { + return Ok(false); + }; + let output = std::process::Command::new("ps") + .args(["-o", "command=", "-p", &pid.to_string()]) + .output() + .with_context(|| format!("读取 PID {} 的 command 失败", pid))?; + if !output.status.success() { + return Ok(false); + } + let command = String::from_utf8_lossy(&output.stdout); + let expected = expected_exe.to_string_lossy(); + if command.contains(expected.as_ref()) { + return Ok(true); + } + let Some(exe_name) = expected_exe.file_name().and_then(|name| name.to_str()) else { + return Ok(false); + }; + Ok(command + .split_whitespace() + .any(|part| part == exe_name || part.ends_with(&format!("/{}", exe_name)))) +} + +#[cfg(windows)] +fn windows_pid_matches_daemon(pid: u32, expected_exe: Option<&Path>) -> Result { + use windows::core::PWSTR; + use windows::Win32::Foundation::CloseHandle; + use windows::Win32::System::Threading::{ + OpenProcess, QueryFullProcessImageNameW, PROCESS_QUERY_LIMITED_INFORMATION, + }; + + let Some(expected_exe) = expected_exe else { + return Ok(false); + }; + let handle = match unsafe { OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, pid) } { + Ok(handle) => handle, + Err(_) => return Ok(false), + }; + + let mut buf = vec![0u16; 260]; + let mut len = buf.len() as u32; + let actual = unsafe { + let result = QueryFullProcessImageNameW(handle, 0, PWSTR(buf.as_mut_ptr()), &mut len); + let _ = CloseHandle(handle); + result + }; + if actual.is_err() { + return Ok(false); + } + + let actual_path = PathBuf::from(String::from_utf16_lossy(&buf[..len as usize])); + Ok(normalize_exe_path(&actual_path) == normalize_exe_path(expected_exe)) +} + +#[cfg(windows)] +fn normalize_exe_path(path: &Path) -> String { + path.to_string_lossy() + .replace('\\', "/") + .to_ascii_lowercase() +} + +fn terminate_pid(pid: u32) -> Result<()> { + #[cfg(unix)] + { + terminate_pid_unix(pid) + } + #[cfg(windows)] + { + terminate_pid_windows(pid) + } + #[cfg(not(any(unix, windows)))] + { + let _ = pid; + Ok(()) + } +} + +#[cfg(unix)] +fn terminate_pid_unix(pid: u32) -> Result<()> { + let rc = unsafe { libc::kill(pid as i32, libc::SIGTERM) }; + if rc != 0 { + let err = std::io::Error::last_os_error(); + if err.raw_os_error() == Some(libc::ESRCH) { + return Ok(()); + } + bail!("停止 PID {} 失败: {}", pid, err); + } + + let deadline = std::time::Instant::now() + Duration::from_millis(STOP_TIMEOUT_MS); + while std::time::Instant::now() < deadline { + if !unix_process_exists(pid) { + return Ok(()); + } + thread::sleep(Duration::from_millis(50)); + } + + bail!("等待 PID {} 退出超时", pid) +} + +#[cfg(unix)] +fn unix_process_exists(pid: u32) -> bool { + let rc = unsafe { libc::kill(pid as i32, 0) }; + if rc == 0 { + return true; + } + let err = std::io::Error::last_os_error(); + err.raw_os_error() == Some(libc::EPERM) +} + +#[cfg(windows)] +fn terminate_pid_windows(pid: u32) -> Result<()> { + let status = std::process::Command::new("taskkill") + .args(["/F", "/PID", &pid.to_string()]) + .status() + .with_context(|| format!("执行 taskkill /PID {} 失败", pid))?; + if !status.success() { + bail!("停止 PID {} 失败: taskkill exit {:?}", pid, status.code()); + } + Ok(()) +} + /// 向 daemon 发送请求并返回响应 pub fn send(req: Request) -> Result { ensure_daemon()?; @@ -214,10 +436,11 @@ pub fn send(req: Request) -> Result { fn send_unix(req: Request) -> Result { use std::os::unix::net::UnixStream; let sock_path = config::sock_path(); - let mut stream = UnixStream::connect(&sock_path) - .context("连接 daemon socket 失败")?; + let mut stream = UnixStream::connect(&sock_path).context("连接 daemon socket 失败")?; stream.set_read_timeout(Some(Duration::from_secs(120))).ok(); - stream.set_write_timeout(Some(Duration::from_secs(120))).ok(); + stream + .set_write_timeout(Some(Duration::from_secs(120))) + .ok(); let req_str = serde_json::to_string(&req)? + "\n"; stream.write_all(req_str.as_bytes())?; @@ -226,8 +449,7 @@ fn send_unix(req: Request) -> Result { let mut reader = BufReader::new(&stream); reader.read_line(&mut line)?; - let resp: Response = serde_json::from_str(&line) - .context("解析 daemon 响应失败")?; + let resp: Response = serde_json::from_str(&line).context("解析 daemon 响应失败")?; if !resp.ok { bail!("{}", resp.error.as_deref().unwrap_or("未知错误")); @@ -240,10 +462,10 @@ fn send_unix(req: Request) -> Result { fn send_windows(req: Request) -> Result { use interprocess::local_socket::{prelude::*, GenericNamespaced, Stream}; - let name = "wx-cli-daemon".to_ns_name::() + let name = "wx-cli-daemon" + .to_ns_name::() .context("构造 pipe name 失败")?; - let stream = Stream::connect(name) - .context("连接 daemon named pipe 失败")?; + let stream = Stream::connect(name).context("连接 daemon named pipe 失败")?; // interprocess::Stream 同时实现 Read + Write,但需要拆分读写端 let mut reader = BufReader::new(stream); @@ -254,8 +476,7 @@ fn send_windows(req: Request) -> Result { let mut line = String::new(); reader.read_line(&mut line)?; - let resp: Response = serde_json::from_str(&line) - .context("解析 daemon 响应失败")?; + let resp: Response = serde_json::from_str(&line).context("解析 daemon 响应失败")?; if !resp.ok { bail!("{}", resp.error.as_deref().unwrap_or("未知错误")); diff --git a/src/config.rs b/src/config.rs index a488ca0..f74fda3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,38 +11,50 @@ pub struct Config { pub wechat_process: String, } -/// 从 /config.json 或 $HOME/.wx-cli/config.json 加载配置 +/// 从当前工作目录 / / $HOME/.wx-cli 加载配置 pub fn load_config() -> Result { let config_path = find_config_file()?; let content = std::fs::read_to_string(&config_path) .with_context(|| format!("读取 config.json 失败: {}", config_path.display()))?; - let raw: serde_json::Value = serde_json::from_str(&content) - .with_context(|| "config.json 格式错误")?; + let raw: serde_json::Value = + serde_json::from_str(&content).with_context(|| "config.json 格式错误")?; - let db_dir = raw.get("db_dir") + let db_dir = raw + .get("db_dir") .and_then(|v| v.as_str()) .map(PathBuf::from) .unwrap_or_else(default_db_dir); let base_dir = config_path.parent().unwrap_or(Path::new(".")); - let keys_file = raw.get("keys_file") + let keys_file = raw + .get("keys_file") .and_then(|v| v.as_str()) .map(|s| { let p = PathBuf::from(s); - if p.is_absolute() { p } else { base_dir.join(p) } + if p.is_absolute() { + p + } else { + base_dir.join(p) + } }) .unwrap_or_else(|| base_dir.join("all_keys.json")); - let decrypted_dir = raw.get("decrypted_dir") + let decrypted_dir = raw + .get("decrypted_dir") .and_then(|v| v.as_str()) .map(|s| { let p = PathBuf::from(s); - if p.is_absolute() { p } else { base_dir.join(p) } + if p.is_absolute() { + p + } else { + base_dir.join(p) + } }) .unwrap_or_else(|| base_dir.join("decrypted")); - let wechat_process = raw.get("wechat_process") + let wechat_process = raw + .get("wechat_process") .and_then(|v| v.as_str()) .unwrap_or(default_wechat_process()) .to_string(); @@ -56,35 +68,56 @@ pub fn load_config() -> Result { } fn find_config_file() -> Result { - // 1. 优先查找可执行文件同目录 - if let Ok(exe) = std::env::current_exe() { - if let Some(dir) = exe.parent() { - let p = dir.join("config.json"); - if p.exists() { - return Ok(p); - } - } + let cwd_dir = std::env::current_dir().ok(); + let exe_dir = std::env::current_exe() + .ok() + .and_then(|exe| exe.parent().map(PathBuf::from)); + let cli_home = cli_home_dir(); + let home_dir = (cli_home != PathBuf::from("/tmp")).then_some(cli_home.as_path()); + + if let Some(path) = find_existing_config_path(cwd_dir.as_deref(), exe_dir.as_deref(), home_dir) + { + return Ok(path); } - // 2. 当前工作目录 - let cwd = std::env::current_dir().unwrap_or_default().join("config.json"); - if cwd.exists() { - return Ok(cwd); - } - // 3. ~/.wx-cli/config.json - let home = cli_home_dir(); - if home != PathBuf::from("/tmp") { - let p = home.join(".wx-cli").join("config.json"); - if p.exists() { - return Ok(p); - } - } - // 返回默认路径(可能不存在,调用方负责处理) - if let Ok(exe) = std::env::current_exe() { - if let Some(dir) = exe.parent() { - return Ok(dir.join("config.json")); - } - } - Ok(PathBuf::from("config.json")) + + Ok(default_config_path( + cwd_dir.as_deref(), + exe_dir.as_deref(), + home_dir, + )) +} + +fn find_existing_config_path( + cwd_dir: Option<&Path>, + exe_dir: Option<&Path>, + home_dir: Option<&Path>, +) -> Option { + let candidates = [ + cwd_dir.map(config_path_in_dir), + exe_dir.map(config_path_in_dir), + home_dir.map(home_config_path), + ]; + candidates.into_iter().flatten().find(|path| path.exists()) +} + +fn default_config_path( + cwd_dir: Option<&Path>, + exe_dir: Option<&Path>, + home_dir: Option<&Path>, +) -> PathBuf { + cwd_dir + .map(config_path_in_dir) + .or_else(|| exe_dir.map(config_path_in_dir)) + .or_else(|| home_dir.map(home_config_path)) + .unwrap_or_else(|| PathBuf::from("config.json")) +} + +fn config_path_in_dir(dir: &Path) -> PathBuf { + dir.join("config.json") +} + +fn home_config_path(home_dir: &Path) -> PathBuf { + home_dir.join(".wx-cli").join("config.json") } pub fn cli_dir() -> PathBuf { @@ -163,8 +196,7 @@ fn default_db_dir() -> PathBuf { } #[cfg(target_os = "windows")] { - PathBuf::from(std::env::var("APPDATA").unwrap_or_default()) - .join("Tencent/xwechat") + PathBuf::from(std::env::var("APPDATA").unwrap_or_default()).join("Tencent/xwechat") } #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] { @@ -174,13 +206,21 @@ fn default_db_dir() -> PathBuf { fn default_wechat_process() -> &'static str { #[cfg(target_os = "macos")] - { "WeChat" } + { + "WeChat" + } #[cfg(target_os = "linux")] - { "wechat" } + { + "wechat" + } #[cfg(target_os = "windows")] - { "Weixin.exe" } + { + "Weixin.exe" + } #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] - { "WeChat" } + { + "WeChat" + } } /// 自动检测微信 db_storage 目录 @@ -244,6 +284,7 @@ fn detect_db_dir_impl() -> Option { candidates.into_iter().next_back() } +#[cfg(any(target_os = "linux", target_os = "windows"))] /// 递归查找 db_storage 目录下所有 .db 文件的最新 mtime fn latest_db_mtime(dir: &Path) -> Option { let mut latest = None; @@ -253,7 +294,10 @@ fn latest_db_mtime(dir: &Path) -> Option { let mtime = if path.is_dir() { latest_db_mtime(&path).unwrap_or(std::time::SystemTime::UNIX_EPOCH) } else if path.extension().and_then(|s| s.to_str()) == Some("db") { - entry.metadata().and_then(|m| m.modified()).unwrap_or(std::time::SystemTime::UNIX_EPOCH) + entry + .metadata() + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH) } else { continue; }; @@ -278,8 +322,7 @@ fn detect_db_dir_impl() -> Option { if let Ok(content) = std::fs::read_to_string(&path) { let data_root = content.trim().to_string(); if PathBuf::from(&data_root).is_dir() { - let pattern = PathBuf::from(&data_root) - .join("xwechat_files"); + let pattern = PathBuf::from(&data_root).join("xwechat_files"); if let Ok(entries2) = std::fs::read_dir(&pattern) { for entry2 in entries2.flatten() { let storage = entry2.path().join("db_storage"); @@ -293,7 +336,8 @@ fn detect_db_dir_impl() -> Option { } } } - candidates.into_iter().next() + candidates.sort_by_key(|p| latest_db_mtime(p).unwrap_or(std::time::SystemTime::UNIX_EPOCH)); + candidates.into_iter().next_back() } #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))] @@ -303,24 +347,66 @@ fn detect_db_dir_impl() -> Option { #[cfg(test)] mod tests { - use super::resolve_cli_home; + use super::{ + config_path_in_dir, default_config_path, find_existing_config_path, home_config_path, + resolve_cli_home, + }; + use std::fs; use std::path::PathBuf; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir(name: &str) -> PathBuf { + let unique = format!( + "wx-cli-config-test-{}-{}-{}", + name, + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + ); + let dir = std::env::temp_dir().join(unique); + fs::create_dir_all(&dir).unwrap(); + dir + } #[test] fn resolve_cli_home_prefers_sudo_home_when_present() { - let home = resolve_cli_home( - PathBuf::from("/root"), - Some(PathBuf::from("/Users/alice")), - ); + let home = resolve_cli_home(PathBuf::from("/root"), Some(PathBuf::from("/Users/alice"))); assert_eq!(home, PathBuf::from("/Users/alice")); } #[test] fn resolve_cli_home_falls_back_to_default_home() { - let home = resolve_cli_home( - PathBuf::from("/root"), - None, - ); + let home = resolve_cli_home(PathBuf::from("/root"), None); assert_eq!(home, PathBuf::from("/root")); } + + #[test] + fn config_path_prefers_cwd_over_exe_and_home() { + let cwd = temp_dir("cwd"); + let exe = temp_dir("exe"); + let home = temp_dir("home"); + fs::write(config_path_in_dir(&cwd), "{}").unwrap(); + fs::write(config_path_in_dir(&exe), "{}").unwrap(); + fs::create_dir_all(home.join(".wx-cli")).unwrap(); + fs::write(home_config_path(&home), "{}").unwrap(); + + let path = find_existing_config_path(Some(&cwd), Some(&exe), Some(&home)).unwrap(); + assert_eq!(path, config_path_in_dir(&cwd)); + + fs::remove_dir_all(cwd).unwrap(); + fs::remove_dir_all(exe).unwrap(); + fs::remove_dir_all(home).unwrap(); + } + + #[test] + fn default_config_path_matches_init_write_order() { + let cwd = PathBuf::from("/tmp/cwd"); + let exe = PathBuf::from("/tmp/exe"); + let home = PathBuf::from("/tmp/home"); + + let path = default_config_path(Some(&cwd), Some(&exe), Some(&home)); + assert_eq!(path, cwd.join("config.json")); + } } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index e5407b5..da074e7 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,9 +1,9 @@ pub mod wal; -use anyhow::{bail, Result}; use aes::Aes256; -use cbc::Decryptor; +use anyhow::{bail, Result}; use cbc::cipher::{BlockDecryptMut, KeyIvInit}; +use cbc::Decryptor; use std::io::{Read, Write}; use std::path::Path; @@ -65,11 +65,8 @@ fn aes_cbc_decrypt(key: &[u8; 32], iv: &[u8; 16], data: &[u8]) -> Result bail!("密文长度不是 AES 块大小的倍数: {}", data.len()); } // 将 &[u8] 复制为 Block 数组,避免 unsafe from_raw_parts_mut - let mut blocks: Vec = data.chunks_exact(16) - .map(Block::clone_from_slice) - .collect(); - Aes256CbcDec::new(key.into(), iv.into()) - .decrypt_blocks_mut(&mut blocks); + let mut blocks: Vec = data.chunks_exact(16).map(Block::clone_from_slice).collect(); + Aes256CbcDec::new(key.into(), iv.into()).decrypt_blocks_mut(&mut blocks); Ok(blocks.iter().flat_map(|b| b.iter().copied()).collect()) } @@ -92,15 +89,101 @@ pub fn full_decrypt(db_path: &Path, out_path: &Path, enc_key: &[u8; 32]) -> Resu let mut page_buf = vec![0u8; PAGE_SZ]; for pgno in 1..=total_pages { - let n = input.read(&mut page_buf)?; - if n == 0 { break; } - // 不足一页则补零 - if n < PAGE_SZ { - page_buf[n..].fill(0); - } + let page_start = (pgno - 1) * PAGE_SZ; + let bytes_remaining = file_size.saturating_sub(page_start); + read_page(&mut input, &mut page_buf, bytes_remaining)?; let dec = decrypt_page(enc_key, &page_buf, pgno as u32)?; output.write_all(&dec)?; } Ok(()) } + +fn read_page( + input: &mut impl Read, + page_buf: &mut [u8], + bytes_remaining: usize, +) -> std::io::Result { + let expected = bytes_remaining.min(PAGE_SZ); + input.read_exact(&mut page_buf[..expected])?; + if expected < PAGE_SZ { + page_buf[expected..].fill(0); + } + Ok(expected) +} + +#[cfg(test)] +mod tests { + use super::{read_page, PAGE_SZ}; + use std::io::{self, Read}; + + struct ChunkedReader { + chunks: Vec>, + chunk_idx: usize, + offset: usize, + } + + impl ChunkedReader { + fn new(chunks: Vec>) -> Self { + Self { + chunks, + chunk_idx: 0, + offset: 0, + } + } + } + + impl Read for ChunkedReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.chunk_idx >= self.chunks.len() { + return Ok(0); + } + let chunk = &self.chunks[self.chunk_idx]; + let remaining = &chunk[self.offset..]; + let n = remaining.len().min(buf.len()); + buf[..n].copy_from_slice(&remaining[..n]); + self.offset += n; + if self.offset == chunk.len() { + self.chunk_idx += 1; + self.offset = 0; + } + Ok(n) + } + } + + #[test] + fn read_page_reads_across_short_chunks() { + let mut reader = ChunkedReader::new(vec![vec![1; 32], vec![2; PAGE_SZ - 32]]); + let mut page_buf = vec![0u8; PAGE_SZ]; + + let n = read_page(&mut reader, &mut page_buf, PAGE_SZ).unwrap(); + + assert_eq!(n, PAGE_SZ); + assert_eq!(page_buf[0], 1); + assert_eq!(page_buf[31], 1); + assert_eq!(page_buf[32], 2); + assert_eq!(page_buf[PAGE_SZ - 1], 2); + } + + #[test] + fn read_page_zero_pads_last_partial_page() { + let mut reader = ChunkedReader::new(vec![vec![7; 8], vec![9; 4]]); + let mut page_buf = vec![0u8; PAGE_SZ]; + + let n = read_page(&mut reader, &mut page_buf, 12).unwrap(); + + assert_eq!(n, 12); + assert_eq!(&page_buf[..8], &[7; 8]); + assert_eq!(&page_buf[8..12], &[9; 4]); + assert!(page_buf[12..].iter().all(|&b| b == 0)); + } + + #[test] + fn read_page_errors_on_early_eof() { + let mut reader = ChunkedReader::new(vec![vec![1; 8]]); + let mut page_buf = vec![0u8; PAGE_SZ]; + + let err = read_page(&mut reader, &mut page_buf, 16).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof); + } +} diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 02dc99f..ef3723c 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -25,9 +25,7 @@ async fn async_run() -> Result<()> { tokio::fs::create_dir_all(&cli_dir).await?; tokio::fs::create_dir_all(config::cache_dir()).await?; - // 写 PID 文件 let pid = std::process::id(); - tokio::fs::write(config::pid_path(), pid.to_string()).await?; // 注册 SIGTERM / SIGINT 处理 setup_signal_handler().await; @@ -39,7 +37,8 @@ async fn async_run() -> Result<()> { eprintln!("[daemon] DB_DIR: {}", cfg.db_dir.display()); // 加载密钥 - let keys_content = tokio::fs::read_to_string(&cfg.keys_file).await + let keys_content = tokio::fs::read_to_string(&cfg.keys_file) + .await .map_err(|e| anyhow::anyhow!("读取密钥文件 {:?} 失败: {}", cfg.keys_file, e))?; let keys_raw: serde_json::Value = serde_json::from_str(&keys_content)?; let all_keys = extract_keys(&keys_raw); @@ -49,11 +48,14 @@ async fn async_run() -> Result<()> { let db = Arc::new(cache::DbCache::new(cfg.db_dir.clone(), all_keys.clone()).await?); // 收集消息 DB 列表 - let msg_db_keys: Vec = all_keys.keys() + let msg_db_keys: Vec = all_keys + .keys() .filter(|k| { let k = k.replace('\\', "/"); - k.contains("message/message_") && k.ends_with(".db") - && !k.contains("_fts") && !k.contains("_resource") + k.contains("message/message_") + && k.ends_with(".db") + && !k.contains("_fts") + && !k.contains("_resource") }) .cloned() .collect(); @@ -82,7 +84,9 @@ async fn async_run() -> Result<()> { let names_arc = Arc::new(tokio::sync::RwLock::new(Arc::new(names))); // 启动 IPC server(阻塞) - server::serve(Arc::clone(&db), Arc::clone(&names_arc)).await?; + let serve_result = server::serve(Arc::clone(&db), Arc::clone(&names_arc)).await; + cleanup_ipc_files(); + serve_result?; Ok(()) } @@ -96,7 +100,9 @@ fn extract_keys(json: &serde_json::Value) -> HashMap { let mut result = HashMap::new(); if let Some(obj) = json.as_object() { for (k, v) in obj { - if k.starts_with('_') { continue; } + if k.starts_with('_') { + continue; + } let enc_key = if let Some(s) = v.as_str() { s.to_string() } else if let Some(obj2) = v.as_object() { @@ -133,7 +139,11 @@ async fn setup_signal_handler() { } fn cleanup_and_exit() { - let _ = std::fs::remove_file(config::sock_path()); - let _ = std::fs::remove_file(config::pid_path()); + cleanup_ipc_files(); std::process::exit(0); } + +fn cleanup_ipc_files() { + let _ = std::fs::remove_file(config::sock_path()); + let _ = std::fs::remove_file(config::pid_path()); +} diff --git a/src/scanner/windows.rs b/src/scanner/windows.rs index a6660cb..0f7470b 100644 --- a/src/scanner/windows.rs +++ b/src/scanner/windows.rs @@ -8,16 +8,16 @@ use anyhow::{bail, Context, Result}; use std::path::Path; use windows::Win32::Foundation::{CloseHandle, HANDLE}; +use windows::Win32::System::Diagnostics::Debug::ReadProcessMemory; use windows::Win32::System::Diagnostics::ToolHelp::{ CreateToolhelp32Snapshot, Process32First, Process32Next, PROCESSENTRY32, TH32CS_SNAPPROCESS, }; use windows::Win32::System::Memory::{ - VirtualQueryEx, MEMORY_BASIC_INFORMATION, MEM_COMMIT, PAGE_READWRITE, + VirtualQueryEx, MEMORY_BASIC_INFORMATION, MEM_COMMIT, PAGE_EXECUTE_READWRITE, + PAGE_EXECUTE_WRITECOPY, PAGE_GUARD, PAGE_NOCACHE, PAGE_READWRITE, PAGE_WRITECOMBINE, + PAGE_WRITECOPY, }; -use windows::Win32::System::Threading::{ - OpenProcess, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ, -}; -use windows::Win32::System::Diagnostics::Debug::ReadProcessMemory; +use windows::Win32::System::Threading::{OpenProcess, PROCESS_QUERY_INFORMATION, PROCESS_VM_READ}; use super::{collect_db_salts, KeyEntry}; @@ -27,9 +27,7 @@ const CHUNK_SIZE: usize = 2 * 1024 * 1024; /// 查找 Weixin.exe 进程 PID fn find_wechat_pid() -> Option { // SAFETY: CreateToolhelp32Snapshot 标准 Windows API - let snap = unsafe { - CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()? - }; + let snap = unsafe { CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0).ok()? }; let mut entry = PROCESSENTRY32 { dwSize: std::mem::size_of::() as u32, @@ -43,8 +41,8 @@ fn find_wechat_pid() -> Option { return None; } loop { - let name = std::ffi::CStr::from_ptr(entry.szExeFile.as_ptr() as *const i8) - .to_string_lossy(); + let name = + std::ffi::CStr::from_ptr(entry.szExeFile.as_ptr() as *const i8).to_string_lossy(); if name.eq_ignore_ascii_case("Weixin.exe") { let pid = entry.th32ProcessID; let _ = CloseHandle(snap); @@ -60,8 +58,7 @@ fn find_wechat_pid() -> Option { } pub fn scan_keys(db_dir: &Path) -> Result> { - let pid = find_wechat_pid() - .context("找不到 Weixin.exe 进程,请确认微信正在运行")?; + let pid = find_wechat_pid().context("找不到 Weixin.exe 进程,请确认微信正在运行")?; eprintln!("WeChat PID: {}", pid); // SAFETY: OpenProcess 请求读取权限 @@ -78,7 +75,9 @@ pub fn scan_keys(db_dir: &Path) -> Result> { eprintln!("找到 {} 个候选密钥", raw_keys.len()); // SAFETY: 关闭进程句柄 - unsafe { let _ = CloseHandle(process); } + unsafe { + let _ = CloseHandle(process); + } let mut entries = Vec::new(); for (key_hex, salt_hex) in &raw_keys { @@ -119,8 +118,9 @@ fn scan_memory(process: HANDLE) -> Result> { let region_size = mbi.RegionSize; let base = mbi.BaseAddress as usize; - // 只扫描已提交的可读写页面 - if mbi.State == MEM_COMMIT && mbi.Protect == PAGE_READWRITE { + // 只扫描已提交的可读可写页面。Windows 的保护位可能带 modifier bits, + // 也可能是 WRITECOPY / EXECUTE_READWRITE 这种同样可读可写的保护类型。 + if mbi.State == MEM_COMMIT && is_writable_readable_page(mbi.Protect.0) { scan_region(process, base, region_size, &mut results); } @@ -133,12 +133,18 @@ fn scan_memory(process: HANDLE) -> Result> { Ok(results) } -fn scan_region( - process: HANDLE, - base: usize, - size: usize, - results: &mut Vec<(String, String)>, -) { +fn is_writable_readable_page(protect: u32) -> bool { + let base = protect & !(PAGE_GUARD.0 | PAGE_NOCACHE.0 | PAGE_WRITECOMBINE.0); + matches!( + base, + x if x == PAGE_READWRITE.0 + || x == PAGE_WRITECOPY.0 + || x == PAGE_EXECUTE_READWRITE.0 + || x == PAGE_EXECUTE_WRITECOPY.0 + ) +} + +fn scan_region(process: HANDLE, base: usize, size: usize, results: &mut Vec<(String, String)>) { let overlap = HEX_PATTERN_LEN + 3; let mut offset = 0usize; @@ -159,7 +165,8 @@ fn scan_region( buf.as_mut_ptr() as *mut _, chunk_size, Some(&mut bytes_read), - ).is_ok() + ) + .is_ok() }; if ok && bytes_read > 0 { @@ -203,10 +210,8 @@ fn search_pattern(buf: &[u8], results: &mut Vec<(String, String)>) { i += 1; continue; } - let key_hex = String::from_utf8_lossy(&buf[hex_start..hex_start + 64]) - .to_lowercase(); - let salt_hex = String::from_utf8_lossy(&buf[hex_start + 64..hex_start + 96]) - .to_lowercase(); + let key_hex = String::from_utf8_lossy(&buf[hex_start..hex_start + 64]).to_lowercase(); + let salt_hex = String::from_utf8_lossy(&buf[hex_start + 64..hex_start + 96]).to_lowercase(); let is_dup = results.iter().any(|(k, s)| k == &key_hex && s == &salt_hex); if !is_dup { results.push((key_hex, salt_hex));