Merge pull request #30 from dsjzazs/main

MCP增强消息查询,支持时间范围和分页
feat/daemon-cli
joshua-deng 2026-03-14 17:38:37 +08:00 committed by GitHub
commit 3e79c8e093
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1635 additions and 342 deletions

View File

@ -132,13 +132,15 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv
| Tool | 功能 | | Tool | 功能 |
|------|------| |------|------|
| `get_recent_sessions(limit)` | 最近会话列表(含消息摘要、未读数) | | `get_recent_sessions(limit)` | 最近会话列表(含消息摘要、未读数) |
| `get_chat_history(chat_name, limit)` | 指定聊天的消息记录(支持模糊匹配名字) | | `get_chat_history(chat_name, limit, offset, start_time, end_time)` | 指定聊天的消息记录,支持时间范围和分页 |
| `search_messages(keyword, limit)` | 全库搜索消息内容 | | `search_messages(keyword, chat_name, start_time, end_time, limit, offset)` | 统一搜索消息;支持全库、单个聊天对象、多个聊天对象、时间范围和分页 |
| `get_contacts(query, limit)` | 搜索/列出联系人 | | `get_contacts(query, limit)` | 搜索/列出联系人 |
| `get_new_messages()` | 获取自上次调用以来的新消息 | | `get_new_messages()` | 获取自上次调用以来的新消息 |
前置条件:需要先运行 `python main.py``python find_all_keys.py` 完成密钥提取。 前置条件:需要先运行 `python main.py``python find_all_keys.py` 完成密钥提取。
说明:`search_messages` 的 `limit` 最大为 `500``get_chat_history` 支持更大的 `limit`,但消息很多时仍建议配合 `offset` 分页读取。
**[查看使用案例 →](USAGE.md)** **[查看使用案例 →](USAGE.md)**
### 图片解密 (V2 格式) ### 图片解密 (V2 格式)

View File

@ -78,7 +78,77 @@ Claude 调用 `search_messages(keyword="claude")`
... ...
``` ```
## 4. 搜索联系人 ## 4. 时间范围 + 分页查看聊天记录
```
> 帮我看一下██群 3 月 1 日到 3 月 7 日的聊天,先给我前 20 条
```
Claude 可以调用:
```python
get_chat_history(
chat_name="██群",
start_time="2026-03-01",
end_time="2026-03-07",
limit=20,
offset=0,
)
```
下一页:
```python
get_chat_history(
chat_name="██群",
start_time="2026-03-01",
end_time="2026-03-07",
limit=20,
offset=20,
)
```
## 5. 搜索指定联系人/群聊在某个时间段内的消息
```
> 帮我搜一下██群这周谁提到过 Claude
```
Claude 可以调用统一接口:
```python
search_messages(
keyword="Claude",
chat_name="██群",
start_time="2026-03-01",
end_time="2026-03-07",
limit=20,
offset=0,
)
```
## 6. 多个联系人/群聊联合搜索
```
> 帮我看看联系人A、联系人B 和 ██项目群 这周谁提到过“项目”
```
Claude 可以调用统一接口:
```python
search_messages(
keyword="项目",
chat_name=["联系人A", "联系人B", "██项目群"],
start_time="2026-03-01",
end_time="2026-03-07",
limit=20,
offset=0,
)
```
如果某些名字没匹配到联系人,或没有对应消息表,结果里会单独说明。
## 7. 搜索联系人
``` ```
> 帮我找一下姓张的联系人 > 帮我找一下姓张的联系人
@ -95,7 +165,7 @@ wxid_████ 备注: 张██ 昵称: 小██
... ...
``` ```
## 5. 获取新消息 ## 8. 获取新消息
``` ```
> 有没有新消息 > 有没有新消息
@ -113,7 +183,7 @@ Claude 调用 `get_new_messages()`
[16:22] ██群 [群] (19条未读): (图片) [16:22] ██群 [群] (19条未读): (图片)
``` ```
## 6. 高级用法:群聊分析 ## 9. 高级用法:群聊分析
Claude 可以获取大量消息后自动分析活跃度、话题分布、关键人物: Claude 可以获取大量消息后自动分析活跃度、话题分布、关键人物:
@ -121,7 +191,7 @@ Claude 可以获取大量消息后自动分析活跃度、话题分布、关键
> 帮我分析一下██群最近一周的情况 > 帮我分析一下██群最近一周的情况
``` ```
Claude 会调用 `get_chat_history(chat_name="██群", limit=500)` 获取消息,然后输出: Claude 会调用 `get_chat_history(chat_name="██群", limit=500)` 获取消息,然后输出。消息很多时,也可以把 `limit` 设得更大,或配合 `offset` 分页读取
``` ```
## ██群最近一周分析 ## ██群最近一周分析

View File

@ -7,6 +7,7 @@ Runs on Windows Python (needs access to D:\ WeChat databases).
import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re
import hmac as hmac_mod import hmac as hmac_mod
from contextlib import closing
from datetime import datetime from datetime import datetime
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from Crypto.Cipher import AES from Crypto.Cipher import AES
@ -222,6 +223,8 @@ _contact_full = None # [{username, nick_name, remark}]
_self_username = None _self_username = None
_XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE) _XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE)
_XML_PARSE_MAX_LEN = 20000 _XML_PARSE_MAX_LEN = 20000
_QUERY_LIMIT_MAX = 500
_HISTORY_QUERY_BATCH_SIZE = 500
def _load_contacts_from(db_path): def _load_contacts_from(db_path):
@ -596,6 +599,577 @@ def _find_msg_table_for_user(username):
return None, None return None, None
def _find_msg_tables_for_user(username):
"""返回用户在所有 message_N.db 中对应的消息表,按最新消息时间倒序排列。"""
table_hash = hashlib.md5(username.encode()).hexdigest()
table_name = f"Msg_{table_hash}"
if not _is_safe_msg_table_name(table_name):
return []
matches = []
for rel_key in MSG_DB_KEYS:
path = _cache.get(rel_key)
if not path:
continue
conn = sqlite3.connect(path)
try:
exists = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
(table_name,)
).fetchone()
if not exists:
continue
max_create_time = conn.execute(
f"SELECT MAX(create_time) FROM [{table_name}]"
).fetchone()[0] or 0
matches.append({
'db_path': path,
'table_name': table_name,
'max_create_time': max_create_time,
})
except Exception:
pass
finally:
conn.close()
matches.sort(key=lambda item: item['max_create_time'], reverse=True)
return matches
def _validate_pagination(limit, offset=0, limit_max=_QUERY_LIMIT_MAX):
if limit <= 0:
raise ValueError("limit 必须大于 0")
if limit_max is not None and limit > limit_max:
raise ValueError(f"limit 不能大于 {limit_max}")
if offset < 0:
raise ValueError("offset 不能小于 0")
def _parse_time_value(value, field_name, is_end=False):
value = (value or '').strip()
if not value:
return None
formats = [
('%Y-%m-%d %H:%M:%S', False),
('%Y-%m-%d %H:%M', False),
('%Y-%m-%d', True),
]
for fmt, date_only in formats:
try:
dt = datetime.strptime(value, fmt)
if date_only and is_end:
dt = dt.replace(hour=23, minute=59, second=59)
return int(dt.timestamp())
except ValueError:
continue
raise ValueError(
f"{field_name} 格式无效: {value}。支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS"
)
def _parse_time_range(start_time='', end_time=''):
start_ts = _parse_time_value(start_time, 'start_time', is_end=False)
end_ts = _parse_time_value(end_time, 'end_time', is_end=True)
if start_ts is not None and end_ts is not None and start_ts > end_ts:
raise ValueError('start_time 不能晚于 end_time')
return start_ts, end_ts
def _build_message_filters(start_ts=None, end_ts=None, keyword=''):
clauses = []
params = []
if start_ts is not None:
clauses.append('create_time >= ?')
params.append(start_ts)
if end_ts is not None:
clauses.append('create_time <= ?')
params.append(end_ts)
if keyword:
clauses.append('message_content LIKE ?')
params.append(f'%{keyword}%')
return clauses, params
def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0):
if not _is_safe_msg_table_name(table_name):
raise ValueError(f'非法消息表名: {table_name}')
clauses, params = _build_message_filters(start_ts, end_ts, keyword)
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
sql = f"""
SELECT local_id, local_type, create_time, real_sender_id, message_content,
WCDB_CT_message_content
FROM [{table_name}]
{where_sql}
ORDER BY create_time DESC
"""
if limit is None:
return conn.execute(sql, params).fetchall()
sql += "\n LIMIT ? OFFSET ?"
return conn.execute(sql, (*params, limit, offset)).fetchall()
def _resolve_chat_context(chat_name):
username = resolve_username(chat_name)
if not username:
return None
names = get_contact_names()
display_name = names.get(username, username)
message_tables = _find_msg_tables_for_user(username)
if not message_tables:
return {
'query': chat_name,
'username': username,
'display_name': display_name,
'db_path': None,
'table_name': None,
'message_tables': [],
'is_group': '@chatroom' in username,
}
primary = message_tables[0]
return {
'query': chat_name,
'username': username,
'display_name': display_name,
'db_path': primary['db_path'],
'table_name': primary['table_name'],
'message_tables': message_tables,
'is_group': '@chatroom' in username,
}
def _resolve_chat_contexts(chat_names):
if not chat_names:
raise ValueError('chat_names 不能为空')
resolved = []
unresolved = []
missing_tables = []
seen = set()
for chat_name in chat_names:
name = (chat_name or '').strip()
if not name:
unresolved.append('(空)')
continue
ctx = _resolve_chat_context(name)
if not ctx:
unresolved.append(name)
continue
if not ctx['message_tables']:
missing_tables.append(ctx['display_name'])
continue
if ctx['username'] in seen:
continue
seen.add(ctx['username'])
resolved.append(ctx)
return resolved, unresolved, missing_tables
def _normalize_chat_names(chat_name):
if chat_name is None:
return []
if isinstance(chat_name, str):
value = chat_name.strip()
return [value] if value else []
if isinstance(chat_name, (list, tuple, set)):
normalized = []
for item in chat_name:
if item is None:
continue
value = str(item).strip()
if value:
normalized.append(value)
return normalized
value = str(chat_name).strip()
return [value] if value else []
def _format_history_lines(rows, username, display_name, is_group, names, id_to_username):
lines = []
ctx = {
'username': username,
'display_name': display_name,
'is_group': is_group,
}
for row in reversed(rows):
_, line = _build_history_line(row, ctx, names, id_to_username)
lines.append(line)
return lines
def _build_search_entry(row, ctx, names, id_to_username):
local_id, local_type, create_time, real_sender_id, content, ct = row
content = _decompress_content(content, ct)
if content is None:
return None
sender, text = _format_message_text(
local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names
)
if text and len(text) > 300:
text = text[:300] + '...'
sender_label = _resolve_sender_label(
real_sender_id,
sender,
ctx['is_group'],
ctx['username'],
ctx['display_name'],
names,
id_to_username,
)
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M')
entry = f"[{time_str}] [{ctx['display_name']}]"
if sender_label:
entry += f" {sender_label}:"
entry += f" {text}"
return create_time, entry
def _build_history_line(row, ctx, names, id_to_username):
local_id, local_type, create_time, real_sender_id, content, ct = row
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M')
content = _decompress_content(content, ct)
if content is None:
content = '(无法解压)'
sender, text = _format_message_text(
local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names
)
sender_label = _resolve_sender_label(
real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username
)
if sender_label:
return create_time, f'[{time_str}] {sender_label}: {text}'
return create_time, f'[{time_str}] {text}'
def _get_chat_message_tables(ctx):
if ctx.get('message_tables'):
return ctx['message_tables']
if ctx.get('db_path') and ctx.get('table_name'):
return [{'db_path': ctx['db_path'], 'table_name': ctx['table_name']}]
return []
def _iter_table_contexts(ctx):
for table in _get_chat_message_tables(ctx):
yield {
'query': ctx['query'],
'username': ctx['username'],
'display_name': ctx['display_name'],
'db_path': table['db_path'],
'table_name': table['table_name'],
'is_group': ctx['is_group'],
}
def _candidate_page_size(limit, offset):
return limit + offset
def _message_query_batch_size(candidate_limit):
return candidate_limit
def _history_query_batch_size(candidate_limit):
return min(candidate_limit, _HISTORY_QUERY_BATCH_SIZE)
def _page_ranked_entries(entries, limit, offset):
ordered = sorted(entries, key=lambda item: item[0], reverse=True)
paged = ordered[offset:offset + limit]
paged.sort(key=lambda item: item[0])
return paged
def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0):
collected = []
failures = []
candidate_limit = _candidate_page_size(limit, offset)
batch_size = _history_query_batch_size(candidate_limit)
for table_ctx in _iter_table_contexts(ctx):
try:
with closing(sqlite3.connect(table_ctx['db_path'])) as conn:
id_to_username = _load_name2id_maps(conn)
fetch_offset = 0
collected_before_table = len(collected)
# 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。
while len(collected) - collected_before_table < candidate_limit:
rows = _query_messages(
conn,
table_ctx['table_name'],
start_ts=start_ts,
end_ts=end_ts,
limit=batch_size,
offset=fetch_offset,
)
if not rows:
break
fetch_offset += len(rows)
for row in rows:
try:
collected.append(_build_history_line(row, table_ctx, names, id_to_username))
except Exception as e:
failures.append(
f"{table_ctx['display_name']} local_id={row[0]} create_time={row[2]}: {e}"
)
if len(collected) - collected_before_table >= candidate_limit:
break
if len(rows) < batch_size:
break
except Exception as e:
failures.append(f"{table_ctx['db_path']}: {e}")
paged = _page_ranked_entries(collected, limit, offset)
return [line for _, line in paged], failures
def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
collected = []
failures = []
contexts_by_db = {}
for table_ctx in _iter_table_contexts(ctx):
contexts_by_db.setdefault(table_ctx['db_path'], []).append(table_ctx)
for db_path, db_contexts in contexts_by_db.items():
try:
with closing(sqlite3.connect(db_path)) as conn:
db_entries, db_failures = _collect_search_entries(
conn,
db_contexts,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(db_entries)
failures.extend(db_failures)
except Exception as e:
failures.extend(f"{table_ctx['display_name']}: {e}" for table_ctx in db_contexts)
return collected, failures
def _load_search_contexts_from_db(conn, db_path, names):
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'"
).fetchall()
table_to_username = {}
try:
for (user_name,) in conn.execute("SELECT user_name FROM Name2Id").fetchall():
if not user_name:
continue
table_hash = hashlib.md5(user_name.encode()).hexdigest()
table_to_username[f"Msg_{table_hash}"] = user_name
except sqlite3.Error:
pass
contexts = []
for (table_name,) in tables:
username = table_to_username.get(table_name, '')
display_name = names.get(username, username) if username else table_name
contexts.append({
'query': display_name,
'username': username,
'display_name': display_name,
'db_path': db_path,
'table_name': table_name,
'is_group': '@chatroom' in username,
})
return contexts
def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
collected = []
failures = []
id_to_username = _load_name2id_maps(conn)
batch_size = _message_query_batch_size(candidate_limit)
for ctx in contexts:
try:
fetch_offset = 0
collected_before_table = len(collected)
# 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。
while len(collected) - collected_before_table < candidate_limit:
rows = _query_messages(
conn,
ctx['table_name'],
start_ts=start_ts,
end_ts=end_ts,
keyword=keyword,
limit=batch_size,
offset=fetch_offset,
)
if not rows:
break
fetch_offset += len(rows)
for row in rows:
formatted = _build_search_entry(row, ctx, names, id_to_username)
if formatted:
collected.append(formatted)
if len(collected) - collected_before_table >= candidate_limit:
break
if len(rows) < batch_size:
break
except Exception as e:
failures.append(f"{ctx['display_name']}: {e}")
return collected, failures
def _page_search_entries(entries, limit, offset):
return _page_ranked_entries(entries, limit, offset)
def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
names = get_contact_names()
candidate_limit = _candidate_page_size(limit, offset)
entries, failures = _collect_chat_search_entries(
ctx,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
paged = _page_search_entries(entries, limit, offset)
if not paged:
if failures:
return "查询失败: " + "".join(failures)
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息"
header = f"{ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(paged)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
try:
resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names)
except ValueError as e:
return f"错误: {e}"
if not resolved_contexts:
details = []
if unresolved:
details.append("未找到联系人: " + "".join(unresolved))
if missing_tables:
details.append("无消息表: " + "".join(missing_tables))
suffix = f"\n{chr(10).join(details)}" if details else ""
return f"错误: 没有可查询的聊天对象{suffix}"
names = get_contact_names()
candidate_limit = _candidate_page_size(limit, offset)
collected = []
failures = []
for ctx in resolved_contexts:
chat_entries, chat_failures = _collect_chat_search_entries(
ctx,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(chat_entries)
failures.extend(chat_failures)
paged = _page_search_entries(collected, limit, offset)
notes = []
if unresolved:
notes.append("未找到联系人: " + "".join(unresolved))
if missing_tables:
notes.append("无消息表: " + "".join(missing_tables))
if failures:
notes.append("查询失败: " + "".join(failures))
if not paged:
header = f"{len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header
header = (
f"{len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果"
f"offset={offset}, limit={limit}"
)
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if notes:
header += "\n" + "\n".join(notes)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, offset):
names = get_contact_names()
collected = []
failures = []
candidate_limit = _candidate_page_size(limit, offset)
for rel_key in MSG_DB_KEYS:
path = _cache.get(rel_key)
if not path:
continue
try:
with closing(sqlite3.connect(path)) as conn:
contexts = _load_search_contexts_from_db(conn, path, names)
db_entries, db_failures = _collect_search_entries(
conn,
contexts,
names,
keyword,
start_ts=start_ts,
end_ts=end_ts,
candidate_limit=candidate_limit,
)
collected.extend(db_entries)
failures.extend(db_failures)
except Exception as e:
failures.append(f"{rel_key}: {e}")
paged = _page_search_entries(collected, limit, offset)
if not paged:
header = f"未找到包含 \"{keyword}\" 的消息"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header
header = f"搜索 \"{keyword}\" 找到 {len(paged)} 条结果offset={offset}, limit={limit}"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
# ============ MCP Server ============ # ============ MCP Server ============
mcp = FastMCP("wechat", instructions="查询微信消息、联系人等数据") mcp = FastMCP("wechat", instructions="查询微信消息、联系人等数据")
@ -617,16 +1191,15 @@ def get_recent_sessions(limit: int = 20) -> str:
return "错误: 无法解密 session.db" return "错误: 无法解密 session.db"
names = get_contact_names() names = get_contact_names()
conn = sqlite3.connect(path) with closing(sqlite3.connect(path)) as conn:
rows = conn.execute(""" rows = conn.execute("""
SELECT username, unread_count, summary, last_timestamp, SELECT username, unread_count, summary, last_timestamp,
last_msg_type, last_msg_sender, last_sender_display_name last_msg_type, last_msg_sender, last_sender_display_name
FROM SessionTable FROM SessionTable
WHERE last_timestamp > 0 WHERE last_timestamp > 0
ORDER BY last_timestamp DESC ORDER BY last_timestamp DESC
LIMIT ? LIMIT ?
""", (limit,)).fetchall() """, (limit,)).fetchall()
conn.close()
results = [] results = []
for r in rows: for r in rows:
@ -664,158 +1237,121 @@ def get_recent_sessions(limit: int = 20) -> str:
@mcp.tool() @mcp.tool()
def get_chat_history(chat_name: str, limit: int = 50) -> str: def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str:
"""获取指定聊天的消息记录。 """获取指定聊天的消息记录。
Args: Args:
chat_name: 聊天对象的名字备注名或wxid自动模糊匹配 chat_name: 聊天对象的名字备注名或wxid自动模糊匹配
limit: 返回的消息数量默认50 limit: 返回的消息数量默认50支持较大的值建议配合 offset 分页使用
offset: 分页偏移量默认0
start_time: 起始时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
end_time: 结束时间支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
""" """
username = resolve_username(chat_name) try:
if not username: _validate_pagination(limit, offset, limit_max=None)
start_ts, end_ts = _parse_time_range(start_time, end_time)
except ValueError as e:
return f"错误: {e}"
ctx = _resolve_chat_context(chat_name)
if not ctx:
return f"找不到聊天对象: {chat_name}\n提示: 可以用 get_contacts(query='{chat_name}') 搜索联系人" return f"找不到聊天对象: {chat_name}\n提示: 可以用 get_contacts(query='{chat_name}') 搜索联系人"
if not ctx['db_path']:
return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息"
names = get_contact_names() names = get_contact_names()
display_name = names.get(username, username) lines, failures = _collect_chat_history_lines(
is_group = '@chatroom' in username ctx,
names,
start_ts=start_ts,
end_ts=end_ts,
limit=limit,
offset=offset,
)
db_path, table_name = _find_msg_table_for_user(username) if not lines:
if not db_path: if failures:
return f"找不到 {display_name} 的消息记录可能在未解密的DB中或无消息" return "查询失败: " + "".join(failures)
return f"{ctx['display_name']} 无消息记录"
conn = sqlite3.connect(db_path) header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)}offset={offset}, limit={limit}"
try: if ctx['is_group']:
id_to_username = _load_name2id_maps(conn)
rows = conn.execute(f"""
SELECT local_id, local_type, create_time, real_sender_id, message_content,
WCDB_CT_message_content
FROM [{table_name}]
ORDER BY create_time DESC
LIMIT ?
""", (limit,)).fetchall()
except Exception as e:
conn.close()
return f"查询失败: {e}"
conn.close()
if not rows:
return f"{display_name} 无消息记录"
lines = []
for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows):
time_str = datetime.fromtimestamp(create_time).strftime('%m-%d %H:%M')
# zstd 解压
content = _decompress_content(content, ct)
if content is None:
content = '(无法解压)'
sender, text = _format_message_text(
local_id, local_type, content, is_group, username, display_name, names
)
if text and len(text) > 500:
text = text[:500] + "..."
sender_label = _resolve_sender_label(
real_sender_id, sender, is_group, username, display_name, names, id_to_username
)
if sender_label:
lines.append(f"[{time_str}] {sender_label}: {text}")
else:
lines.append(f"[{time_str}] {text}")
header = f"{display_name} 的最近 {len(lines)} 条消息"
if is_group:
header += " [群聊]" header += " [群聊]"
if start_time or end_time:
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
if failures:
header += "\n查询失败: " + "".join(failures)
return header + ":\n\n" + "\n".join(lines) return header + ":\n\n" + "\n".join(lines)
@mcp.tool() @mcp.tool()
def search_messages(keyword: str, limit: int = 20) -> str: def search_messages(
"""在所有聊天记录中搜索包含关键词的消息。 keyword: str,
chat_name: str | list[str] | None = None,
start_time: str = "",
end_time: str = "",
limit: int = 20,
offset: int = 0,
) -> str:
"""搜索消息内容,支持全库、单个聊天对象、多个聊天对象,以及时间范围和分页。
Args: Args:
keyword: 搜索关键词 keyword: 搜索关键词
limit: 返回的结果数量默认20 chat_name: 聊天对象名称可为空单个字符串或字符串列表
start_time: 起始时间可为空
end_time: 结束时间可为空
limit: 返回的结果数量默认20最大500
offset: 分页偏移量默认0
""" """
if not keyword or len(keyword) < 1: if not keyword or len(keyword) < 1:
return "请提供搜索关键词" return "请提供搜索关键词"
names = get_contact_names() chat_names = _normalize_chat_names(chat_name)
results = []
for rel_key in MSG_DB_KEYS: try:
if len(results) >= limit: _validate_pagination(limit, offset)
break start_ts, end_ts = _parse_time_range(start_time, end_time)
except ValueError as e:
return f"错误: {e}"
path = _cache.get(rel_key) if len(chat_names) == 1:
if not path: ctx = _resolve_chat_context(chat_names[0])
continue if not ctx:
return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人"
if not ctx['db_path']:
return f"找不到 {ctx['display_name']} 的消息记录可能在未解密的DB中或无消息"
return _search_single_chat(
ctx,
keyword,
start_ts,
end_ts,
start_time,
end_time,
limit,
offset,
)
conn = sqlite3.connect(path) if len(chat_names) > 1:
try: return _search_multiple_chats(
# 获取所有 Msg_ 表 chat_names,
tables = conn.execute( keyword,
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'" start_ts,
).fetchall() end_ts,
start_time,
# 获取 Name2Id 映射hash -> username 反查) end_time,
name2id = {} limit,
try: offset,
for r in conn.execute("SELECT user_name FROM Name2Id").fetchall(): )
h = hashlib.md5(r[0].encode()).hexdigest()
name2id[f"Msg_{h}"] = r[0]
except Exception:
pass
for (tname,) in tables:
if len(results) >= limit:
break
username = name2id.get(tname, '')
is_group = '@chatroom' in username
display = names.get(username, username) if username else tname
try:
rows = conn.execute(f"""
SELECT local_type, create_time, message_content,
WCDB_CT_message_content
FROM [{tname}]
WHERE message_content LIKE ?
ORDER BY create_time DESC
LIMIT ?
""", (f'%{keyword}%', limit - len(results))).fetchall()
except Exception:
continue
for local_type, ts, content, ct in rows:
content = _decompress_content(content, ct)
if content is None:
continue
sender, text = _parse_message_content(content, local_type, is_group)
time_str = datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M')
sender_name = ''
if is_group and sender:
sender_name = names.get(sender, sender)
entry = f"[{time_str}] [{display}]"
if sender_name:
entry += f" {sender_name}:"
entry += f" {text}"
if len(entry) > 300:
entry = entry[:300] + "..."
results.append((ts, entry))
finally:
conn.close()
results.sort(key=lambda x: x[0], reverse=True)
entries = [r[1] for r in results[:limit]]
if not entries:
return f"未找到包含 \"{keyword}\" 的消息"
return f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果:\n\n" + "\n\n".join(entries)
return _search_all_messages(
keyword,
start_ts,
end_ts,
start_time,
end_time,
limit,
offset,
)
@mcp.tool() @mcp.tool()
def get_contacts(query: str = "", limit: int = 50) -> str: def get_contacts(query: str = "", limit: int = 50) -> str:
@ -870,15 +1406,14 @@ def get_new_messages() -> str:
return "错误: 无法解密 session.db" return "错误: 无法解密 session.db"
names = get_contact_names() names = get_contact_names()
conn = sqlite3.connect(path) with closing(sqlite3.connect(path)) as conn:
rows = conn.execute(""" rows = conn.execute("""
SELECT username, unread_count, summary, last_timestamp, SELECT username, unread_count, summary, last_timestamp,
last_msg_type, last_msg_sender, last_sender_display_name last_msg_type, last_msg_sender, last_sender_display_name
FROM SessionTable FROM SessionTable
WHERE last_timestamp > 0 WHERE last_timestamp > 0
ORDER BY last_timestamp DESC ORDER BY last_timestamp DESC
""").fetchall() """).fetchall()
conn.close()
curr_state = {} curr_state = {}
for r in rows: for r in rows:

View File

@ -0,0 +1,686 @@
import hashlib
import os
import sqlite3
import tempfile
import unittest
from unittest.mock import patch
import mcp_server
class _FakeCache:
# 用最小缓存桩替代真实解密缓存,避免单元测试依赖本地微信环境。
def __init__(self, mapping):
self._mapping = mapping
def get(self, rel_key):
return self._mapping.get(rel_key)
def _msg_table_name(username):
# 生产代码使用 username 的 md5 作为消息表名,测试里保持一致。
return f"Msg_{hashlib.md5(username.encode()).hexdigest()}"
def _create_message_db(path, chats):
# 构造最小可用消息库,只包含搜索/历史查询依赖的字段。
conn = sqlite3.connect(path)
try:
conn.execute("CREATE TABLE Name2Id (user_name TEXT)")
for username, messages in chats.items():
conn.execute("INSERT INTO Name2Id(user_name) VALUES (?)", (username,))
table_name = _msg_table_name(username)
conn.execute(
f"""
CREATE TABLE [{table_name}] (
local_id INTEGER,
local_type INTEGER,
create_time INTEGER,
real_sender_id INTEGER,
message_content TEXT,
WCDB_CT_message_content INTEGER
)
"""
)
for local_id, create_time, content in messages:
conn.execute(
f"""
INSERT INTO [{table_name}] (
local_id, local_type, create_time, real_sender_id,
message_content, WCDB_CT_message_content
) VALUES (?, ?, ?, ?, ?, ?)
""",
(local_id, 1, create_time, 0, content, 0),
)
conn.commit()
finally:
conn.close()
class SearchMessagesTests(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.addCleanup(self.temp_dir.cleanup)
def create_db(self, filename, chats):
path = os.path.join(self.temp_dir.name, filename)
_create_message_db(path, chats)
return path
def test_validate_pagination_rejects_large_limit(self):
# 防止单次查询过大,保证 limit 上限校验存在。
with self.assertRaisesRegex(ValueError, "limit 不能大于 500"):
mcp_server._validate_pagination(501, 0)
def test_validate_pagination_allows_large_limit_when_limit_is_unbounded(self):
# get_chat_history 允许更大的 limit只校验正数和 offset。
mcp_server._validate_pagination(999999, 0, limit_max=None)
def test_page_search_entries_returns_chronological_results_with_offset(self):
# 结果应先按最新时间分页,再把当前页恢复成时间正序输出。
entries = [(1, "a"), (5, "e"), (3, "c"), (4, "d"), (2, "b")]
paged = mcp_server._page_search_entries(entries, limit=2, offset=1)
self.assertEqual(paged, [(3, "c"), (4, "d")])
def test_search_messages_single_chat_uses_offset_and_returns_page(self):
# 单聊分页应只返回当前页,并按聊天阅读顺序展示。
db_path = self.create_db(
"single.db",
{
"alice": [
(1, 100, "foo newest"),
(2, 90, "foo middle"),
(3, 80, "foo oldest"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.search_messages("foo", chat_name="Alice", limit=2, offset=1)
self.assertIn('在 Alice 中搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertLess(result.index("foo oldest"), result.index("foo middle"))
self.assertNotIn("foo newest", result)
def test_search_messages_multiple_chats_applies_global_pagination(self):
# 多个聊天联合搜索时,分页必须基于合并后的全局结果。
db_path = self.create_db(
"multi.db",
{
"alice": [
(1, 110, "foo a1"),
(2, 90, "foo a2"),
],
"bob": [
(1, 100, "foo b1"),
(2, 80, "foo b2"),
],
},
)
contexts = [
{
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
},
{
"query": "Bob",
"username": "bob",
"display_name": "Bob",
"db_path": db_path,
"table_name": _msg_table_name("bob"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
"is_group": False,
},
]
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
):
result = mcp_server.search_messages("foo", chat_name=["Alice", "Bob"], limit=2, offset=1)
self.assertIn('在 2 个聊天对象中搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertLess(result.index("foo a2"), result.index("foo b1"))
self.assertNotIn("foo a1", result)
self.assertNotIn("foo b2", result)
def test_search_messages_all_messages_merges_global_results_before_paging(self):
# 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。
older_db = self.create_db(
"older.db",
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
)
newer_db = self.create_db(
"newer.db",
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2")]},
)
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"older_user": "Older", "newer_user": "Newer"},
):
result = mcp_server.search_messages("foo", limit=2, offset=0)
self.assertIn('搜索 "foo" 找到 2 条结果offset=0, limit=2', result)
self.assertLess(result.index("foo newer 2"), result.index("foo newer 1"))
self.assertNotIn("foo older 1", result)
def test_search_messages_all_messages_uses_bounded_sql_pagination(self):
# 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。
older_db = self.create_db(
"older_paged.db",
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
)
newer_db = self.create_db(
"newer_paged.db",
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]},
)
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
original_query_messages = mcp_server._query_messages
calls = []
def recording_query_messages(*args, **kwargs):
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
return original_query_messages(*args, **kwargs)
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"older_user": "Older", "newer_user": "Newer"},
), patch.object(
mcp_server, "_query_messages", side_effect=recording_query_messages
):
result = mcp_server.search_messages("foo", limit=2, offset=1)
self.assertIn('搜索 "foo" 找到 2 条结果offset=1, limit=2', result)
self.assertEqual(
calls,
[
(_msg_table_name("older_user"), 3, 0),
(_msg_table_name("newer_user"), 3, 0),
],
)
def test_search_messages_single_chat_respects_time_range(self):
# 单聊搜索的开始/结束时间都必须严格生效。
db_path = self.create_db(
"single_time.db",
{
"alice": [
(1, 300, "foo in range"),
(2, 200, "foo too early"),
(3, 400, "foo too late"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 350)
):
result = mcp_server.search_messages(
"foo",
chat_name="Alice",
start_time="custom-start",
end_time="custom-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: custom-start ~ custom-end", result)
self.assertIn("foo in range", result)
self.assertNotIn("foo too early", result)
self.assertNotIn("foo too late", result)
def test_search_messages_multiple_chats_respects_time_range(self):
# 多聊联合搜索时,每个聊天对象都要套用同一时间范围。
db_path = self.create_db(
"multi_time.db",
{
"alice": [(1, 300, "foo alice in range"), (2, 150, "foo alice too early")],
"bob": [(1, 320, "foo bob in range"), (2, 500, "foo bob too late")],
},
)
contexts = [
{
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
},
{
"query": "Bob",
"username": "bob",
"display_name": "Bob",
"db_path": db_path,
"table_name": _msg_table_name("bob"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
"is_group": False,
},
]
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 400)
):
result = mcp_server.search_messages(
"foo",
chat_name=["Alice", "Bob"],
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: range-start ~ range-end", result)
self.assertIn("foo alice in range", result)
self.assertIn("foo bob in range", result)
self.assertNotIn("foo alice too early", result)
self.assertNotIn("foo bob too late", result)
def test_search_messages_all_messages_respects_time_range(self):
# 全库搜索也不能返回时间范围外的消息。
db_path = self.create_db(
"all_time.db",
{
"alice": [
(1, 100, "foo too early"),
(2, 300, "foo in range"),
(3, 500, "foo too late"),
]
},
)
fake_cache = _FakeCache({"all": db_path})
with patch.object(mcp_server, "MSG_DB_KEYS", ["all"]), patch.object(
mcp_server, "_cache", fake_cache
), patch.object(
mcp_server,
"get_contact_names",
return_value={"alice": "Alice"},
), patch.object(
mcp_server, "_parse_time_range", return_value=(250, 350)
):
result = mcp_server.search_messages(
"foo",
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("时间范围: range-start ~ range-end", result)
self.assertIn("foo in range", result)
self.assertNotIn("foo too early", result)
self.assertNotIn("foo too late", result)
def test_get_chat_history_merges_sharded_message_tables(self):
# 同一联系人跨多个 message_N.db 分片时,历史查询要先合并再分页。
older_db = self.create_db("history_older.db", {"alice": [(1, 100, "old message")]})
newer_db = self.create_db(
"history_newer.db",
{"alice": [(1, 300, "new message"), (2, 250, "middle message")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": newer_db,
"table_name": _msg_table_name("alice"),
"message_tables": [
{"db_path": older_db, "table_name": _msg_table_name("alice")},
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
self.assertIn("Alice 的消息记录(返回 2 条offset=0, limit=2", result)
self.assertIn("middle message", result)
self.assertIn("new message", result)
self.assertNotIn("old message", result)
def test_get_chat_history_large_limit_reads_all_rows_across_shards(self):
# 大 limit 下,跨分片历史查询不能只返回较旧分片里的少量消息。
older_messages = [
(index, 1000 + index, f"old shard message {index}")
for index in range(1, 18)
]
newer_messages = [
(index, 2000 + index, f"new shard message {index}")
for index in range(1, 296)
]
older_db = self.create_db("history_cross_shard_older.db", {"alice": older_messages})
newer_db = self.create_db("history_cross_shard_newer.db", {"alice": newer_messages})
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": newer_db,
"table_name": _msg_table_name("alice"),
"message_tables": [
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
{"db_path": older_db, "table_name": _msg_table_name("alice")},
],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.get_chat_history("Alice", limit=500, offset=0)
self.assertIn("Alice 的消息记录(返回 312 条offset=0, limit=500", result)
self.assertIn("new shard message 295", result)
self.assertIn("old shard message 17", result)
body = result.split(":\n\n", 1)[1]
self.assertEqual(len(body.splitlines()), 312)
def test_get_chat_history_uses_bounded_sql_pagination(self):
# 历史查询应把 offset+limit 下推到 SQL避免把整张消息表读出来后再切片。
db_path = self.create_db(
"history_paged.db",
{
"alice": [
(1, 400, "newest"),
(2, 300, "middle"),
(3, 200, "older"),
(4, 100, "oldest"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_query_messages = mcp_server._query_messages
calls = []
def recording_query_messages(*args, **kwargs):
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
return original_query_messages(*args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_query_messages", side_effect=recording_query_messages
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=1)
self.assertIn("middle", result)
self.assertIn("older", result)
self.assertNotIn("newest", result)
self.assertNotIn("oldest", result)
self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)])
def test_get_chat_history_allows_large_limit_values(self):
# 历史查询不应再把大 limit 直接拒绝掉。
db_path = self.create_db(
"history_large_limit.db",
{
"alice": [
(1, 200, "message 1"),
(2, 100, "message 2"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.get_chat_history("Alice", limit=999999, offset=0)
self.assertNotIn("错误:", result)
self.assertIn("message 1", result)
self.assertIn("message 2", result)
def test_get_chat_history_keeps_partial_results_when_formatting_fails(self):
# 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。
db_path = self.create_db(
"history_partial_failure.db",
{"alice": [(1, 200, "good message"), (2, 100, "bad message")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_build_history_line = mcp_server._build_history_line
def flaky_build_history_line(row, *args, **kwargs):
if row[2] == 100:
raise ValueError("bad row")
return original_build_history_line(row, *args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_build_history_line", side_effect=flaky_build_history_line
):
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
self.assertIn("good message", result)
self.assertIn("查询失败:", result)
self.assertIn("bad row", result)
def test_get_chat_history_does_not_truncate_long_messages(self):
# 历史记录应返回完整消息内容,而不是固定截断到 500 字符。
long_message = "x" * 600
db_path = self.create_db(
"history_long_message.db",
{"alice": [(1, 200, long_message)]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
):
result = mcp_server.get_chat_history("Alice", limit=1, offset=0)
self.assertIn(long_message, result)
self.assertNotIn(("x" * 500) + "...", result)
def test_search_messages_single_chat_merges_sharded_message_tables(self):
# 单聊搜索也要跨分片合并,否则最近消息可能查不到。
older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]})
newer_db = self.create_db(
"search_newer.db",
{"alice": [(1, 300, "foo new"), (2, 200, "foo middle")]},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": newer_db,
"table_name": _msg_table_name("alice"),
"message_tables": [
{"db_path": older_db, "table_name": _msg_table_name("alice")},
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
],
"is_group": False,
}
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_parse_time_range", return_value=(150, 350)
):
result = mcp_server.search_messages(
"foo",
chat_name="Alice",
start_time="range-start",
end_time="range-end",
limit=20,
offset=0,
)
self.assertIn("foo middle", result)
self.assertIn("foo new", result)
self.assertNotIn("foo old", result)
def test_search_messages_keeps_partial_results_when_later_batch_fails(self):
# 后续批次失败时,前面已经拿到的有效结果不应被丢弃。
db_path = self.create_db(
"search_partial_failure.db",
{
"alice": [
(1, 400, "foo newest"),
(2, 300, "foo skipped"),
(3, 200, "foo older"),
(4, 100, "foo bad"),
]
},
)
ctx = {
"query": "Alice",
"username": "alice",
"display_name": "Alice",
"db_path": db_path,
"table_name": _msg_table_name("alice"),
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
"is_group": False,
}
original_build_search_entry = mcp_server._build_search_entry
def flaky_build_search_entry(row, *args, **kwargs):
if row[2] == 300:
return None
if row[2] == 100:
raise ValueError("bad row")
return original_build_search_entry(row, *args, **kwargs)
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
mcp_server, "_resolve_chat_context", return_value=ctx
), patch.object(
mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry
):
result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0)
self.assertIn("foo newest", result)
self.assertIn("foo older", result)
self.assertIn("查询失败:", result)
self.assertIn("bad row", result)
def test_get_recent_sessions_closes_connection_when_query_fails(self):
# 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
class _FakeConn:
def __init__(self):
self.closed = False
def execute(self, *args, **kwargs):
raise sqlite3.OperationalError("boom")
def close(self):
self.closed = True
fake_conn = _FakeConn()
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
mcp_server, "get_contact_names", return_value={}
), patch.object(
mcp_server.sqlite3, "connect", return_value=fake_conn
):
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
mcp_server.get_recent_sessions()
self.assertTrue(fake_conn.closed)
def test_get_new_messages_closes_connection_when_query_fails(self):
# 新消息轮询失败时也要释放 sqlite3 连接。
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
class _FakeConn:
def __init__(self):
self.closed = False
def execute(self, *args, **kwargs):
raise sqlite3.OperationalError("boom")
def close(self):
self.closed = True
fake_conn = _FakeConn()
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
mcp_server, "get_contact_names", return_value={}
), patch.object(
mcp_server.sqlite3, "connect", return_value=fake_conn
):
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
mcp_server.get_new_messages()
self.assertTrue(fake_conn.closed)
if __name__ == "__main__":
unittest.main()