From 7e7f7a251606c77b2892548d8256014330011736 Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 10:21:21 +0800 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E6=9F=A5=E8=AF=A2=E5=8A=9F=E8=83=BD=EF=BC=8C=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=97=B6=E9=97=B4=E8=8C=83=E5=9B=B4=E5=92=8C=E5=88=86?= =?UTF-8?q?=E9=A1=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +- USAGE.md | 78 ++++- mcp_server.py | 858 ++++++++++++++++++++++++++++++++++---------------- 3 files changed, 676 insertions(+), 270 deletions(-) diff --git a/README.md b/README.md index 969cd83..b7003c9 100644 --- a/README.md +++ b/README.md @@ -132,13 +132,19 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv | Tool | 功能 | |------|------| | `get_recent_sessions(limit)` | 最近会话列表(含消息摘要、未读数) | -| `get_chat_history(chat_name, limit)` | 指定聊天的消息记录(支持模糊匹配名字) | -| `search_messages(keyword, limit)` | 全库搜索消息内容 | +| `get_chat_history(chat_name, limit, offset, start_time, end_time)` | 指定聊天的消息记录,支持时间范围和分页 | +| `search_messages(keyword, chat_name, start_time, end_time, limit, offset)` | 统一搜索消息;支持全库、单个聊天对象、多个聊天对象、时间范围和分页 | | `get_contacts(query, limit)` | 搜索/列出联系人 | | `get_new_messages()` | 获取自上次调用以来的新消息 | 前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。 +新增能力: +- `get_chat_history` 支持 `offset` 分页,以及 `start_time` / `end_time` 时间范围过滤 +- `search_messages` 支持“全库 / 单个联系人或群聊 / 多个联系人或群聊”的统一搜索入口 +- `search_messages` 在定向搜索时会汇报无法解析或无消息表的对象 +- 时间格式支持 `YYYY-MM-DD`、`YYYY-MM-DD HH:MM`、`YYYY-MM-DD HH:MM:SS` + **[查看使用案例 →](USAGE.md)** ### 图片解密 (V2 格式) diff --git a/USAGE.md b/USAGE.md index b37cfcb..c193191 100644 --- a/USAGE.md +++ b/USAGE.md @@ -66,7 +66,7 @@ Claude 调用 `get_chat_history`,然后自动分析总结: > 搜一下谁提过"claude" ``` -Claude 调用 `search_messages(keyword="claude")`: +Claude 调用 `search_messages(keyword="claude")`: ``` 搜索 "claude" 找到 20 条结果: @@ -78,7 +78,77 @@ Claude 调用 `search_messages(keyword="claude")`: ... ``` -## 4. 搜索联系人 +## 4. 时间范围 + 分页查看聊天记录 + +``` +> 帮我看一下██群 3 月 1 日到 3 月 7 日的聊天,先给我前 20 条 +``` + +Claude 可以调用: + +```python +get_chat_history( + chat_name="██群", + start_time="2026-03-01", + end_time="2026-03-07", + limit=20, + offset=0, +) +``` + +下一页: + +```python +get_chat_history( + chat_name="██群", + start_time="2026-03-01", + end_time="2026-03-07", + limit=20, + offset=20, +) +``` + +## 5. 搜索指定联系人/群聊在某个时间段内的消息 + +``` +> 帮我搜一下██群这周谁提到过 Claude +``` + +Claude 可以调用统一接口: + +```python +search_messages( + keyword="Claude", + chat_name="██群", + start_time="2026-03-01", + end_time="2026-03-07", + limit=20, + offset=0, +) +``` + +## 6. 多个联系人/群聊联合搜索 + +``` +> 帮我看看联系人A、联系人B 和 ██项目群 这周谁提到过“项目” +``` + +Claude 可以调用统一接口: + +```python +search_messages( + keyword="项目", + chat_name=["联系人A", "联系人B", "██项目群"], + start_time="2026-03-01", + end_time="2026-03-07", + limit=20, + offset=0, +) +``` + +如果某些名字没匹配到联系人,或没有对应消息表,结果里会单独说明。 + +## 7. 搜索联系人 ``` > 帮我找一下姓张的联系人 @@ -95,7 +165,7 @@ wxid_████ 备注: 张██ 昵称: 小██ ... ``` -## 5. 获取新消息 +## 8. 获取新消息 ``` > 有没有新消息 @@ -113,7 +183,7 @@ Claude 调用 `get_new_messages()`: [16:22] ██群 [群] (19条未读): (图片) ``` -## 6. 高级用法:群聊分析 +## 9. 高级用法:群聊分析 Claude 可以获取大量消息后自动分析活跃度、话题分布、关键人物: diff --git a/mcp_server.py b/mcp_server.py index d01b06c..2f141e1 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -217,11 +217,11 @@ atexit.register(_cache.cleanup) # ============ 联系人缓存 ============ -_contact_names = None # {username: display_name} -_contact_full = None # [{username, nick_name, remark}] -_self_username = None -_XML_UNSAFE_RE = re.compile(r' 0xFFFFFFFF: - return t & 0xFFFFFFFF, t >> 32 - return t, 0 +def _split_msg_type(t): + try: + t = int(t) + except (TypeError, ValueError): + return 0, 0 + # WeChat packs the base type into the low 32 bits and app subtype into the high 32 bits. + if t > 0xFFFFFFFF: + return t & 0xFFFFFFFF, t >> 32 + return t, 0 def resolve_username(chat_name): @@ -353,42 +353,42 @@ def _collapse_text(text): return re.sub(r'\s+', ' ', text).strip() -def _get_self_username(): - global _self_username - if _self_username: - return _self_username - - if not DB_DIR: - return '' - - names = get_contact_names() - account_dir = os.path.basename(os.path.dirname(DB_DIR)) - candidates = [account_dir] +def _get_self_username(): + global _self_username + if _self_username: + return _self_username + + if not DB_DIR: + return '' + + names = get_contact_names() + account_dir = os.path.basename(os.path.dirname(DB_DIR)) + candidates = [account_dir] m = re.fullmatch(r'(.+)_([0-9a-fA-F]{4,})', account_dir) if m: candidates.insert(0, m.group(1)) - for candidate in candidates: - if candidate and candidate in names: - _self_username = candidate - return _self_username - - return '' - - -def _load_name2id_maps(conn): - id_to_username = {} - try: - rows = conn.execute("SELECT rowid, user_name FROM Name2Id").fetchall() - except sqlite3.Error: - return id_to_username - - for rowid, user_name in rows: - if not user_name: - continue - id_to_username[rowid] = user_name - return id_to_username + for candidate in candidates: + if candidate and candidate in names: + _self_username = candidate + return _self_username + + return '' + + +def _load_name2id_maps(conn): + id_to_username = {} + try: + rows = conn.execute("SELECT rowid, user_name FROM Name2Id").fetchall() + except sqlite3.Error: + return id_to_username + + for rowid, user_name in rows: + if not user_name: + continue + id_to_username[rowid] = user_name + return id_to_username def _display_name_for_username(username, names): @@ -416,62 +416,62 @@ def _resolve_sender_label(real_sender_id, sender_from_content, is_group, chat_us return '' -def _resolve_quote_sender_label(ref_user, ref_display_name, is_group, chat_username, chat_display_name, names): - if is_group: - if ref_user: - return _display_name_for_username(ref_user, names) - return ref_display_name or '' - - self_username = _get_self_username() - if ref_user: - if ref_user == chat_username: - return chat_display_name - if self_username and ref_user == self_username: - return 'me' - return names.get(ref_user, ref_display_name or ref_user) - if ref_display_name: - if ref_display_name == chat_display_name: - return chat_display_name - self_display_name = names.get(self_username, self_username) if self_username else '' - if self_display_name and ref_display_name == self_display_name: - return 'me' - return ref_display_name - return '' - - -def _parse_xml_root(content): - if not content or len(content) > _XML_PARSE_MAX_LEN or _XML_UNSAFE_RE.search(content): - return None - - try: - return ET.fromstring(content) - except ET.ParseError: - return None - - -def _parse_int(value, fallback=0): - try: - return int(value) - except (TypeError, ValueError): - return fallback - - -def _format_app_message_text(content, local_type, is_group, chat_username, chat_display_name, names): - if not content or ' _XML_PARSE_MAX_LEN or _XML_UNSAFE_RE.search(content): + return None + + try: + return ET.fromstring(content) + except ET.ParseError: + return None + + +def _parse_int(value, fallback=0): + try: + return int(value) + except (TypeError, ValueError): + return fallback + + +def _format_app_message_text(content, local_type, is_group, chat_username, chat_display_name, names): + if not content or ' end_ts: + raise ValueError('start_time 不能晚于 end_time') + return start_ts, end_ts + + +def _build_message_filters(start_ts=None, end_ts=None, keyword=''): + clauses = [] + params = [] + if start_ts is not None: + clauses.append('create_time >= ?') + params.append(start_ts) + if end_ts is not None: + clauses.append('create_time <= ?') + params.append(end_ts) + if keyword: + clauses.append('message_content LIKE ?') + params.append(f'%{keyword}%') + return clauses, params + + +def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0): + if not _is_safe_msg_table_name(table_name): + raise ValueError(f'非法消息表名: {table_name}') + + clauses, params = _build_message_filters(start_ts, end_ts, keyword) + where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' + sql = f""" + SELECT local_id, local_type, create_time, real_sender_id, message_content, + WCDB_CT_message_content + FROM [{table_name}] + {where_sql} + ORDER BY create_time DESC + LIMIT ? OFFSET ? + """ + return conn.execute(sql, (*params, limit, offset)).fetchall() + + +def _resolve_chat_context(chat_name): + username = resolve_username(chat_name) + if not username: + return None + + names = get_contact_names() + display_name = names.get(username, username) + db_path, table_name = _find_msg_table_for_user(username) + if not db_path: + return { + 'query': chat_name, + 'username': username, + 'display_name': display_name, + 'db_path': None, + 'table_name': None, + 'is_group': '@chatroom' in username, + } + + return { + 'query': chat_name, + 'username': username, + 'display_name': display_name, + 'db_path': db_path, + 'table_name': table_name, + 'is_group': '@chatroom' in username, + } + + +def _resolve_chat_contexts(chat_names): + if not chat_names: + raise ValueError('chat_names 不能为空') + + resolved = [] + unresolved = [] + missing_tables = [] + seen = set() + + for chat_name in chat_names: + name = (chat_name or '').strip() + if not name: + unresolved.append('(空)') + continue + ctx = _resolve_chat_context(name) + if not ctx: + unresolved.append(name) + continue + if not ctx['db_path']: + missing_tables.append(ctx['display_name']) + continue + if ctx['username'] in seen: + continue + seen.add(ctx['username']) + resolved.append(ctx) + + return resolved, unresolved, missing_tables + + +def _normalize_chat_names(chat_name): + if chat_name is None: + return [] + if isinstance(chat_name, str): + value = chat_name.strip() + return [value] if value else [] + if isinstance(chat_name, (list, tuple, set)): + normalized = [] + for item in chat_name: + if item is None: + continue + value = str(item).strip() + if value: + normalized.append(value) + return normalized + value = str(chat_name).strip() + return [value] if value else [] + + +def _format_history_lines(rows, username, display_name, is_group, names, id_to_username): + lines = [] + for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows): + time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') + content = _decompress_content(content, ct) + if content is None: + content = '(无法解压)' + + sender, text = _format_message_text( + local_id, local_type, content, is_group, username, display_name, names + ) + if text and len(text) > 500: + text = text[:500] + '...' + + sender_label = _resolve_sender_label( + real_sender_id, sender, is_group, username, display_name, names, id_to_username + ) + if sender_label: + lines.append(f'[{time_str}] {sender_label}: {text}') + else: + lines.append(f'[{time_str}] {text}') + return lines + + +def _build_search_entry(row, ctx, names, id_to_username): + local_id, local_type, create_time, real_sender_id, content, ct = row + content = _decompress_content(content, ct) + if content is None: + return None + + sender, text = _format_message_text( + local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names + ) + if text and len(text) > 300: + text = text[:300] + '...' + + sender_label = _resolve_sender_label( + real_sender_id, + sender, + ctx['is_group'], + ctx['username'], + ctx['display_name'], + names, + id_to_username, + ) + time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') + entry = f"[{time_str}] [{ctx['display_name']}]" + if sender_label: + entry += f" {sender_label}:" + entry += f" {text}" + return create_time, entry + + # ============ MCP Server ============ mcp = FastMCP("wechat", instructions="查询微信消息、联系人等数据") @@ -664,92 +864,218 @@ def get_recent_sessions(limit: int = 20) -> str: @mcp.tool() -def get_chat_history(chat_name: str, limit: int = 50) -> str: +def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str: """获取指定聊天的消息记录。 Args: chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配 limit: 返回的消息数量,默认50 + offset: 分页偏移量,默认0 + start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS + end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS """ - username = resolve_username(chat_name) - if not username: + try: + _validate_pagination(limit, offset) + start_ts, end_ts = _parse_time_range(start_time, end_time) + except ValueError as e: + return f"错误: {e}" + + ctx = _resolve_chat_context(chat_name) + if not ctx: return f"找不到聊天对象: {chat_name}\n提示: 可以用 get_contacts(query='{chat_name}') 搜索联系人" + if not ctx['db_path']: + return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" names = get_contact_names() - display_name = names.get(username, username) - is_group = '@chatroom' in username - - db_path, table_name = _find_msg_table_for_user(username) - if not db_path: - return f"找不到 {display_name} 的消息记录(可能在未解密的DB中或无消息)" - - conn = sqlite3.connect(db_path) - try: - id_to_username = _load_name2id_maps(conn) - rows = conn.execute(f""" - SELECT local_id, local_type, create_time, real_sender_id, message_content, - WCDB_CT_message_content - FROM [{table_name}] - ORDER BY create_time DESC - LIMIT ? - """, (limit,)).fetchall() + conn = sqlite3.connect(ctx['db_path']) + try: + id_to_username = _load_name2id_maps(conn) + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + limit=limit, + offset=offset, + ) except Exception as e: conn.close() return f"查询失败: {e}" conn.close() if not rows: - return f"{display_name} 无消息记录" + return f"{ctx['display_name']} 无消息记录" - lines = [] - for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows): - time_str = datetime.fromtimestamp(create_time).strftime('%m-%d %H:%M') + lines = _format_history_lines( + rows, + ctx['username'], + ctx['display_name'], + ctx['is_group'], + names, + id_to_username, + ) - # zstd 解压 - content = _decompress_content(content, ct) - if content is None: - content = '(无法解压)' - - sender, text = _format_message_text( - local_id, local_type, content, is_group, username, display_name, names - ) - - if text and len(text) > 500: - text = text[:500] + "..." - - sender_label = _resolve_sender_label( - real_sender_id, sender, is_group, username, display_name, names, id_to_username - ) - if sender_label: - lines.append(f"[{time_str}] {sender_label}: {text}") - else: - lines.append(f"[{time_str}] {text}") - - header = f"{display_name} 的最近 {len(lines)} 条消息" - if is_group: + header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})" + if ctx['is_group']: header += " [群聊]" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" return header + ":\n\n" + "\n".join(lines) -@mcp.tool() -def search_messages(keyword: str, limit: int = 20) -> str: - """在所有聊天记录中搜索包含关键词的消息。 - - Args: - keyword: 搜索关键词 - limit: 返回的结果数量,默认20 - """ - if not keyword or len(keyword) < 1: - return "请提供搜索关键词" - - names = get_contact_names() - results = [] - - for rel_key in MSG_DB_KEYS: - if len(results) >= limit: - break - - path = _cache.get(rel_key) +@mcp.tool() +def search_messages( + keyword: str, + chat_name: str | list[str] | None = None, + start_time: str = "", + end_time: str = "", + limit: int = 20, + offset: int = 0, +) -> str: + """搜索消息内容,支持全库、单个聊天对象、多个聊天对象,以及时间范围和分页。 + + Args: + keyword: 搜索关键词 + chat_name: 聊天对象名称,可为空、单个字符串或字符串列表 + start_time: 起始时间,可为空 + end_time: 结束时间,可为空 + limit: 返回的结果数量,默认20 + offset: 分页偏移量,默认0 + """ + if not keyword or len(keyword) < 1: + return "请提供搜索关键词" + + chat_names = _normalize_chat_names(chat_name) + + try: + _validate_pagination(limit, offset) + start_ts, end_ts = _parse_time_range(start_time, end_time) + except ValueError as e: + return f"错误: {e}" + + if len(chat_names) == 1: + ctx = _resolve_chat_context(chat_names[0]) + if not ctx: + return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人" + if not ctx['db_path']: + return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" + + names = get_contact_names() + conn = sqlite3.connect(ctx['db_path']) + try: + id_to_username = _load_name2id_maps(conn) + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + keyword=keyword, + limit=limit, + offset=offset, + ) + except Exception as e: + conn.close() + return f"查询失败: {e}" + conn.close() + + if not rows: + return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息" + + entries = [] + for row in rows: + formatted = _build_search_entry(row, ctx, names, id_to_username) + if formatted: + entries.append(formatted) + + if not entries: + return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的可读消息" + + entries.sort(key=lambda x: x[0]) + header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + return header + ":\n\n" + "\n\n".join(item[1] for item in entries) + + if len(chat_names) > 1: + try: + resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names) + except ValueError as e: + return f"错误: {e}" + + if not resolved_contexts: + details = [] + if unresolved: + details.append("未找到联系人: " + "、".join(unresolved)) + if missing_tables: + details.append("无消息表: " + "、".join(missing_tables)) + suffix = f"\n{chr(10).join(details)}" if details else "" + return f"错误: 没有可查询的聊天对象{suffix}" + + names = get_contact_names() + collected = [] + failures = [] + per_chat_limit = limit + offset + + for ctx in resolved_contexts: + conn = sqlite3.connect(ctx['db_path']) + try: + id_to_username = _load_name2id_maps(conn) + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + keyword=keyword, + limit=per_chat_limit, + offset=0, + ) + for row in rows: + formatted = _build_search_entry(row, ctx, names, id_to_username) + if formatted: + collected.append(formatted) + except Exception as e: + failures.append(f"{ctx['display_name']}: {e}") + finally: + conn.close() + + collected.sort(key=lambda x: x[0], reverse=True) + paged = collected[offset:offset + limit] + + notes = [] + if unresolved: + notes.append("未找到联系人: " + "、".join(unresolved)) + if missing_tables: + notes.append("无消息表: " + "、".join(missing_tables)) + if failures: + notes.append("查询失败: " + ";".join(failures)) + + if not paged: + header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if notes: + header += "\n" + "\n".join(notes) + return header + + header = ( + f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果" + f"(offset={offset}, limit={limit})" + ) + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if notes: + header += "\n" + "\n".join(notes) + return header + ":\n\n" + "\n\n".join(item[1] for item in paged) + + names = get_contact_names() + results = [] + max_results = limit + offset + + for rel_key in MSG_DB_KEYS: + if len(results) >= max_results: + break + + path = _cache.get(rel_key) if not path: continue @@ -768,25 +1094,27 @@ def search_messages(keyword: str, limit: int = 20) -> str: name2id[f"Msg_{h}"] = r[0] except Exception: pass - - for (tname,) in tables: - if len(results) >= limit: - break - username = name2id.get(tname, '') - is_group = '@chatroom' in username - display = names.get(username, username) if username else tname - - try: - rows = conn.execute(f""" - SELECT local_type, create_time, message_content, - WCDB_CT_message_content - FROM [{tname}] - WHERE message_content LIKE ? - ORDER BY create_time DESC - LIMIT ? - """, (f'%{keyword}%', limit - len(results))).fetchall() - except Exception: - continue + + for (tname,) in tables: + if len(results) >= max_results: + break + username = name2id.get(tname, '') + is_group = '@chatroom' in username + display = names.get(username, username) if username else tname + + try: + clauses, params = _build_message_filters(start_ts, end_ts, keyword) + where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' + rows = conn.execute(f""" + SELECT local_type, create_time, message_content, + WCDB_CT_message_content + FROM [{tname}] + {where_sql} + ORDER BY create_time DESC + LIMIT ? OFFSET ? + """, (*params, max_results - len(results), 0)).fetchall() + except Exception: + continue for local_type, ts, content, ct in rows: content = _decompress_content(content, ct) @@ -808,17 +1136,19 @@ def search_messages(keyword: str, limit: int = 20) -> str: finally: conn.close() - results.sort(key=lambda x: x[0], reverse=True) - entries = [r[1] for r in results[:limit]] + results.sort(key=lambda x: x[0], reverse=True) + entries = [r[1] for r in results[offset:offset + limit]] + + if not entries: + return f"未找到包含 \"{keyword}\" 的消息" + + header = f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + return header + ":\n\n" + "\n\n".join(entries) - if not entries: - return f"未找到包含 \"{keyword}\" 的消息" - - return f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果:\n\n" + "\n\n".join(entries) - - -@mcp.tool() -def get_contacts(query: str = "", limit: int = 50) -> str: +@mcp.tool() +def get_contacts(query: str = "", limit: int = 50) -> str: """搜索或列出微信联系人。 Args: From 4bda20f7aad82b412e33e7f69101c21ba67cc9a7 Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 10:24:23 +0800 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20=E6=9B=B4=E6=96=B0=20README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/README.md b/README.md index b7003c9..4062227 100644 --- a/README.md +++ b/README.md @@ -139,12 +139,6 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv 前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。 -新增能力: -- `get_chat_history` 支持 `offset` 分页,以及 `start_time` / `end_time` 时间范围过滤 -- `search_messages` 支持“全库 / 单个联系人或群聊 / 多个联系人或群聊”的统一搜索入口 -- `search_messages` 在定向搜索时会汇报无法解析或无消息表的对象 -- 时间格式支持 `YYYY-MM-DD`、`YYYY-MM-DD HH:MM`、`YYYY-MM-DD HH:MM:SS` - **[查看使用案例 →](USAGE.md)** ### 图片解密 (V2 格式) From b6237114104a533aafd0810b263974cf87a1454e Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 14:07:51 +0800 Subject: [PATCH 3/5] Add MCP search unit tests --- README.md | 2 + mcp_server.py | 787 +++++++++++++++++++------------- tests/test_mcp_server_search.py | 382 ++++++++++++++++ 3 files changed, 853 insertions(+), 318 deletions(-) create mode 100644 tests/test_mcp_server_search.py diff --git a/README.md b/README.md index 4062227..8807e33 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,8 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv 前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。 +说明:`get_chat_history` 和 `search_messages` 的 `limit` 最大为 `500`。 + **[查看使用案例 →](USAGE.md)** ### 图片解密 (V2 格式) diff --git a/mcp_server.py b/mcp_server.py index 2f141e1..be0abac 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -5,9 +5,10 @@ Based on FastMCP (stdio transport), reuses existing decryption. Runs on Windows Python (needs access to D:\ WeChat databases). """ -import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re -import hmac as hmac_mod -from datetime import datetime +import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re +import hmac as hmac_mod +from contextlib import closing +from datetime import datetime import xml.etree.ElementTree as ET from Crypto.Cipher import AES from mcp.server.fastmcp import FastMCP @@ -219,9 +220,10 @@ atexit.register(_cache.cleanup) _contact_names = None # {username: display_name} _contact_full = None # [{username, nick_name, remark}] -_self_username = None -_XML_UNSAFE_RE = re.compile(r' _QUERY_LIMIT_MAX: + raise ValueError(f"limit 不能大于 {_QUERY_LIMIT_MAX}") + if offset < 0: + raise ValueError("offset 不能小于 0") def _parse_time_value(value, field_name, is_end=False): @@ -650,49 +691,54 @@ def _build_message_filters(start_ts=None, end_ts=None, keyword=''): return clauses, params -def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0): - if not _is_safe_msg_table_name(table_name): - raise ValueError(f'非法消息表名: {table_name}') - - clauses, params = _build_message_filters(start_ts, end_ts, keyword) - where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' - sql = f""" - SELECT local_id, local_type, create_time, real_sender_id, message_content, - WCDB_CT_message_content - FROM [{table_name}] - {where_sql} - ORDER BY create_time DESC - LIMIT ? OFFSET ? - """ - return conn.execute(sql, (*params, limit, offset)).fetchall() +def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0): + if not _is_safe_msg_table_name(table_name): + raise ValueError(f'非法消息表名: {table_name}') + + clauses, params = _build_message_filters(start_ts, end_ts, keyword) + where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' + sql = f""" + SELECT local_id, local_type, create_time, real_sender_id, message_content, + WCDB_CT_message_content + FROM [{table_name}] + {where_sql} + ORDER BY create_time DESC + """ + if limit is None: + return conn.execute(sql, params).fetchall() + sql += "\n LIMIT ? OFFSET ?" + return conn.execute(sql, (*params, limit, offset)).fetchall() -def _resolve_chat_context(chat_name): - username = resolve_username(chat_name) - if not username: - return None - - names = get_contact_names() - display_name = names.get(username, username) - db_path, table_name = _find_msg_table_for_user(username) - if not db_path: - return { - 'query': chat_name, - 'username': username, - 'display_name': display_name, - 'db_path': None, - 'table_name': None, - 'is_group': '@chatroom' in username, - } - - return { - 'query': chat_name, - 'username': username, - 'display_name': display_name, - 'db_path': db_path, - 'table_name': table_name, - 'is_group': '@chatroom' in username, - } +def _resolve_chat_context(chat_name): + username = resolve_username(chat_name) + if not username: + return None + + names = get_contact_names() + display_name = names.get(username, username) + message_tables = _find_msg_tables_for_user(username) + if not message_tables: + return { + 'query': chat_name, + 'username': username, + 'display_name': display_name, + 'db_path': None, + 'table_name': None, + 'message_tables': [], + 'is_group': '@chatroom' in username, + } + + primary = message_tables[0] + return { + 'query': chat_name, + 'username': username, + 'display_name': display_name, + 'db_path': primary['db_path'], + 'table_name': primary['table_name'], + 'message_tables': message_tables, + 'is_group': '@chatroom' in username, + } def _resolve_chat_contexts(chat_names): @@ -713,11 +759,11 @@ def _resolve_chat_contexts(chat_names): if not ctx: unresolved.append(name) continue - if not ctx['db_path']: - missing_tables.append(ctx['display_name']) - continue - if ctx['username'] in seen: - continue + if not ctx['message_tables']: + missing_tables.append(ctx['display_name']) + continue + if ctx['username'] in seen: + continue seen.add(ctx['username']) resolved.append(ctx) @@ -743,31 +789,20 @@ def _normalize_chat_names(chat_name): return [value] if value else [] -def _format_history_lines(rows, username, display_name, is_group, names, id_to_username): - lines = [] - for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows): - time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') - content = _decompress_content(content, ct) - if content is None: - content = '(无法解压)' - - sender, text = _format_message_text( - local_id, local_type, content, is_group, username, display_name, names - ) - if text and len(text) > 500: - text = text[:500] + '...' - - sender_label = _resolve_sender_label( - real_sender_id, sender, is_group, username, display_name, names, id_to_username - ) - if sender_label: - lines.append(f'[{time_str}] {sender_label}: {text}') - else: - lines.append(f'[{time_str}] {text}') - return lines +def _format_history_lines(rows, username, display_name, is_group, names, id_to_username): + lines = [] + ctx = { + 'username': username, + 'display_name': display_name, + 'is_group': is_group, + } + for row in reversed(rows): + _, line = _build_history_line(row, ctx, names, id_to_username) + lines.append(line) + return lines -def _build_search_entry(row, ctx, names, id_to_username): +def _build_search_entry(row, ctx, names, id_to_username): local_id, local_type, create_time, real_sender_id, content, ct = row content = _decompress_content(content, ct) if content is None: @@ -792,8 +827,291 @@ def _build_search_entry(row, ctx, names, id_to_username): entry = f"[{time_str}] [{ctx['display_name']}]" if sender_label: entry += f" {sender_label}:" - entry += f" {text}" - return create_time, entry + entry += f" {text}" + return create_time, entry + + +def _build_history_line(row, ctx, names, id_to_username): + local_id, local_type, create_time, real_sender_id, content, ct = row + time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') + content = _decompress_content(content, ct) + if content is None: + content = '(无法解压)' + + sender, text = _format_message_text( + local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names + ) + if text and len(text) > 500: + text = text[:500] + '...' + + sender_label = _resolve_sender_label( + real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username + ) + if sender_label: + return create_time, f'[{time_str}] {sender_label}: {text}' + return create_time, f'[{time_str}] {text}' + + +def _get_chat_message_tables(ctx): + if ctx.get('message_tables'): + return ctx['message_tables'] + if ctx.get('db_path') and ctx.get('table_name'): + return [{'db_path': ctx['db_path'], 'table_name': ctx['table_name']}] + return [] + + +def _iter_table_contexts(ctx): + for table in _get_chat_message_tables(ctx): + yield { + 'query': ctx['query'], + 'username': ctx['username'], + 'display_name': ctx['display_name'], + 'db_path': table['db_path'], + 'table_name': table['table_name'], + 'is_group': ctx['is_group'], + } + + +def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0): + collected = [] + failures = [] + + for table_ctx in _iter_table_contexts(ctx): + try: + with closing(sqlite3.connect(table_ctx['db_path'])) as conn: + id_to_username = _load_name2id_maps(conn) + rows = _query_messages( + conn, + table_ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + limit=None, + ) + for row in rows: + collected.append(_build_history_line(row, table_ctx, names, id_to_username)) + except Exception as e: + failures.append(f"{table_ctx['db_path']}: {e}") + + ordered = sorted(collected, key=lambda item: item[0], reverse=True) + paged = ordered[offset:offset + limit] + paged.sort(key=lambda item: item[0]) + return [line for _, line in paged], failures + + +def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None): + collected = [] + failures = [] + contexts_by_db = {} + for table_ctx in _iter_table_contexts(ctx): + contexts_by_db.setdefault(table_ctx['db_path'], []).append(table_ctx) + + for db_path, db_contexts in contexts_by_db.items(): + try: + with closing(sqlite3.connect(db_path)) as conn: + db_entries, db_failures = _collect_search_entries( + conn, + db_contexts, + names, + keyword, + start_ts=start_ts, + end_ts=end_ts, + ) + collected.extend(db_entries) + failures.extend(db_failures) + except Exception as e: + failures.extend(f"{table_ctx['display_name']}: {e}" for table_ctx in db_contexts) + + return collected, failures + + +def _load_search_contexts_from_db(conn, db_path, names): + tables = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'" + ).fetchall() + + table_to_username = {} + try: + for (user_name,) in conn.execute("SELECT user_name FROM Name2Id").fetchall(): + if not user_name: + continue + table_hash = hashlib.md5(user_name.encode()).hexdigest() + table_to_username[f"Msg_{table_hash}"] = user_name + except sqlite3.Error: + pass + + contexts = [] + for (table_name,) in tables: + username = table_to_username.get(table_name, '') + display_name = names.get(username, username) if username else table_name + contexts.append({ + 'query': display_name, + 'username': username, + 'display_name': display_name, + 'db_path': db_path, + 'table_name': table_name, + 'is_group': '@chatroom' in username, + }) + return contexts + + +def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None): + collected = [] + failures = [] + id_to_username = _load_name2id_maps(conn) + + for ctx in contexts: + try: + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + keyword=keyword, + limit=None, + ) + for row in rows: + formatted = _build_search_entry(row, ctx, names, id_to_username) + if formatted: + collected.append(formatted) + except Exception as e: + failures.append(f"{ctx['display_name']}: {e}") + + return collected, failures + + +def _page_search_entries(entries, limit, offset): + ordered = sorted(entries, key=lambda x: x[0], reverse=True) + paged = ordered[offset:offset + limit] + paged.sort(key=lambda x: x[0]) + return paged + + +def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset): + names = get_contact_names() + + entries, failures = _collect_chat_search_entries( + ctx, + names, + keyword, + start_ts=start_ts, + end_ts=end_ts, + ) + + paged = _page_search_entries(entries, limit, offset) + + if not paged: + if failures: + return "查询失败: " + ";".join(failures) + return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息" + + header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(paged)} 条结果(offset={offset}, limit={limit})" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if failures: + header += "\n查询失败: " + ";".join(failures) + return header + ":\n\n" + "\n\n".join(item[1] for item in paged) + + +def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, end_time, limit, offset): + try: + resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names) + except ValueError as e: + return f"错误: {e}" + + if not resolved_contexts: + details = [] + if unresolved: + details.append("未找到联系人: " + "、".join(unresolved)) + if missing_tables: + details.append("无消息表: " + "、".join(missing_tables)) + suffix = f"\n{chr(10).join(details)}" if details else "" + return f"错误: 没有可查询的聊天对象{suffix}" + + names = get_contact_names() + collected = [] + failures = [] + for ctx in resolved_contexts: + chat_entries, chat_failures = _collect_chat_search_entries( + ctx, + names, + keyword, + start_ts=start_ts, + end_ts=end_ts, + ) + collected.extend(chat_entries) + failures.extend(chat_failures) + + paged = _page_search_entries(collected, limit, offset) + + notes = [] + if unresolved: + notes.append("未找到联系人: " + "、".join(unresolved)) + if missing_tables: + notes.append("无消息表: " + "、".join(missing_tables)) + if failures: + notes.append("查询失败: " + ";".join(failures)) + + if not paged: + header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if notes: + header += "\n" + "\n".join(notes) + return header + + header = ( + f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果" + f"(offset={offset}, limit={limit})" + ) + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if notes: + header += "\n" + "\n".join(notes) + return header + ":\n\n" + "\n\n".join(item[1] for item in paged) + + +def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, offset): + names = get_contact_names() + collected = [] + failures = [] + + for rel_key in MSG_DB_KEYS: + path = _cache.get(rel_key) + if not path: + continue + + try: + with closing(sqlite3.connect(path)) as conn: + contexts = _load_search_contexts_from_db(conn, path, names) + db_entries, db_failures = _collect_search_entries( + conn, + contexts, + names, + keyword, + start_ts=start_ts, + end_ts=end_ts, + ) + collected.extend(db_entries) + failures.extend(db_failures) + except Exception as e: + failures.append(f"{rel_key}: {e}") + + paged = _page_search_entries(collected, limit, offset) + + if not paged: + header = f"未找到包含 \"{keyword}\" 的消息" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if failures: + header += "\n查询失败: " + ";".join(failures) + return header + + header = f"搜索 \"{keyword}\" 找到 {len(paged)} 条结果(offset={offset}, limit={limit})" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if failures: + header += "\n查询失败: " + ";".join(failures) + return header + ":\n\n" + "\n\n".join(item[1] for item in paged) # ============ MCP Server ============ @@ -864,15 +1182,15 @@ def get_recent_sessions(limit: int = 20) -> str: @mcp.tool() -def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str: - """获取指定聊天的消息记录。 - - Args: - chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配 - limit: 返回的消息数量,默认50 - offset: 分页偏移量,默认0 - start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS - end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS +def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str: + """获取指定聊天的消息记录。 + + Args: + chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配 + limit: 返回的消息数量,默认50,最大500 + offset: 分页偏移量,默认0 + start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS + end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS """ try: _validate_pagination(limit, offset) @@ -886,41 +1204,29 @@ def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_tim if not ctx['db_path']: return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" - names = get_contact_names() - conn = sqlite3.connect(ctx['db_path']) - try: - id_to_username = _load_name2id_maps(conn) - rows = _query_messages( - conn, - ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - limit=limit, - offset=offset, - ) - except Exception as e: - conn.close() - return f"查询失败: {e}" - conn.close() - - if not rows: - return f"{ctx['display_name']} 无消息记录" - - lines = _format_history_lines( - rows, - ctx['username'], - ctx['display_name'], - ctx['is_group'], - names, - id_to_username, - ) - - header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})" - if ctx['is_group']: - header += " [群聊]" - if start_time or end_time: - header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" - return header + ":\n\n" + "\n".join(lines) + names = get_contact_names() + lines, failures = _collect_chat_history_lines( + ctx, + names, + start_ts=start_ts, + end_ts=end_ts, + limit=limit, + offset=offset, + ) + + if not lines: + if failures: + return "查询失败: " + ";".join(failures) + return f"{ctx['display_name']} 无消息记录" + + header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})" + if ctx['is_group']: + header += " [群聊]" + if start_time or end_time: + header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" + if failures: + header += "\n查询失败: " + ";".join(failures) + return header + ":\n\n" + "\n".join(lines) @mcp.tool() @@ -939,7 +1245,7 @@ def search_messages( chat_name: 聊天对象名称,可为空、单个字符串或字符串列表 start_time: 起始时间,可为空 end_time: 结束时间,可为空 - limit: 返回的结果数量,默认20 + limit: 返回的结果数量,默认20,最大500 offset: 分页偏移量,默认0 """ if not keyword or len(keyword) < 1: @@ -959,193 +1265,38 @@ def search_messages( return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人" if not ctx['db_path']: return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" - - names = get_contact_names() - conn = sqlite3.connect(ctx['db_path']) - try: - id_to_username = _load_name2id_maps(conn) - rows = _query_messages( - conn, - ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - keyword=keyword, - limit=limit, - offset=offset, - ) - except Exception as e: - conn.close() - return f"查询失败: {e}" - conn.close() - - if not rows: - return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息" - - entries = [] - for row in rows: - formatted = _build_search_entry(row, ctx, names, id_to_username) - if formatted: - entries.append(formatted) - - if not entries: - return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的可读消息" - - entries.sort(key=lambda x: x[0]) - header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" - if start_time or end_time: - header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" - return header + ":\n\n" + "\n\n".join(item[1] for item in entries) + return _search_single_chat( + ctx, + keyword, + start_ts, + end_ts, + start_time, + end_time, + limit, + offset, + ) if len(chat_names) > 1: - try: - resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names) - except ValueError as e: - return f"错误: {e}" - - if not resolved_contexts: - details = [] - if unresolved: - details.append("未找到联系人: " + "、".join(unresolved)) - if missing_tables: - details.append("无消息表: " + "、".join(missing_tables)) - suffix = f"\n{chr(10).join(details)}" if details else "" - return f"错误: 没有可查询的聊天对象{suffix}" - - names = get_contact_names() - collected = [] - failures = [] - per_chat_limit = limit + offset - - for ctx in resolved_contexts: - conn = sqlite3.connect(ctx['db_path']) - try: - id_to_username = _load_name2id_maps(conn) - rows = _query_messages( - conn, - ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - keyword=keyword, - limit=per_chat_limit, - offset=0, - ) - for row in rows: - formatted = _build_search_entry(row, ctx, names, id_to_username) - if formatted: - collected.append(formatted) - except Exception as e: - failures.append(f"{ctx['display_name']}: {e}") - finally: - conn.close() - - collected.sort(key=lambda x: x[0], reverse=True) - paged = collected[offset:offset + limit] - - notes = [] - if unresolved: - notes.append("未找到联系人: " + "、".join(unresolved)) - if missing_tables: - notes.append("无消息表: " + "、".join(missing_tables)) - if failures: - notes.append("查询失败: " + ";".join(failures)) - - if not paged: - header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息" - if start_time or end_time: - header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" - if notes: - header += "\n" + "\n".join(notes) - return header - - header = ( - f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果" - f"(offset={offset}, limit={limit})" + return _search_multiple_chats( + chat_names, + keyword, + start_ts, + end_ts, + start_time, + end_time, + limit, + offset, ) - if start_time or end_time: - header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" - if notes: - header += "\n" + "\n".join(notes) - return header + ":\n\n" + "\n\n".join(item[1] for item in paged) - names = get_contact_names() - results = [] - max_results = limit + offset - - for rel_key in MSG_DB_KEYS: - if len(results) >= max_results: - break - - path = _cache.get(rel_key) - if not path: - continue - - conn = sqlite3.connect(path) - try: - # 获取所有 Msg_ 表 - tables = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'" - ).fetchall() - - # 获取 Name2Id 映射(hash -> username 反查) - name2id = {} - try: - for r in conn.execute("SELECT user_name FROM Name2Id").fetchall(): - h = hashlib.md5(r[0].encode()).hexdigest() - name2id[f"Msg_{h}"] = r[0] - except Exception: - pass - - for (tname,) in tables: - if len(results) >= max_results: - break - username = name2id.get(tname, '') - is_group = '@chatroom' in username - display = names.get(username, username) if username else tname - - try: - clauses, params = _build_message_filters(start_ts, end_ts, keyword) - where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' - rows = conn.execute(f""" - SELECT local_type, create_time, message_content, - WCDB_CT_message_content - FROM [{tname}] - {where_sql} - ORDER BY create_time DESC - LIMIT ? OFFSET ? - """, (*params, max_results - len(results), 0)).fetchall() - except Exception: - continue - - for local_type, ts, content, ct in rows: - content = _decompress_content(content, ct) - if content is None: - continue - sender, text = _parse_message_content(content, local_type, is_group) - time_str = datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M') - sender_name = '' - if is_group and sender: - sender_name = names.get(sender, sender) - - entry = f"[{time_str}] [{display}]" - if sender_name: - entry += f" {sender_name}:" - entry += f" {text}" - if len(entry) > 300: - entry = entry[:300] + "..." - results.append((ts, entry)) - finally: - conn.close() - - results.sort(key=lambda x: x[0], reverse=True) - entries = [r[1] for r in results[offset:offset + limit]] - - if not entries: - return f"未找到包含 \"{keyword}\" 的消息" - - header = f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" - if start_time or end_time: - header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" - return header + ":\n\n" + "\n\n".join(entries) + return _search_all_messages( + keyword, + start_ts, + end_ts, + start_time, + end_time, + limit, + offset, + ) @mcp.tool() def get_contacts(query: str = "", limit: int = 50) -> str: diff --git a/tests/test_mcp_server_search.py b/tests/test_mcp_server_search.py new file mode 100644 index 0000000..8505dc0 --- /dev/null +++ b/tests/test_mcp_server_search.py @@ -0,0 +1,382 @@ +import hashlib +import os +import sqlite3 +import tempfile +import unittest +from unittest.mock import patch + +import mcp_server + + +class _FakeCache: + # 用最小缓存桩替代真实解密缓存,避免单元测试依赖本地微信环境。 + def __init__(self, mapping): + self._mapping = mapping + + def get(self, rel_key): + return self._mapping.get(rel_key) + + +def _msg_table_name(username): + # 生产代码使用 username 的 md5 作为消息表名,测试里保持一致。 + return f"Msg_{hashlib.md5(username.encode()).hexdigest()}" + + +def _create_message_db(path, chats): + # 构造最小可用消息库,只包含搜索/历史查询依赖的字段。 + conn = sqlite3.connect(path) + try: + conn.execute("CREATE TABLE Name2Id (user_name TEXT)") + for username, messages in chats.items(): + conn.execute("INSERT INTO Name2Id(user_name) VALUES (?)", (username,)) + table_name = _msg_table_name(username) + conn.execute( + f""" + CREATE TABLE [{table_name}] ( + local_id INTEGER, + local_type INTEGER, + create_time INTEGER, + real_sender_id INTEGER, + message_content TEXT, + WCDB_CT_message_content INTEGER + ) + """ + ) + for local_id, create_time, content in messages: + conn.execute( + f""" + INSERT INTO [{table_name}] ( + local_id, local_type, create_time, real_sender_id, + message_content, WCDB_CT_message_content + ) VALUES (?, ?, ?, ?, ?, ?) + """, + (local_id, 1, create_time, 0, content, 0), + ) + conn.commit() + finally: + conn.close() + + +class SearchMessagesTests(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.addCleanup(self.temp_dir.cleanup) + + def create_db(self, filename, chats): + path = os.path.join(self.temp_dir.name, filename) + _create_message_db(path, chats) + return path + + def test_validate_pagination_rejects_large_limit(self): + # 防止单次查询过大,保证 limit 上限校验存在。 + with self.assertRaisesRegex(ValueError, "limit 不能大于 500"): + mcp_server._validate_pagination(501, 0) + + def test_page_search_entries_returns_chronological_results_with_offset(self): + # 结果应先按最新时间分页,再把当前页恢复成时间正序输出。 + entries = [(1, "a"), (5, "e"), (3, "c"), (4, "d"), (2, "b")] + + paged = mcp_server._page_search_entries(entries, limit=2, offset=1) + + self.assertEqual(paged, [(3, "c"), (4, "d")]) + + def test_search_messages_single_chat_uses_offset_and_returns_page(self): + # 单聊分页应只返回当前页,并按聊天阅读顺序展示。 + db_path = self.create_db( + "single.db", + { + "alice": [ + (1, 100, "foo newest"), + (2, 90, "foo middle"), + (3, 80, "foo oldest"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ): + result = mcp_server.search_messages("foo", chat_name="Alice", limit=2, offset=1) + + self.assertIn('在 Alice 中搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result) + self.assertLess(result.index("foo oldest"), result.index("foo middle")) + self.assertNotIn("foo newest", result) + + def test_search_messages_multiple_chats_applies_global_pagination(self): + # 多个聊天联合搜索时,分页必须基于合并后的全局结果。 + db_path = self.create_db( + "multi.db", + { + "alice": [ + (1, 110, "foo a1"), + (2, 90, "foo a2"), + ], + "bob": [ + (1, 100, "foo b1"), + (2, 80, "foo b2"), + ], + }, + ) + contexts = [ + { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + }, + { + "query": "Bob", + "username": "bob", + "display_name": "Bob", + "db_path": db_path, + "table_name": _msg_table_name("bob"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}], + "is_group": False, + }, + ] + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object( + mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], []) + ): + result = mcp_server.search_messages("foo", chat_name=["Alice", "Bob"], limit=2, offset=1) + + self.assertIn('在 2 个聊天对象中搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result) + self.assertLess(result.index("foo a2"), result.index("foo b1")) + self.assertNotIn("foo a1", result) + self.assertNotIn("foo b2", result) + + def test_search_messages_all_messages_scans_all_dbs_before_paging(self): + # 全库搜索要先扫完整个库,再分页,不能被旧分片提前截断。 + older_db = self.create_db( + "older.db", + {"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]}, + ) + newer_db = self.create_db( + "newer.db", + {"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2")]}, + ) + fake_cache = _FakeCache({"older": older_db, "newer": newer_db}) + + with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object( + mcp_server, "_cache", fake_cache + ), patch.object( + mcp_server, + "get_contact_names", + return_value={"older_user": "Older", "newer_user": "Newer"}, + ): + result = mcp_server.search_messages("foo", limit=2, offset=0) + + self.assertIn('搜索 "foo" 找到 2 条结果(offset=0, limit=2)', result) + self.assertLess(result.index("foo newer 2"), result.index("foo newer 1")) + self.assertNotIn("foo older 1", result) + + def test_search_messages_single_chat_respects_time_range(self): + # 单聊搜索的开始/结束时间都必须严格生效。 + db_path = self.create_db( + "single_time.db", + { + "alice": [ + (1, 300, "foo in range"), + (2, 200, "foo too early"), + (3, 400, "foo too late"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_parse_time_range", return_value=(250, 350) + ): + result = mcp_server.search_messages( + "foo", + chat_name="Alice", + start_time="custom-start", + end_time="custom-end", + limit=20, + offset=0, + ) + + self.assertIn("时间范围: custom-start ~ custom-end", result) + self.assertIn("foo in range", result) + self.assertNotIn("foo too early", result) + self.assertNotIn("foo too late", result) + + def test_search_messages_multiple_chats_respects_time_range(self): + # 多聊联合搜索时,每个聊天对象都要套用同一时间范围。 + db_path = self.create_db( + "multi_time.db", + { + "alice": [(1, 300, "foo alice in range"), (2, 150, "foo alice too early")], + "bob": [(1, 320, "foo bob in range"), (2, 500, "foo bob too late")], + }, + ) + contexts = [ + { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + }, + { + "query": "Bob", + "username": "bob", + "display_name": "Bob", + "db_path": db_path, + "table_name": _msg_table_name("bob"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}], + "is_group": False, + }, + ] + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object( + mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], []) + ), patch.object( + mcp_server, "_parse_time_range", return_value=(250, 400) + ): + result = mcp_server.search_messages( + "foo", + chat_name=["Alice", "Bob"], + start_time="range-start", + end_time="range-end", + limit=20, + offset=0, + ) + + self.assertIn("时间范围: range-start ~ range-end", result) + self.assertIn("foo alice in range", result) + self.assertIn("foo bob in range", result) + self.assertNotIn("foo alice too early", result) + self.assertNotIn("foo bob too late", result) + + def test_search_messages_all_messages_respects_time_range(self): + # 全库搜索也不能返回时间范围外的消息。 + db_path = self.create_db( + "all_time.db", + { + "alice": [ + (1, 100, "foo too early"), + (2, 300, "foo in range"), + (3, 500, "foo too late"), + ] + }, + ) + fake_cache = _FakeCache({"all": db_path}) + + with patch.object(mcp_server, "MSG_DB_KEYS", ["all"]), patch.object( + mcp_server, "_cache", fake_cache + ), patch.object( + mcp_server, + "get_contact_names", + return_value={"alice": "Alice"}, + ), patch.object( + mcp_server, "_parse_time_range", return_value=(250, 350) + ): + result = mcp_server.search_messages( + "foo", + start_time="range-start", + end_time="range-end", + limit=20, + offset=0, + ) + + self.assertIn("时间范围: range-start ~ range-end", result) + self.assertIn("foo in range", result) + self.assertNotIn("foo too early", result) + self.assertNotIn("foo too late", result) + + def test_get_chat_history_merges_sharded_message_tables(self): + # 同一联系人跨多个 message_N.db 分片时,历史查询要先合并再分页。 + older_db = self.create_db("history_older.db", {"alice": [(1, 100, "old message")]}) + newer_db = self.create_db( + "history_newer.db", + {"alice": [(1, 300, "new message"), (2, 250, "middle message")]}, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": newer_db, + "table_name": _msg_table_name("alice"), + "message_tables": [ + {"db_path": older_db, "table_name": _msg_table_name("alice")}, + {"db_path": newer_db, "table_name": _msg_table_name("alice")}, + ], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ): + result = mcp_server.get_chat_history("Alice", limit=2, offset=0) + + self.assertIn("Alice 的消息记录(返回 2 条,offset=0, limit=2)", result) + self.assertIn("middle message", result) + self.assertIn("new message", result) + self.assertNotIn("old message", result) + + def test_search_messages_single_chat_merges_sharded_message_tables(self): + # 单聊搜索也要跨分片合并,否则最近消息可能查不到。 + older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]}) + newer_db = self.create_db( + "search_newer.db", + {"alice": [(1, 300, "foo new"), (2, 200, "foo middle")]}, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": newer_db, + "table_name": _msg_table_name("alice"), + "message_tables": [ + {"db_path": older_db, "table_name": _msg_table_name("alice")}, + {"db_path": newer_db, "table_name": _msg_table_name("alice")}, + ], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_parse_time_range", return_value=(150, 350) + ): + result = mcp_server.search_messages( + "foo", + chat_name="Alice", + start_time="range-start", + end_time="range-end", + limit=20, + offset=0, + ) + + self.assertIn("foo middle", result) + self.assertIn("foo new", result) + self.assertNotIn("foo old", result) + + +if __name__ == "__main__": + unittest.main() From 9ae558a31e363645658f201473cdbf358bb148d4 Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 16:36:55 +0800 Subject: [PATCH 4/5] Fix global search pagination --- mcp_server.py | 119 +++++++++++------- tests/test_mcp_server_search.py | 212 +++++++++++++++++++++++++++++++- 2 files changed, 286 insertions(+), 45 deletions(-) diff --git a/mcp_server.py b/mcp_server.py index be0abac..f1c1547 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -872,33 +872,49 @@ def _iter_table_contexts(ctx): } +def _candidate_page_size(limit, offset): + return limit + offset + + +def _message_query_batch_size(candidate_limit): + return candidate_limit + + +def _page_ranked_entries(entries, limit, offset): + ordered = sorted(entries, key=lambda item: item[0], reverse=True) + paged = ordered[offset:offset + limit] + paged.sort(key=lambda item: item[0]) + return paged + + def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0): collected = [] failures = [] + candidate_limit = _candidate_page_size(limit, offset) for table_ctx in _iter_table_contexts(ctx): try: with closing(sqlite3.connect(table_ctx['db_path'])) as conn: id_to_username = _load_name2id_maps(conn) + # 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。 rows = _query_messages( conn, table_ctx['table_name'], start_ts=start_ts, end_ts=end_ts, - limit=None, + limit=candidate_limit, + offset=0, ) for row in rows: collected.append(_build_history_line(row, table_ctx, names, id_to_username)) except Exception as e: failures.append(f"{table_ctx['db_path']}: {e}") - ordered = sorted(collected, key=lambda item: item[0], reverse=True) - paged = ordered[offset:offset + limit] - paged.sort(key=lambda item: item[0]) + paged = _page_ranked_entries(collected, limit, offset) return [line for _, line in paged], failures -def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None): +def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20): collected = [] failures = [] contexts_by_db = {} @@ -915,6 +931,7 @@ def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(db_entries) failures.extend(db_failures) @@ -954,25 +971,40 @@ def _load_search_contexts_from_db(conn, db_path, names): return contexts -def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None): +def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20): collected = [] failures = [] id_to_username = _load_name2id_maps(conn) + batch_size = _message_query_batch_size(candidate_limit) for ctx in contexts: try: - rows = _query_messages( - conn, - ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - keyword=keyword, - limit=None, - ) - for row in rows: - formatted = _build_search_entry(row, ctx, names, id_to_username) - if formatted: - collected.append(formatted) + fetch_offset = 0 + collected_before_table = len(collected) + # 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。 + while len(collected) - collected_before_table < candidate_limit: + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + keyword=keyword, + limit=batch_size, + offset=fetch_offset, + ) + if not rows: + break + fetch_offset += len(rows) + + for row in rows: + formatted = _build_search_entry(row, ctx, names, id_to_username) + if formatted: + collected.append(formatted) + if len(collected) - collected_before_table >= candidate_limit: + break + + if len(rows) < batch_size: + break except Exception as e: failures.append(f"{ctx['display_name']}: {e}") @@ -980,14 +1012,12 @@ def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_t def _page_search_entries(entries, limit, offset): - ordered = sorted(entries, key=lambda x: x[0], reverse=True) - paged = ordered[offset:offset + limit] - paged.sort(key=lambda x: x[0]) - return paged + return _page_ranked_entries(entries, limit, offset) def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset): names = get_contact_names() + candidate_limit = _candidate_page_size(limit, offset) entries, failures = _collect_chat_search_entries( ctx, @@ -995,6 +1025,7 @@ def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, li keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) paged = _page_search_entries(entries, limit, offset) @@ -1028,6 +1059,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en return f"错误: 没有可查询的聊天对象{suffix}" names = get_contact_names() + candidate_limit = _candidate_page_size(limit, offset) collected = [] failures = [] for ctx in resolved_contexts: @@ -1037,6 +1069,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(chat_entries) failures.extend(chat_failures) @@ -1074,6 +1107,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, names = get_contact_names() collected = [] failures = [] + candidate_limit = _candidate_page_size(limit, offset) for rel_key in MSG_DB_KEYS: path = _cache.get(rel_key) @@ -1090,6 +1124,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(db_entries) failures.extend(db_failures) @@ -1123,7 +1158,7 @@ _last_check_state = {} # {username: last_timestamp} @mcp.tool() -def get_recent_sessions(limit: int = 20) -> str: +def get_recent_sessions(limit: int = 20) -> str: """获取微信最近会话列表,包含最新消息摘要、未读数、时间等。 用于了解最近有哪些人/群在聊天。 @@ -1135,16 +1170,15 @@ def get_recent_sessions(limit: int = 20) -> str: return "错误: 无法解密 session.db" names = get_contact_names() - conn = sqlite3.connect(path) - rows = conn.execute(""" - SELECT username, unread_count, summary, last_timestamp, - last_msg_type, last_msg_sender, last_sender_display_name - FROM SessionTable - WHERE last_timestamp > 0 - ORDER BY last_timestamp DESC - LIMIT ? - """, (limit,)).fetchall() - conn.close() + with closing(sqlite3.connect(path)) as conn: + rows = conn.execute(""" + SELECT username, unread_count, summary, last_timestamp, + last_msg_type, last_msg_sender, last_sender_display_name + FROM SessionTable + WHERE last_timestamp > 0 + ORDER BY last_timestamp DESC + LIMIT ? + """, (limit,)).fetchall() results = [] for r in rows: @@ -1342,7 +1376,7 @@ def get_contacts(query: str = "", limit: int = 50) -> str: @mcp.tool() -def get_new_messages() -> str: +def get_new_messages() -> str: """获取自上次调用以来的新消息。首次调用返回最近的会话状态。""" global _last_check_state @@ -1351,15 +1385,14 @@ def get_new_messages() -> str: return "错误: 无法解密 session.db" names = get_contact_names() - conn = sqlite3.connect(path) - rows = conn.execute(""" - SELECT username, unread_count, summary, last_timestamp, - last_msg_type, last_msg_sender, last_sender_display_name - FROM SessionTable - WHERE last_timestamp > 0 - ORDER BY last_timestamp DESC - """).fetchall() - conn.close() + with closing(sqlite3.connect(path)) as conn: + rows = conn.execute(""" + SELECT username, unread_count, summary, last_timestamp, + last_msg_type, last_msg_sender, last_sender_display_name + FROM SessionTable + WHERE last_timestamp > 0 + ORDER BY last_timestamp DESC + """).fetchall() curr_state = {} for r in rows: diff --git a/tests/test_mcp_server_search.py b/tests/test_mcp_server_search.py index 8505dc0..859d996 100644 --- a/tests/test_mcp_server_search.py +++ b/tests/test_mcp_server_search.py @@ -157,8 +157,8 @@ class SearchMessagesTests(unittest.TestCase): self.assertNotIn("foo a1", result) self.assertNotIn("foo b2", result) - def test_search_messages_all_messages_scans_all_dbs_before_paging(self): - # 全库搜索要先扫完整个库,再分页,不能被旧分片提前截断。 + def test_search_messages_all_messages_merges_global_results_before_paging(self): + # 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。 older_db = self.create_db( "older.db", {"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]}, @@ -182,6 +182,44 @@ class SearchMessagesTests(unittest.TestCase): self.assertLess(result.index("foo newer 2"), result.index("foo newer 1")) self.assertNotIn("foo older 1", result) + def test_search_messages_all_messages_uses_bounded_sql_pagination(self): + # 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。 + older_db = self.create_db( + "older_paged.db", + {"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]}, + ) + newer_db = self.create_db( + "newer_paged.db", + {"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]}, + ) + fake_cache = _FakeCache({"older": older_db, "newer": newer_db}) + original_query_messages = mcp_server._query_messages + calls = [] + + def recording_query_messages(*args, **kwargs): + calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0))) + return original_query_messages(*args, **kwargs) + + with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object( + mcp_server, "_cache", fake_cache + ), patch.object( + mcp_server, + "get_contact_names", + return_value={"older_user": "Older", "newer_user": "Newer"}, + ), patch.object( + mcp_server, "_query_messages", side_effect=recording_query_messages + ): + result = mcp_server.search_messages("foo", limit=2, offset=1) + + self.assertIn('搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result) + self.assertEqual( + calls, + [ + (_msg_table_name("older_user"), 3, 0), + (_msg_table_name("newer_user"), 3, 0), + ], + ) + def test_search_messages_single_chat_respects_time_range(self): # 单聊搜索的开始/结束时间都必须严格生效。 db_path = self.create_db( @@ -339,6 +377,81 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("new message", result) self.assertNotIn("old message", result) + def test_get_chat_history_uses_bounded_sql_pagination(self): + # 历史查询应把 offset+limit 下推到 SQL,避免把整张消息表读出来后再切片。 + db_path = self.create_db( + "history_paged.db", + { + "alice": [ + (1, 400, "newest"), + (2, 300, "middle"), + (3, 200, "older"), + (4, 100, "oldest"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_query_messages = mcp_server._query_messages + calls = [] + + def recording_query_messages(*args, **kwargs): + calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0))) + return original_query_messages(*args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_query_messages", side_effect=recording_query_messages + ): + result = mcp_server.get_chat_history("Alice", limit=2, offset=1) + + self.assertIn("middle", result) + self.assertIn("older", result) + self.assertNotIn("newest", result) + self.assertNotIn("oldest", result) + self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)]) + + def test_get_chat_history_keeps_partial_results_when_formatting_fails(self): + # 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。 + db_path = self.create_db( + "history_partial_failure.db", + {"alice": [(1, 200, "good message"), (2, 100, "bad message")]}, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_build_history_line = mcp_server._build_history_line + + def flaky_build_history_line(row, *args, **kwargs): + if row[2] == 100: + raise ValueError("bad row") + return original_build_history_line(row, *args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_build_history_line", side_effect=flaky_build_history_line + ): + result = mcp_server.get_chat_history("Alice", limit=2, offset=0) + + self.assertIn("good message", result) + self.assertIn("查询失败:", result) + self.assertIn("bad row", result) + def test_search_messages_single_chat_merges_sharded_message_tables(self): # 单聊搜索也要跨分片合并,否则最近消息可能查不到。 older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]}) @@ -377,6 +490,101 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("foo new", result) self.assertNotIn("foo old", result) + def test_search_messages_keeps_partial_results_when_later_batch_fails(self): + # 后续批次失败时,前面已经拿到的有效结果不应被丢弃。 + db_path = self.create_db( + "search_partial_failure.db", + { + "alice": [ + (1, 400, "foo newest"), + (2, 300, "foo skipped"), + (3, 200, "foo older"), + (4, 100, "foo bad"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_build_search_entry = mcp_server._build_search_entry + + def flaky_build_search_entry(row, *args, **kwargs): + if row[2] == 300: + return None + if row[2] == 100: + raise ValueError("bad row") + return original_build_search_entry(row, *args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry + ): + result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0) + + self.assertIn("foo newest", result) + self.assertIn("foo older", result) + self.assertIn("查询失败:", result) + self.assertIn("bad row", result) + + def test_get_recent_sessions_closes_connection_when_query_fails(self): + # 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。 + fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"}) + + class _FakeConn: + def __init__(self): + self.closed = False + + def execute(self, *args, **kwargs): + raise sqlite3.OperationalError("boom") + + def close(self): + self.closed = True + + fake_conn = _FakeConn() + + with patch.object(mcp_server, "_cache", fake_cache), patch.object( + mcp_server, "get_contact_names", return_value={} + ), patch.object( + mcp_server.sqlite3, "connect", return_value=fake_conn + ): + with self.assertRaisesRegex(sqlite3.OperationalError, "boom"): + mcp_server.get_recent_sessions() + + self.assertTrue(fake_conn.closed) + + def test_get_new_messages_closes_connection_when_query_fails(self): + # 新消息轮询失败时也要释放 sqlite3 连接。 + fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"}) + + class _FakeConn: + def __init__(self): + self.closed = False + + def execute(self, *args, **kwargs): + raise sqlite3.OperationalError("boom") + + def close(self): + self.closed = True + + fake_conn = _FakeConn() + + with patch.object(mcp_server, "_cache", fake_cache), patch.object( + mcp_server, "get_contact_names", return_value={} + ), patch.object( + mcp_server.sqlite3, "connect", return_value=fake_conn + ): + with self.assertRaisesRegex(sqlite3.OperationalError, "boom"): + mcp_server.get_new_messages() + + self.assertTrue(fake_conn.closed) + if __name__ == "__main__": unittest.main() From 7c42ff5d38bb214b20d74e23e83e9a30c2a19eae Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 16:59:17 +0800 Subject: [PATCH 5/5] Investigate get_chat_history limit --- README.md | 2 +- USAGE.md | 2 +- mcp_server.py | 65 ++++++++++++++-------- tests/test_mcp_server_search.py | 96 +++++++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 8807e33..b63a550 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv 前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。 -说明:`get_chat_history` 和 `search_messages` 的 `limit` 最大为 `500`。 +说明:`search_messages` 的 `limit` 最大为 `500`;`get_chat_history` 支持更大的 `limit`,但消息很多时仍建议配合 `offset` 分页读取。 **[查看使用案例 →](USAGE.md)** diff --git a/USAGE.md b/USAGE.md index c193191..0d8717c 100644 --- a/USAGE.md +++ b/USAGE.md @@ -191,7 +191,7 @@ Claude 可以获取大量消息后自动分析活跃度、话题分布、关键 > 帮我分析一下██群最近一周的情况 ``` -Claude 会调用 `get_chat_history(chat_name="██群", limit=500)` 获取消息,然后输出: +Claude 会调用 `get_chat_history(chat_name="██群", limit=500)` 获取消息,然后输出。消息很多时,也可以把 `limit` 设得更大,或配合 `offset` 分页读取: ``` ## ██群最近一周分析 diff --git a/mcp_server.py b/mcp_server.py index f1c1547..8ef1f05 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -224,6 +224,7 @@ _self_username = None _XML_UNSAFE_RE = re.compile(r' _QUERY_LIMIT_MAX: - raise ValueError(f"limit 不能大于 {_QUERY_LIMIT_MAX}") + if limit_max is not None and limit > limit_max: + raise ValueError(f"limit 不能大于 {limit_max}") if offset < 0: raise ValueError("offset 不能小于 0") @@ -841,8 +842,6 @@ def _build_history_line(row, ctx, names, id_to_username): sender, text = _format_message_text( local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names ) - if text and len(text) > 500: - text = text[:500] + '...' sender_label = _resolve_sender_label( real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username @@ -880,6 +879,10 @@ def _message_query_batch_size(candidate_limit): return candidate_limit +def _history_query_batch_size(candidate_limit): + return min(candidate_limit, _HISTORY_QUERY_BATCH_SIZE) + + def _page_ranked_entries(entries, limit, offset): ordered = sorted(entries, key=lambda item: item[0], reverse=True) paged = ordered[offset:offset + limit] @@ -891,22 +894,40 @@ def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20 collected = [] failures = [] candidate_limit = _candidate_page_size(limit, offset) + batch_size = _history_query_batch_size(candidate_limit) for table_ctx in _iter_table_contexts(ctx): try: with closing(sqlite3.connect(table_ctx['db_path'])) as conn: id_to_username = _load_name2id_maps(conn) + fetch_offset = 0 + collected_before_table = len(collected) # 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。 - rows = _query_messages( - conn, - table_ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - limit=candidate_limit, - offset=0, - ) - for row in rows: - collected.append(_build_history_line(row, table_ctx, names, id_to_username)) + while len(collected) - collected_before_table < candidate_limit: + rows = _query_messages( + conn, + table_ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + limit=batch_size, + offset=fetch_offset, + ) + if not rows: + break + fetch_offset += len(rows) + + for row in rows: + try: + collected.append(_build_history_line(row, table_ctx, names, id_to_username)) + except Exception as e: + failures.append( + f"{table_ctx['display_name']} local_id={row[0]} create_time={row[2]}: {e}" + ) + if len(collected) - collected_before_table >= candidate_limit: + break + + if len(rows) < batch_size: + break except Exception as e: failures.append(f"{table_ctx['db_path']}: {e}") @@ -1221,16 +1242,16 @@ def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_tim Args: chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配 - limit: 返回的消息数量,默认50,最大500 + limit: 返回的消息数量,默认50;支持较大的值,建议配合 offset 分页使用 offset: 分页偏移量,默认0 start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS - """ - try: - _validate_pagination(limit, offset) - start_ts, end_ts = _parse_time_range(start_time, end_time) - except ValueError as e: - return f"错误: {e}" + """ + try: + _validate_pagination(limit, offset, limit_max=None) + start_ts, end_ts = _parse_time_range(start_time, end_time) + except ValueError as e: + return f"错误: {e}" ctx = _resolve_chat_context(chat_name) if not ctx: diff --git a/tests/test_mcp_server_search.py b/tests/test_mcp_server_search.py index 859d996..49b9ec7 100644 --- a/tests/test_mcp_server_search.py +++ b/tests/test_mcp_server_search.py @@ -72,6 +72,10 @@ class SearchMessagesTests(unittest.TestCase): with self.assertRaisesRegex(ValueError, "limit 不能大于 500"): mcp_server._validate_pagination(501, 0) + def test_validate_pagination_allows_large_limit_when_limit_is_unbounded(self): + # get_chat_history 允许更大的 limit,只校验正数和 offset。 + mcp_server._validate_pagination(999999, 0, limit_max=None) + def test_page_search_entries_returns_chronological_results_with_offset(self): # 结果应先按最新时间分页,再把当前页恢复成时间正序输出。 entries = [(1, "a"), (5, "e"), (3, "c"), (4, "d"), (2, "b")] @@ -377,6 +381,43 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("new message", result) self.assertNotIn("old message", result) + def test_get_chat_history_large_limit_reads_all_rows_across_shards(self): + # 大 limit 下,跨分片历史查询不能只返回较旧分片里的少量消息。 + older_messages = [ + (index, 1000 + index, f"old shard message {index}") + for index in range(1, 18) + ] + newer_messages = [ + (index, 2000 + index, f"new shard message {index}") + for index in range(1, 296) + ] + older_db = self.create_db("history_cross_shard_older.db", {"alice": older_messages}) + newer_db = self.create_db("history_cross_shard_newer.db", {"alice": newer_messages}) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": newer_db, + "table_name": _msg_table_name("alice"), + "message_tables": [ + {"db_path": newer_db, "table_name": _msg_table_name("alice")}, + {"db_path": older_db, "table_name": _msg_table_name("alice")}, + ], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ): + result = mcp_server.get_chat_history("Alice", limit=500, offset=0) + + self.assertIn("Alice 的消息记录(返回 312 条,offset=0, limit=500)", result) + self.assertIn("new shard message 295", result) + self.assertIn("old shard message 17", result) + + body = result.split(":\n\n", 1)[1] + self.assertEqual(len(body.splitlines()), 312) + def test_get_chat_history_uses_bounded_sql_pagination(self): # 历史查询应把 offset+limit 下推到 SQL,避免把整张消息表读出来后再切片。 db_path = self.create_db( @@ -419,6 +460,36 @@ class SearchMessagesTests(unittest.TestCase): self.assertNotIn("oldest", result) self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)]) + def test_get_chat_history_allows_large_limit_values(self): + # 历史查询不应再把大 limit 直接拒绝掉。 + db_path = self.create_db( + "history_large_limit.db", + { + "alice": [ + (1, 200, "message 1"), + (2, 100, "message 2"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ): + result = mcp_server.get_chat_history("Alice", limit=999999, offset=0) + + self.assertNotIn("错误:", result) + self.assertIn("message 1", result) + self.assertIn("message 2", result) + def test_get_chat_history_keeps_partial_results_when_formatting_fails(self): # 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。 db_path = self.create_db( @@ -452,6 +523,31 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("查询失败:", result) self.assertIn("bad row", result) + def test_get_chat_history_does_not_truncate_long_messages(self): + # 历史记录应返回完整消息内容,而不是固定截断到 500 字符。 + long_message = "x" * 600 + db_path = self.create_db( + "history_long_message.db", + {"alice": [(1, 200, long_message)]}, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ): + result = mcp_server.get_chat_history("Alice", limit=1, offset=0) + + self.assertIn(long_message, result) + self.assertNotIn(("x" * 500) + "...", result) + def test_search_messages_single_chat_merges_sharded_message_tables(self): # 单聊搜索也要跨分片合并,否则最近消息可能查不到。 older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]})