Merge pull request #2 from dsjzazs/codex/searchmessages

Add unit tests for MCP search and fix pagination
feat/daemon-cli
dsjzazs 2026-03-14 16:39:12 +08:00 committed by GitHub
commit 2cd180c63a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1115 additions and 339 deletions

View File

@ -139,6 +139,8 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv
前置条件:需要先运行 `python main.py``python find_all_keys.py` 完成密钥提取。 前置条件:需要先运行 `python main.py``python find_all_keys.py` 完成密钥提取。
说明:`get_chat_history` 和 `search_messages``limit` 最大为 `500`
**[查看使用案例 →](USAGE.md)** **[查看使用案例 →](USAGE.md)**
### 图片解密 (V2 格式) ### 图片解密 (V2 格式)

View File

@ -5,9 +5,10 @@ Based on FastMCP (stdio transport), reuses existing decryption.
Runs on Windows Python (needs access to D:\ WeChat databases). Runs on Windows Python (needs access to D:\ WeChat databases).
""" """
import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re
import hmac as hmac_mod import hmac as hmac_mod
from datetime import datetime from contextlib import closing
from datetime import datetime
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from Crypto.Cipher import AES from Crypto.Cipher import AES
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
@ -219,9 +220,10 @@ atexit.register(_cache.cleanup)
_contact_names = None # {username: display_name} _contact_names = None # {username: display_name}
_contact_full = None # [{username, nick_name, remark}] _contact_full = None # [{username, nick_name, remark}]
_self_username = None _self_username = None
_XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE) _XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE)
_XML_PARSE_MAX_LEN = 20000 _XML_PARSE_MAX_LEN = 20000
_QUERY_LIMIT_MAX = 500
def _load_contacts_from(db_path): def _load_contacts_from(db_path):
@ -568,12 +570,12 @@ MSG_DB_KEYS = sorted([
]) ])
def _find_msg_table_for_user(username): def _find_msg_table_for_user(username):
"""在所有 message_N.db 中查找用户的消息表,返回 (db_path, table_name)""" """在所有 message_N.db 中查找用户的消息表,返回 (db_path, table_name)"""
table_hash = hashlib.md5(username.encode()).hexdigest() table_hash = hashlib.md5(username.encode()).hexdigest()
table_name = f"Msg_{table_hash}" table_name = f"Msg_{table_hash}"
if not _is_safe_msg_table_name(table_name): if not _is_safe_msg_table_name(table_name):
return None, None return None, None
for rel_key in MSG_DB_KEYS: for rel_key in MSG_DB_KEYS:
path = _cache.get(rel_key) path = _cache.get(rel_key)
@ -592,15 +594,54 @@ def _find_msg_table_for_user(username):
pass pass
finally: finally:
conn.close() conn.close()
return None, None return None, None
def _find_msg_tables_for_user(username):
"""返回用户在所有 message_N.db 中对应的消息表,按最新消息时间倒序排列。"""
table_hash = hashlib.md5(username.encode()).hexdigest()
table_name = f"Msg_{table_hash}"
if not _is_safe_msg_table_name(table_name):
return []
matches = []
for rel_key in MSG_DB_KEYS:
path = _cache.get(rel_key)
if not path:
continue
conn = sqlite3.connect(path)
try:
exists = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
).fetchone()
if not exists:
continue
max_create_time = conn.execute(
f"SELECT MAX(create_time) FROM [{table_name}]"
).fetchone()[0] or 0
matches.append({
'db_path': path,
'table_name': table_name,
'max_create_time': max_create_time,
})
except Exception:
pass
finally:
conn.close()
matches.sort(key=lambda item: item['max_create_time'], reverse=True)
return matches
def _validate_pagination(limit, offset=0): def _validate_pagination(limit, offset=0):
if limit <= 0: if limit <= 0:
raise ValueError("limit 必须大于 0") raise ValueError("limit 必须大于 0")
if offset < 0: if limit > _QUERY_LIMIT_MAX:
raise ValueError("offset 不能小于 0") raise ValueError(f"limit 不能大于 {_QUERY_LIMIT_MAX}")
if offset < 0:
raise ValueError("offset 不能小于 0")
def _parse_time_value(value, field_name, is_end=False): def _parse_time_value(value, field_name, is_end=False):
@ -650,49 +691,54 @@ def _build_message_filters(start_ts=None, end_ts=None, keyword=''):
return clauses, params return clauses, params
def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0): def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0):
if not _is_safe_msg_table_name(table_name): if not _is_safe_msg_table_name(table_name):
raise ValueError(f'非法消息表名: {table_name}') raise ValueError(f'非法消息表名: {table_name}')
clauses, params = _build_message_filters(start_ts, end_ts, keyword) clauses, params = _build_message_filters(start_ts, end_ts, keyword)
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
sql = f""" sql = f"""
SELECT local_id, local_type, create_time, real_sender_id, message_content, SELECT local_id, local_type, create_time, real_sender_id, message_content,
WCDB_CT_message_content WCDB_CT_message_content
FROM [{table_name}] FROM [{table_name}]
{where_sql} {where_sql}
ORDER BY create_time DESC ORDER BY create_time DESC
LIMIT ? OFFSET ? """
""" if limit is None:
return conn.execute(sql, (*params, limit, offset)).fetchall() return conn.execute(sql, params).fetchall()
sql += "\n LIMIT ? OFFSET ?"
return conn.execute(sql, (*params, limit, offset)).fetchall()
def _resolve_chat_context(chat_name): def _resolve_chat_context(chat_name):
username = resolve_username(chat_name) username = resolve_username(chat_name)
if not username: if not username:
return None return None
names = get_contact_names() names = get_contact_names()
display_name = names.get(username, username) display_name = names.get(username, username)
db_path, table_name = _find_msg_table_for_user(username) message_tables = _find_msg_tables_for_user(username)
if not db_path: if not message_tables:
return { return {
'query': chat_name, 'query': chat_name,
'username': username, 'username': username,
'display_name': display_name, 'display_name': display_name,
'db_path': None, 'db_path': None,
'table_name': None, 'table_name': None,
'is_group': '@chatroom' in username, 'message_tables': [],
} 'is_group': '@chatroom' in username,
}
return {
'query': chat_name, primary = message_tables[0]
'username': username, return {
'display_name': display_name, 'query': chat_name,
'db_path': db_path, 'username': username,
'table_name': table_name, 'display_name': display_name,
'is_group': '@chatroom' in username, 'db_path': primary['db_path'],
} 'table_name': primary['table_name'],
'message_tables': message_tables,
'is_group': '@chatroom' in username,
}
def _resolve_chat_contexts(chat_names): def _resolve_chat_contexts(chat_names):
@ -713,11 +759,11 @@ def _resolve_chat_contexts(chat_names):
if not ctx: if not ctx:
unresolved.append(name) unresolved.append(name)
continue continue
if not ctx['db_path']: if not ctx['message_tables']:
missing_tables.append(ctx['display_name']) missing_tables.append(ctx['display_name'])
continue continue
if ctx['username'] in seen: if ctx['username'] in seen:
continue continue
seen.add(ctx['username']) seen.add(ctx['username'])
resolved.append(ctx) resolved.append(ctx)
@ -743,31 +789,20 @@ def _normalize_chat_names(chat_name):
return [value] if value else [] return [value] if value else []
def _format_history_lines(rows, username, display_name, is_group, names, id_to_username): def _format_history_lines(rows, username, display_name, is_group, names, id_to_username):
lines = [] lines = []
for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows): ctx = {
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') 'username': username,
content = _decompress_content(content, ct) 'display_name': display_name,
if content is None: 'is_group': is_group,
content = '(无法解压)' }
for row in reversed(rows):
sender, text = _format_message_text( _, line = _build_history_line(row, ctx, names, id_to_username)
local_id, local_type, content, is_group, username, display_name, names lines.append(line)
) return lines
if text and len(text) > 500:
text = text[:500] + '...'
sender_label = _resolve_sender_label(
real_sender_id, sender, is_group, username, display_name, names, id_to_username
)
if sender_label:
lines.append(f'[{time_str}] {sender_label}: {text}')
else:
lines.append(f'[{time_str}] {text}')
return lines
def _build_search_entry(row, ctx, names, id_to_username): def _build_search_entry(row, ctx, names, id_to_username):
local_id, local_type, create_time, real_sender_id, content, ct = row local_id, local_type, create_time, real_sender_id, content, ct = row
content = _decompress_content(content, ct) content = _decompress_content(content, ct)
if content is None: if content is None:
@ -792,8 +827,326 @@ def _build_search_entry(row, ctx, names, id_to_username):
entry = f"[{time_str}] [{ctx['display_name']}]" entry = f"[{time_str}] [{ctx['display_name']}]"
if sender_label: if sender_label:
entry += f" {sender_label}:" entry += f" {sender_label}:"
entry += f" {text}" entry += f" {text}"
return create_time, entry return create_time, entry
def _build_history_line(row, ctx, names, id_to_username):
local_id, local_type, create_time, real_sender_id, content, ct = row
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M')
content = _decompress_content(content, ct)
if content is None:
content = '(无法解压)'
sender, text = _format_message_text(
local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names
)
if text and len(text) > 500:
text = text[:500] + '...'
sender_label = _resolve_sender_label(
real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username
)
if sender_label:
return create_time, f'[{time_str}] {sender_label}: {text}'
return create_time, f'[{time_str}] {text}'
def _get_chat_message_tables(ctx):
if ctx.get('message_tables'):
return ctx['message_tables']
if ctx.get('db_path') and ctx.get('table_name'):
return [{'db_path': ctx['db_path'], 'table_name': ctx['table_name']}]
return []
def _iter_table_contexts(ctx):
for table in _get_chat_message_tables(ctx):
yield {
'query': ctx['query'],
'username': ctx['username'],
'display_name': ctx['display_name'],
'db_path': table['db_path'],
'table_name': table['table_name'],
'is_group': ctx['is_group'],
}
def _candidate_page_size(limit, offset):
return limit + offset
def _message_query_batch_size(candidate_limit):
return candidate_limit
def _page_ranked_entries(entries, limit, offset):
ordered = sorted(entries, key=lambda item: item[0], reverse=True)
paged = ordered[offset:offset + limit]
paged.sort(key=lambda item: item[0])
return paged
def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0):
collected = []
failures = []
candidate_limit = _candidate_page_size(limit, offset)
for table_ctx in _iter_table_contexts(ctx):
try:
with closing(sqlite3.connect(table_ctx['db_path'])) as conn:
id_to_username = _load_name2id_maps(conn)
# 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。
rows = _query_messages(
conn,
table_ctx['table_name'],
start_ts=start_ts,
end_ts=end_ts,
limit=candidate_limit,
offset=0,
)
for row in rows:
collected.append(_build_history_line(row, table_ctx, names, id_to_username))
except Exception as e:
failures.append(f"{table_ctx['db_path']}: {e}")
paged = _page_ranked_entries(collected, limit, offset)
return [line for _, line in paged], failures
def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
collected = []
failures = []
contexts_by_db = {}
for table_ctx in _iter_table_contexts(ctx):
contexts_by_db.setdefault(table_ctx['db_path'], []).append(table_ctx)
for db_path, db_contexts in contexts_by_db.items():
try:
with closing(sqlite3.connect(db_path)) as conn:
db_entries, db_failures = _collect_search_entries(
conn,
db_contexts,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(db_entries)
failures.extend(db_failures)
except Exception as e:
failures.extend(f"{table_ctx['display_name']}: {e}" for table_ctx in db_contexts)
return collected, failures
def _load_search_contexts_from_db(conn, db_path, names):
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'"
).fetchall()
table_to_username = {}
try:
for (user_name,) in conn.execute("SELECT user_name FROM Name2Id").fetchall():
if not user_name:
continue
table_hash = hashlib.md5(user_name.encode()).hexdigest()
table_to_username[f"Msg_{table_hash}"] = user_name
except sqlite3.Error:
pass
contexts = []
for (table_name,) in tables:
username = table_to_username.get(table_name, '')
display_name = names.get(username, username) if username else table_name
contexts.append({
'query': display_name,
'username': username,
'display_name': display_name,
'db_path': db_path,
'table_name': table_name,
'is_group': '@chatroom' in username,
})
return contexts
def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
collected = []
failures = []
id_to_username = _load_name2id_maps(conn)
batch_size = _message_query_batch_size(candidate_limit)
for ctx in contexts:
try:
fetch_offset = 0
collected_before_table = len(collected)
# 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。
while len(collected) - collected_before_table < candidate_limit:
rows = _query_messages(
conn,
ctx['table_name'],
start_ts=start_ts,
end_ts=end_ts,
keyword=keyword,
limit=batch_size,
offset=fetch_offset,
)
if not rows:
break
fetch_offset += len(rows)
for row in rows:
formatted = _build_search_entry(row, ctx, names, id_to_username)
if formatted:
collected.append(formatted)
if len(collected) - collected_before_table >= candidate_limit:
break
if len(rows) < batch_size:
break
except Exception as e:
failures.append(f"{ctx['display_name']}: {e}")
return collected, failures
def _page_search_entries(entries, limit, offset):
return _page_ranked_entries(entries, limit, offset)
def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
names = get_contact_names()
candidate_limit = _candidate_page_size(limit, offset)
entries, failures = _collect_chat_search_entries(
ctx,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
paged = _page_search_entries(entries, limit, offset)
if not paged:
if failures:
return "查询失败: " + "".join(failures)
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息"
header = f"{ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(paged)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
try:
resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names)
except ValueError as e:
return f"错误: {e}"
if not resolved_contexts:
details = []
if unresolved:
details.append("未找到联系人: " + "".join(unresolved))
if missing_tables:
details.append("无消息表: " + "".join(missing_tables))
suffix = f"\n{chr(10).join(details)}" if details else ""
return f"错误: 没有可查询的聊天对象{suffix}"
names = get_contact_names()
candidate_limit = _candidate_page_size(limit, offset)
collected = []
failures = []
for ctx in resolved_contexts:
chat_entries, chat_failures = _collect_chat_search_entries(
ctx,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(chat_entries)
failures.extend(chat_failures)
paged = _page_search_entries(collected, limit, offset)
notes = []
if unresolved:
notes.append("未找到联系人: " + "".join(unresolved))
if missing_tables:
notes.append("无消息表: " + "".join(missing_tables))
if failures:
notes.append("查询失败: " + "".join(failures))
if not paged:
header = f"{len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header
header = (
f"{len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果"
f"offset={offset}, limit={limit}"
)
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, offset):
names = get_contact_names()
collected = []
failures = []
candidate_limit = _candidate_page_size(limit, offset)
for rel_key in MSG_DB_KEYS:
path = _cache.get(rel_key)
if not path:
continue
try:
with closing(sqlite3.connect(path)) as conn:
contexts = _load_search_contexts_from_db(conn, path, names)
db_entries, db_failures = _collect_search_entries(
conn,
contexts,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(db_entries)
failures.extend(db_failures)
except Exception as e:
failures.append(f"{rel_key}: {e}")
paged = _page_search_entries(collected, limit, offset)
if not paged:
header = f"未找到包含 \"{keyword}\" 的消息"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header
header = f"搜索 \"{keyword}\" 找到 {len(paged)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
# ============ MCP Server ============ # ============ MCP Server ============
@ -805,7 +1158,7 @@ _last_check_state = {} # {username: last_timestamp}
@mcp.tool() @mcp.tool()
def get_recent_sessions(limit: int = 20) -> str: def get_recent_sessions(limit: int = 20) -> str:
"""获取微信最近会话列表,包含最新消息摘要、未读数、时间等。 """获取微信最近会话列表,包含最新消息摘要、未读数、时间等。
用于了解最近有哪些人/群在聊天 用于了解最近有哪些人/群在聊天
@ -817,16 +1170,15 @@ def get_recent_sessions(limit: int = 20) -> str:
return "错误: 无法解密 session.db" return "错误: 无法解密 session.db"
names = get_contact_names() names = get_contact_names()
conn = sqlite3.connect(path) with closing(sqlite3.connect(path)) as conn:
rows = conn.execute(""" rows = conn.execute("""
SELECT username, unread_count, summary, last_timestamp, SELECT username, unread_count, summary, last_timestamp,
last_msg_type, last_msg_sender, last_sender_display_name last_msg_type, last_msg_sender, last_sender_display_name
FROM SessionTable FROM SessionTable
WHERE last_timestamp > 0 WHERE last_timestamp > 0
ORDER BY last_timestamp DESC ORDER BY last_timestamp DESC
LIMIT ? LIMIT ?
""", (limit,)).fetchall() """, (limit,)).fetchall()
conn.close()
results = [] results = []
for r in rows: for r in rows:
@ -864,15 +1216,15 @@ def get_recent_sessions(limit: int = 20) -> str:
@mcp.tool() @mcp.tool()
def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str: def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str:
"""获取指定聊天的消息记录。 """获取指定聊天的消息记录。
Args: Args:
chat_name: 聊天对象的名字备注名或wxid自动模糊匹配 chat_name: 聊天对象的名字备注名或wxid自动模糊匹配
limit: 返回的消息数量默认50 limit: 返回的消息数量默认50最大500
offset: 分页偏移量默认0 offset: 分页偏移量默认0
start_time: 起始时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS start_time: 起始时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
end_time: 结束时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS end_time: 结束时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
""" """
try: try:
_validate_pagination(limit, offset) _validate_pagination(limit, offset)
@ -886,41 +1238,29 @@ def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_tim
if not ctx['db_path']: if not ctx['db_path']:
return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息" return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息"
names = get_contact_names() names = get_contact_names()
conn = sqlite3.connect(ctx['db_path']) lines, failures = _collect_chat_history_lines(
try: ctx,
id_to_username = _load_name2id_maps(conn) names,
rows = _query_messages( start_ts=start_ts,
conn, end_ts=end_ts,
ctx['table_name'], limit=limit,
start_ts=start_ts, offset=offset,
end_ts=end_ts, )
limit=limit,
offset=offset, if not lines:
) if failures:
except Exception as e: return "查询失败: " + "".join(failures)
conn.close() return f"{ctx['display_name']} 无消息记录"
return f"查询失败: {e}"
conn.close() header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)}offset={offset}, limit={limit}"
if ctx['is_group']:
if not rows: header += " [群聊]"
return f"{ctx['display_name']} 无消息记录" if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
lines = _format_history_lines( if failures:
rows, header += "\n查询失败: " + "".join(failures)
ctx['username'], return header + ":\n\n" + "\n".join(lines)
ctx['display_name'],
ctx['is_group'],
names,
id_to_username,
)
header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)}offset={offset}, limit={limit}"
if ctx['is_group']:
header += " [群聊]"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
return header + ":\n\n" + "\n".join(lines)
@mcp.tool() @mcp.tool()
@ -939,7 +1279,7 @@ def search_messages(
chat_name: 聊天对象名称可为空单个字符串或字符串列表 chat_name: 聊天对象名称可为空单个字符串或字符串列表
start_time: 起始时间可为空 start_time: 起始时间可为空
end_time: 结束时间可为空 end_time: 结束时间可为空
limit: 返回的结果数量默认20 limit: 返回的结果数量默认20最大500
offset: 分页偏移量默认0 offset: 分页偏移量默认0
""" """
if not keyword or len(keyword) < 1: if not keyword or len(keyword) < 1:
@ -959,193 +1299,38 @@ def search_messages(
return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人" return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人"
if not ctx['db_path']: if not ctx['db_path']:
return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息" return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息"
return _search_single_chat(
names = get_contact_names() ctx,
conn = sqlite3.connect(ctx['db_path']) keyword,
try: start_ts,
id_to_username = _load_name2id_maps(conn) end_ts,
rows = _query_messages( start_time,
conn, end_time,
ctx['table_name'], limit,
start_ts=start_ts, offset,
end_ts=end_ts, )
keyword=keyword,
limit=limit,
offset=offset,
)
except Exception as e:
conn.close()
return f"查询失败: {e}"
conn.close()
if not rows:
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息"
entries = []
for row in rows:
formatted = _build_search_entry(row, ctx, names, id_to_username)
if formatted:
entries.append(formatted)
if not entries:
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的可读消息"
entries.sort(key=lambda x: x[0])
header = f"{ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(entries)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
return header + ":\n\n" + "\n\n".join(item[1] for item in entries)
if len(chat_names) > 1: if len(chat_names) > 1:
try: return _search_multiple_chats(
resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names) chat_names,
except ValueError as e: keyword,
return f"错误: {e}" start_ts,
end_ts,
if not resolved_contexts: start_time,
details = [] end_time,
if unresolved: limit,
details.append("未找到联系人: " + "".join(unresolved)) offset,
if missing_tables:
details.append("无消息表: " + "".join(missing_tables))
suffix = f"\n{chr(10).join(details)}" if details else ""
return f"错误: 没有可查询的聊天对象{suffix}"
names = get_contact_names()
collected = []
failures = []
per_chat_limit = limit + offset
for ctx in resolved_contexts:
conn = sqlite3.connect(ctx['db_path'])
try:
id_to_username = _load_name2id_maps(conn)
rows = _query_messages(
conn,
ctx['table_name'],
start_ts=start_ts,
end_ts=end_ts,
keyword=keyword,
limit=per_chat_limit,
offset=0,
)
for row in rows:
formatted = _build_search_entry(row, ctx, names, id_to_username)
if formatted:
collected.append(formatted)
except Exception as e:
failures.append(f"{ctx['display_name']}: {e}")
finally:
conn.close()
collected.sort(key=lambda x: x[0], reverse=True)
paged = collected[offset:offset + limit]
notes = []
if unresolved:
notes.append("未找到联系人: " + "".join(unresolved))
if missing_tables:
notes.append("无消息表: " + "".join(missing_tables))
if failures:
notes.append("查询失败: " + "".join(failures))
if not paged:
header = f"{len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header
header = (
f"{len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果"
f"offset={offset}, limit={limit}"
) )
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
names = get_contact_names() return _search_all_messages(
results = [] keyword,
max_results = limit + offset start_ts,
end_ts,
for rel_key in MSG_DB_KEYS: start_time,
if len(results) >= max_results: end_time,
break limit,
offset,
path = _cache.get(rel_key) )
if not path:
continue
conn = sqlite3.connect(path)
try:
# 获取所有 Msg_ 表
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'"
).fetchall()
# 获取 Name2Id 映射hash -> username 反查)
name2id = {}
try:
for r in conn.execute("SELECT user_name FROM Name2Id").fetchall():
h = hashlib.md5(r[0].encode()).hexdigest()
name2id[f"Msg_{h}"] = r[0]
except Exception:
pass
for (tname,) in tables:
if len(results) >= max_results:
break
username = name2id.get(tname, '')
is_group = '@chatroom' in username
display = names.get(username, username) if username else tname
try:
clauses, params = _build_message_filters(start_ts, end_ts, keyword)
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
rows = conn.execute(f"""
SELECT local_type, create_time, message_content,
WCDB_CT_message_content
FROM [{tname}]
{where_sql}
ORDER BY create_time DESC
LIMIT ? OFFSET ?
""", (*params, max_results - len(results), 0)).fetchall()
except Exception:
continue
for local_type, ts, content, ct in rows:
content = _decompress_content(content, ct)
if content is None:
continue
sender, text = _parse_message_content(content, local_type, is_group)
time_str = datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M')
sender_name = ''
if is_group and sender:
sender_name = names.get(sender, sender)
entry = f"[{time_str}] [{display}]"
if sender_name:
entry += f" {sender_name}:"
entry += f" {text}"
if len(entry) > 300:
entry = entry[:300] + "..."
results.append((ts, entry))
finally:
conn.close()
results.sort(key=lambda x: x[0], reverse=True)
entries = [r[1] for r in results[offset:offset + limit]]
if not entries:
return f"未找到包含 \"{keyword}\" 的消息"
header = f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
return header + ":\n\n" + "\n\n".join(entries)
@mcp.tool() @mcp.tool()
def get_contacts(query: str = "", limit: int = 50) -> str: def get_contacts(query: str = "", limit: int = 50) -> str:
@ -1191,7 +1376,7 @@ def get_contacts(query: str = "", limit: int = 50) -> str:
@mcp.tool() @mcp.tool()
def get_new_messages() -> str: def get_new_messages() -> str:
"""获取自上次调用以来的新消息。首次调用返回最近的会话状态。""" """获取自上次调用以来的新消息。首次调用返回最近的会话状态。"""
global _last_check_state global _last_check_state
@ -1200,15 +1385,14 @@ def get_new_messages() -> str:
return "错误: 无法解密 session.db" return "错误: 无法解密 session.db"
names = get_contact_names() names = get_contact_names()
conn = sqlite3.connect(path) with closing(sqlite3.connect(path)) as conn:
rows = conn.execute(""" rows = conn.execute("""
SELECT username, unread_count, summary, last_timestamp, SELECT username, unread_count, summary, last_timestamp,
last_msg_type, last_msg_sender, last_sender_display_name last_msg_type, last_msg_sender, last_sender_display_name
FROM SessionTable FROM SessionTable
WHERE last_timestamp > 0 WHERE last_timestamp > 0
ORDER BY last_timestamp DESC ORDER BY last_timestamp DESC
""").fetchall() """).fetchall()
conn.close()
curr_state = {} curr_state = {}
for r in rows: for r in rows:

View File

@ -0,0 +1,590 @@
import hashlib
import os
import sqlite3
import tempfile
import unittest
from unittest.mock import patch
import mcp_server
class _FakeCache:
# 用最小缓存桩替代真实解密缓存,避免单元测试依赖本地微信环境。
def __init__(self, mapping):
self._mapping = mapping
def get(self, rel_key):
return self._mapping.get(rel_key)
def _msg_table_name(username):
# 生产代码使用 username 的 md5 作为消息表名,测试里保持一致。
return f"Msg_{hashlib.md5(username.encode()).hexdigest()}"
def _create_message_db(path, chats):
# 构造最小可用消息库,只包含搜索/历史查询依赖的字段。
conn = sqlite3.connect(path)
try:
conn.execute("CREATE TABLE Name2Id (user_name TEXT)")
for username, messages in chats.items():
conn.execute("INSERT INTO Name2Id(user_name) VALUES (?)", (username,))
table_name = _msg_table_name(username)
conn.execute(
f"""
CREATE TABLE [{table_name}] (
local_id INTEGER,
local_type INTEGER,
create_time INTEGER,
real_sender_id INTEGER,
message_content TEXT,
WCDB_CT_message_content INTEGER
)
"""
)
for local_id, create_time, content in messages:
conn.execute(
f"""
INSERT INTO [{table_name}] (
local_id, local_type, create_time, real_sender_id,
message_content, WCDB_CT_message_content
) VALUES (?, ?, ?, ?, ?, ?)
""",
(local_id, 1, create_time, 0, content, 0),
)
conn.commit()
finally:
conn.close()
class SearchMessagesTests(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.addCleanup(self.temp_dir.cleanup)
def create_db(self, filename, chats):
path = os.path.join(self.temp_dir.name, filename)
_create_message_db(path, chats)
return path
def test_validate_pagination_rejects_large_limit(self):
# 防止单次查询过大,保证 limit 上限校验存在。
with self.assertRaisesRegex(ValueError, "limit 不能大于 500"):
mcp_server._validate_pagination(501, 0)
def test_page_search_entries_returns_chronological_results_with_offset(self):
# 结果应先按最新时间分页,再把当前页恢复成时间正序输出。
entries = [(1, "a"), (5, "e"), (3, "c"), (4, "d"), (2, "b")]
paged = mcp_server._page_search_entries(entries, limit=2, offset=1)
self.assertEqual(paged, [(3, "c"), (4, "d")])
def test_search_messages_single_chat_uses_offset_and_returns_page(self):
# 单聊分页应只返回当前页,并按聊天阅读顺序展示。
db_path = self.create_db(
"single.db",
{
"alice": [
(1, 100, "foo newest"),
(2, 90, "foo middle"),
(3, 80, "foo oldest"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.search_messages("foo", chat_name="Alice", limit=2, offset=1)
self.assertIn('在 Alice 中搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertLess(result.index("foo oldest"), result.index("foo middle"))
self.assertNotIn("foo newest", result)
def test_search_messages_multiple_chats_applies_global_pagination(self):
# 多个聊天联合搜索时,分页必须基于合并后的全局结果。
db_path = self.create_db(
"multi.db",
{
"alice": [
(1, 110, "foo a1"),
(2, 90, "foo a2"),
],
"bob": [
(1, 100, "foo b1"),
(2, 80, "foo b2"),
],
},
)
contexts = [
{
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
},
{
"query": "Bob",
"username": "bob",
"display_name": "Bob",
"db_path": db_path,
"table_name": _msg_table_name("bob"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
"is_group": False,
},
]
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
):
result = mcp_server.search_messages("foo", chat_name=["Alice", "Bob"], limit=2, offset=1)
self.assertIn('在 2 个聊天对象中搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertLess(result.index("foo a2"), result.index("foo b1"))
self.assertNotIn("foo a1", result)
self.assertNotIn("foo b2", result)
def test_search_messages_all_messages_merges_global_results_before_paging(self):
# 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。
older_db = self.create_db(
"older.db",
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
)
newer_db = self.create_db(
"newer.db",
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2")]},
)
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"older_user": "Older", "newer_user": "Newer"},
):
result = mcp_server.search_messages("foo", limit=2, offset=0)
self.assertIn('搜索 "foo" 找到 2 条结果offset=0, limit=2', result)
self.assertLess(result.index("foo newer 2"), result.index("foo newer 1"))
self.assertNotIn("foo older 1", result)
def test_search_messages_all_messages_uses_bounded_sql_pagination(self):
# 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。
older_db = self.create_db(
"older_paged.db",
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
)
newer_db = self.create_db(
"newer_paged.db",
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]},
)
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
original_query_messages = mcp_server._query_messages
calls = []
def recording_query_messages(*args, **kwargs):
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
return original_query_messages(*args, **kwargs)
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"older_user": "Older", "newer_user": "Newer"},
), patch.object(
mcp_server, "_query_messages", side_effect=recording_query_messages
):
result = mcp_server.search_messages("foo", limit=2, offset=1)
self.assertIn('搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertEqual(
calls,
[
(_msg_table_name("older_user"), 3, 0),
(_msg_table_name("newer_user"), 3, 0),
],
)
def test_search_messages_single_chat_respects_time_range(self):
# 单聊搜索的开始/结束时间都必须严格生效。
db_path = self.create_db(
"single_time.db",
{
"alice": [
(1, 300, "foo in range"),
(2, 200, "foo too early"),
(3, 400, "foo too late"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 350)
):
result = mcp_server.search_messages(
"foo",
chat_name="Alice",
start_time="custom-start",
end_time="custom-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: custom-start ~ custom-end", result)
self.assertIn("foo in range", result)
self.assertNotIn("foo too early", result)
self.assertNotIn("foo too late", result)
def test_search_messages_multiple_chats_respects_time_range(self):
# 多聊联合搜索时,每个聊天对象都要套用同一时间范围。
db_path = self.create_db(
"multi_time.db",
{
"alice": [(1, 300, "foo alice in range"), (2, 150, "foo alice too early")],
"bob": [(1, 320, "foo bob in range"), (2, 500, "foo bob too late")],
},
)
contexts = [
{
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
},
{
"query": "Bob",
"username": "bob",
"display_name": "Bob",
"db_path": db_path,
"table_name": _msg_table_name("bob"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
"is_group": False,
},
]
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 400)
):
result = mcp_server.search_messages(
"foo",
chat_name=["Alice", "Bob"],
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: range-start ~ range-end", result)
self.assertIn("foo alice in range", result)
self.assertIn("foo bob in range", result)
self.assertNotIn("foo alice too early", result)
self.assertNotIn("foo bob too late", result)
def test_search_messages_all_messages_respects_time_range(self):
# 全库搜索也不能返回时间范围外的消息。
db_path = self.create_db(
"all_time.db",
{
"alice": [
(1, 100, "foo too early"),
(2, 300, "foo in range"),
(3, 500, "foo too late"),
]
},
)
fake_cache = _FakeCache({"all": db_path})
with patch.object(mcp_server, "MSG_DB_KEYS", ["all"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"alice": "Alice"},
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 350)
):
result = mcp_server.search_messages(
"foo",
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: range-start ~ range-end", result)
self.assertIn("foo in range", result)
self.assertNotIn("foo too early", result)
self.assertNotIn("foo too late", result)
def test_get_chat_history_merges_sharded_message_tables(self):
# 同一联系人跨多个 message_N.db 分片时,历史查询要先合并再分页。
older_db = self.create_db("history_older.db", {"alice": [(1, 100, "old message")]})
newer_db = self.create_db(
"history_newer.db",
{"alice": [(1, 300, "new message"), (2, 250, "middle message")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": newer_db,
"table_name": _msg_table_name("alice"),
"message_tables": [
{"db_path": older_db, "table_name": _msg_table_name("alice")},
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
self.assertIn("Alice 的消息记录(返回 2 条offset=0, limit=2", result)
self.assertIn("middle message", result)
self.assertIn("new message", result)
self.assertNotIn("old message", result)
def test_get_chat_history_uses_bounded_sql_pagination(self):
# 历史查询应把 offset+limit 下推到 SQL避免把整张消息表读出来后再切片。
db_path = self.create_db(
"history_paged.db",
{
"alice": [
(1, 400, "newest"),
(2, 300, "middle"),
(3, 200, "older"),
(4, 100, "oldest"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_query_messages = mcp_server._query_messages
calls = []
def recording_query_messages(*args, **kwargs):
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
return original_query_messages(*args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_query_messages", side_effect=recording_query_messages
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=1)
self.assertIn("middle", result)
self.assertIn("older", result)
self.assertNotIn("newest", result)
self.assertNotIn("oldest", result)
self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)])
def test_get_chat_history_keeps_partial_results_when_formatting_fails(self):
# 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。
db_path = self.create_db(
"history_partial_failure.db",
{"alice": [(1, 200, "good message"), (2, 100, "bad message")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_build_history_line = mcp_server._build_history_line
def flaky_build_history_line(row, *args, **kwargs):
if row[2] == 100:
raise ValueError("bad row")
return original_build_history_line(row, *args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_build_history_line", side_effect=flaky_build_history_line
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
self.assertIn("good message", result)
self.assertIn("查询失败:", result)
self.assertIn("bad row", result)
def test_search_messages_single_chat_merges_sharded_message_tables(self):
# 单聊搜索也要跨分片合并,否则最近消息可能查不到。
older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]})
newer_db = self.create_db(
"search_newer.db",
{"alice": [(1, 300, "foo new"), (2, 200, "foo middle")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": newer_db,
"table_name": _msg_table_name("alice"),
"message_tables": [
{"db_path": older_db, "table_name": _msg_table_name("alice")},
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_parse_time_range", return_value=(150, 350)
):
result = mcp_server.search_messages(
"foo",
chat_name="Alice",
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("foo middle", result)
self.assertIn("foo new", result)
self.assertNotIn("foo old", result)
def test_search_messages_keeps_partial_results_when_later_batch_fails(self):
# 后续批次失败时,前面已经拿到的有效结果不应被丢弃。
db_path = self.create_db(
"search_partial_failure.db",
{
"alice": [
(1, 400, "foo newest"),
(2, 300, "foo skipped"),
(3, 200, "foo older"),
(4, 100, "foo bad"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_build_search_entry = mcp_server._build_search_entry
def flaky_build_search_entry(row, *args, **kwargs):
if row[2] == 300:
return None
if row[2] == 100:
raise ValueError("bad row")
return original_build_search_entry(row, *args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry
):
result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0)
self.assertIn("foo newest", result)
self.assertIn("foo older", result)
self.assertIn("查询失败:", result)
self.assertIn("bad row", result)
def test_get_recent_sessions_closes_connection_when_query_fails(self):
# 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
class _FakeConn:
def __init__(self):
self.closed = False
def execute(self, *args, **kwargs):
raise sqlite3.OperationalError("boom")
def close(self):
self.closed = True
fake_conn = _FakeConn()
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
mcp_server, "get_contact_names", return_value={}
), patch.object(
mcp_server.sqlite3, "connect", return_value=fake_conn
):
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
mcp_server.get_recent_sessions()
self.assertTrue(fake_conn.closed)
def test_get_new_messages_closes_connection_when_query_fails(self):
# 新消息轮询失败时也要释放 sqlite3 连接。
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
class _FakeConn:
def __init__(self):
self.closed = False
def execute(self, *args, **kwargs):
raise sqlite3.OperationalError("boom")
def close(self):
self.closed = True
fake_conn = _FakeConn()
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
mcp_server, "get_contact_names", return_value={}
), patch.object(
mcp_server.sqlite3, "connect", return_value=fake_conn
):
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
mcp_server.get_new_messages()
self.assertTrue(fake_conn.closed)
if __name__ == "__main__":
unittest.main()