diff --git a/README.md b/README.md index ea7c22f..1c1c7b5 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,7 @@ wx sns-search "婚礼" --user "李四" --since 2023-01-01 ### 公众号文章 -公众号文章推送存在独立的 `biz_message_0.db`,用 `biz-articles` 单独查: +公众号文章推送存在独立的 `biz_message_*.db` 分片,用 `biz-articles` 单独查: ```bash wx biz-articles # 最近 50 篇 diff --git a/SKILL.md b/SKILL.md index f75cadc..61082fe 100644 --- a/SKILL.md +++ b/SKILL.md @@ -242,7 +242,7 @@ wx sns-search "婚礼" --user "李四" --since 2023-01-01 -n 50 ### 公众号文章 -公众号的文章推送存在独立的 `biz_message_0.db`,与普通 `message_0.db` 分开: +公众号的文章推送存在独立的 `biz_message_*.db` 分片,与普通 `message_0.db` 分开: ```bash # 最近 50 篇(默认) diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 6503134..f503ce5 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -9,6 +9,39 @@ use std::sync::Arc; use crate::config; +fn normalized_rel_key(rel_key: &str) -> String { + rel_key.replace('\\', "/") +} + +fn is_msg_db_key(rel_key: &str) -> bool { + let rel_key = normalized_rel_key(rel_key); + rel_key.starts_with("message/message_") + && rel_key.ends_with(".db") + && !rel_key.contains("_fts") + && !rel_key.contains("_resource") +} + +fn is_biz_msg_db_key(rel_key: &str) -> bool { + let rel_key = normalized_rel_key(rel_key); + rel_key.starts_with("message/biz_message_") + && rel_key.ends_with(".db") + && !rel_key.contains("_fts") + && !rel_key.contains("_resource") +} + +fn collect_db_keys( + all_keys: &HashMap, + predicate: fn(&str) -> bool, +) -> Vec { + let mut keys: Vec = all_keys + .keys() + .filter(|k| predicate(k)) + .cloned() + .collect(); + keys.sort(); + keys +} + /// daemon 入口 /// /// 当 WX_DAEMON_MODE 环境变量设置时,main() 调用此函数 @@ -49,17 +82,8 @@ async fn async_run() -> Result<()> { let db = Arc::new(cache::DbCache::new(cfg.db_dir.clone(), all_keys.clone()).await?); // 收集消息 DB 列表 - let msg_db_keys: Vec = all_keys - .keys() - .filter(|k| { - let k = k.replace('\\', "/"); - k.contains("message/message_") - && k.ends_with(".db") - && !k.contains("_fts") - && !k.contains("_resource") - }) - .cloned() - .collect(); + let msg_db_keys = collect_db_keys(&all_keys, is_msg_db_key); + let biz_msg_db_keys = collect_db_keys(&all_keys, is_biz_msg_db_key); // 预热:加载联系人 + 解密 session.db eprintln!("[daemon] 预热..."); @@ -69,11 +93,13 @@ async fn async_run() -> Result<()> { map: HashMap::new(), md5_to_uname: HashMap::new(), msg_db_keys: Vec::new(), + biz_msg_db_keys: Vec::new(), verify_flags: HashMap::new(), } }); let mut names = names_raw; names.msg_db_keys = msg_db_keys; + names.biz_msg_db_keys = biz_msg_db_keys; let _ = db.get("session/session.db").await; let _ = db.get("sns/sns.db").await; @@ -149,3 +175,28 @@ fn cleanup_ipc_files() { let _ = std::fs::remove_file(config::sock_path()); let _ = std::fs::remove_file(config::pid_path()); } + +#[cfg(test)] +mod tests { + use super::{is_biz_msg_db_key, is_msg_db_key}; + + #[test] + fn message_db_key_filter_ignores_biz_and_auxiliary_files() { + assert!(is_msg_db_key("message/message_0.db")); + assert!(is_msg_db_key("message\\message_12.db")); + assert!(!is_msg_db_key("message/biz_message_0.db")); + assert!(!is_msg_db_key("message/message_0.db-wal")); + assert!(!is_msg_db_key("message/message_0_fts.db")); + assert!(!is_msg_db_key("message/message_0_resource.db")); + } + + #[test] + fn biz_message_db_key_filter_matches_only_biz_shards() { + assert!(is_biz_msg_db_key("message/biz_message_0.db")); + assert!(is_biz_msg_db_key("message\\biz_message_3.db")); + assert!(!is_biz_msg_db_key("message/message_0.db")); + assert!(!is_biz_msg_db_key("message/biz_message_0.db-wal")); + assert!(!is_biz_msg_db_key("message/biz_message_0_fts.db")); + assert!(!is_biz_msg_db_key("message/biz_message_0_resource.db")); + } +} diff --git a/src/daemon/query.rs b/src/daemon/query.rs index 0de0eca..ac9ec0d 100644 --- a/src/daemon/query.rs +++ b/src/daemon/query.rs @@ -55,6 +55,8 @@ pub struct Names { pub md5_to_uname: HashMap, /// 消息 DB 的相对路径列表(message/message_N.db) pub msg_db_keys: Vec, + /// 公众号推送 DB 的相对路径列表(message/biz_message_N.db) + pub biz_msg_db_keys: Vec, /// username -> contact.verify_flag(0=真人,非 0 通常为公众号/服务号/认证账号) pub verify_flags: HashMap, } @@ -269,6 +271,7 @@ pub async fn load_names(db: &DbCache) -> Result { map, md5_to_uname, msg_db_keys: Vec::new(), + biz_msg_db_keys: Vec::new(), verify_flags, }) } @@ -4010,7 +4013,7 @@ fn extract_cdata(xml: &str, tag: &str) -> Option { } } -/// 查询公众号文章推送(biz_message_0.db) +/// 查询公众号文章推送(biz_message_*.db 分片) /// /// 每条消息可能包含多篇文章(多图文推送)。返回所有文章展开就的平底列表。 pub async fn q_biz_articles( @@ -4022,10 +4025,17 @@ pub async fn q_biz_articles( until: Option, unread: bool, ) -> Result { - let biz_path = db - .get("message/biz_message_0.db") - .await? - .context("无法解密 biz_message_0.db,请确认 all_keys.json 包含对应密钥")?; + let mut biz_paths = Vec::new(); + for rel_key in &names.biz_msg_db_keys { + if let Some(path) = db.get(rel_key).await? { + biz_paths.push(path); + } + } + if biz_paths.is_empty() { + return Err(anyhow::anyhow!( + "无法解密任何 biz_message_*.db,请确认 all_keys.json 包含对应密钥" + )); + } // 开启 --unread:从 session.db 拿“公众号 + unread_count>0”的 username 子集, // 作为合集过滤(与 --account 取交集),后续结果按 account_username 去重取顶 1 篇。 @@ -4060,32 +4070,37 @@ pub async fn q_biz_articles( None }; - // 1. 从 Name2Id 表获取 rowid -> username 映射,再推导 md5 -> username - let biz_path2 = biz_path.clone(); - let id2username: HashMap = tokio::task::spawn_blocking(move || { - let conn = Connection::open(&biz_path2)?; - let mut stmt = - conn.prepare("SELECT rowid, user_name FROM Name2Id WHERE user_name LIKE 'gh_%'")?; - let rows = stmt - .query_map([], |row| { - Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?)) - })? - .collect::>>()?; - Ok::<_, anyhow::Error>(rows.into_iter().collect()) + // 1. 从全部 biz shard 的 Name2Id 表收集 username,再推导 md5 -> username + let biz_paths2 = biz_paths.clone(); + let biz_usernames: HashSet = tokio::task::spawn_blocking(move || { + let mut usernames = HashSet::new(); + for biz_path in biz_paths2 { + let conn = Connection::open(&biz_path)?; + let mut stmt = conn.prepare( + "SELECT DISTINCT user_name FROM Name2Id \ + WHERE user_name IS NOT NULL AND user_name != ''", + )?; + let rows: Vec = stmt + .query_map([], |row| row.get::<_, String>(0))? + .filter_map(|r| r.ok()) + .collect(); + usernames.extend(rows); + } + Ok::<_, anyhow::Error>(usernames) }) .await??; // 构建 md5(username) -> username 映射 - let md5_to_uname: HashMap = id2username - .values() + let md5_to_uname: HashMap = biz_usernames + .iter() .map(|u| (format!("{:x}", md5::compute(u.as_bytes())), u.clone())) .collect(); // 2. 如果 指定了 --account,找到匹配的 username 列表 let account_low = account.as_deref().map(|s| s.to_lowercase()); let mut target_usernames: Option> = account_low.as_ref().map(|low| { - id2username - .values() + biz_usernames + .iter() .filter(|u| { let display = names.display(u); display.to_lowercase().contains(low.as_str()) @@ -4115,7 +4130,7 @@ pub async fn q_biz_articles( } // 3. 进行数据库查询 - let biz_path3 = biz_path.clone(); + let biz_paths3 = biz_paths; let since2 = since; let until2 = until; let target_hashes: Option> = target_usernames.as_ref().map(|unames| { @@ -4126,71 +4141,72 @@ pub async fn q_biz_articles( }); let rows: Vec<(String, i64, i64, Vec, i64)> = tokio::task::spawn_blocking(move || { - let conn = Connection::open(&biz_path3)?; - - // 列出所有 Msg_ 表 - let mut stmt = conn - .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'")?; - let table_names: Vec = stmt - .query_map([], |row| row.get(0))? - .filter_map(|r| r.ok()) - .collect(); - let re = regex::Regex::new(r"^Msg_[0-9a-f]{32}$").unwrap(); let mut all_rows: Vec<(String, i64, i64, Vec, i64)> = Vec::new(); - for tname in &table_names { - if !re.is_match(tname) { - continue; - } - let hash = &tname[4..]; + for biz_path in biz_paths3 { + let conn = Connection::open(&biz_path)?; + let mut stmt = conn.prepare( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'", + )?; + let table_names: Vec = stmt + .query_map([], |row| row.get(0))? + .filter_map(|r| r.ok()) + .collect(); - // account 过滤 - if let Some(ref hashes) = target_hashes { - if !hashes.iter().any(|h| h == hash) { + for tname in &table_names { + if !re.is_match(tname) { continue; } - } + let hash = &tname[4..]; - let username = md5_to_uname.get(hash).cloned().unwrap_or_default(); + // account 过滤 + if let Some(ref hashes) = target_hashes { + if !hashes.iter().any(|h| h == hash) { + continue; + } + } - // 构建过滤条件 - let mut clauses: Vec = Vec::new(); - let mut params: Vec> = Vec::new(); - // local_type & 0xFFFFFFFF = 49 是 appmsg(公众号文章) - clauses.push("(local_type & 4294967295) = 49".to_string()); - if let Some(s) = since2 { - clauses.push("create_time >= ?".to_string()); - params.push(Box::new(s)); - } - if let Some(u) = until2 { - clauses.push("create_time <= ?".to_string()); - params.push(Box::new(u)); - } - let where_clause = format!("WHERE {}", clauses.join(" AND ")); + let username = md5_to_uname.get(hash).cloned().unwrap_or_default(); - let sql = format!( - "SELECT create_time, WCDB_CT_message_content, message_content \ - FROM [{}] {} ORDER BY create_time DESC", - tname, where_clause - ); + // 构建过滤条件 + let mut clauses: Vec = Vec::new(); + let mut params: Vec> = Vec::new(); + // local_type & 0xFFFFFFFF = 49 是 appmsg(公众号文章) + clauses.push("(local_type & 4294967295) = 49".to_string()); + if let Some(s) = since2 { + clauses.push("create_time >= ?".to_string()); + params.push(Box::new(s)); + } + if let Some(u) = until2 { + clauses.push("create_time <= ?".to_string()); + params.push(Box::new(u)); + } + let where_clause = format!("WHERE {}", clauses.join(" AND ")); - let params_ref: Vec<&dyn rusqlite::types::ToSql> = - params.iter().map(|p| p.as_ref()).collect(); - if let Ok(mut inner_stmt) = conn.prepare(&sql) { - let msg_rows: Vec<_> = inner_stmt - .query_map(params_ref.as_slice(), |row| { - Ok(( - username.clone(), - row.get::<_, i64>(0)?, - row.get::<_, i64>(1).unwrap_or(0), - get_content_bytes(row, 2), - 0i64, - )) - }) - .map(|it| it.filter_map(|r| r.ok()).collect()) - .unwrap_or_default(); - all_rows.extend(msg_rows); + let sql = format!( + "SELECT create_time, WCDB_CT_message_content, message_content \ + FROM [{}] {} ORDER BY create_time DESC", + tname, where_clause + ); + + let params_ref: Vec<&dyn rusqlite::types::ToSql> = + params.iter().map(|p| p.as_ref()).collect(); + if let Ok(mut inner_stmt) = conn.prepare(&sql) { + let msg_rows: Vec<_> = inner_stmt + .query_map(params_ref.as_slice(), |row| { + Ok(( + username.clone(), + row.get::<_, i64>(0)?, + row.get::<_, i64>(1).unwrap_or(0), + get_content_bytes(row, 2), + 0i64, + )) + }) + .map(|it| it.filter_map(|r| r.ok()).collect()) + .unwrap_or_default(); + all_rows.extend(msg_rows); + } } } Ok::<_, anyhow::Error>(all_rows) diff --git a/src/ipc.rs b/src/ipc.rs index fd6d6bf..93306fb 100644 --- a/src/ipc.rs +++ b/src/ipc.rs @@ -126,7 +126,7 @@ pub enum Request { #[serde(skip_serializing_if = "Option::is_none")] user: Option, }, - /// 查询公众号文章推送(biz_message_0.db) + /// 查询公众号文章推送(biz_message_*.db 分片) BizArticles { #[serde(default = "default_limit_50")] limit: usize,