mirror of https://github.com/jackwener/wx-cli.git
Merge pull request #2 from dsjzazs/codex/searchmessages
Add unit tests for MCP search and fix paginationfeat/daemon-cli
commit
2cd180c63a
|
|
@ -139,6 +139,8 @@ claude mcp add wechat -- python C:\Users\你的用户名\wechat-decrypt\mcp_serv
|
||||||
|
|
||||||
前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。
|
前置条件:需要先运行 `python main.py` 或 `python find_all_keys.py` 完成密钥提取。
|
||||||
|
|
||||||
|
说明:`get_chat_history` 和 `search_messages` 的 `limit` 最大为 `500`。
|
||||||
|
|
||||||
**[查看使用案例 →](USAGE.md)**
|
**[查看使用案例 →](USAGE.md)**
|
||||||
|
|
||||||
### 图片解密 (V2 格式)
|
### 图片解密 (V2 格式)
|
||||||
|
|
|
||||||
862
mcp_server.py
862
mcp_server.py
|
|
@ -5,9 +5,10 @@ Based on FastMCP (stdio transport), reuses existing decryption.
|
||||||
Runs on Windows Python (needs access to D:\ WeChat databases).
|
Runs on Windows Python (needs access to D:\ WeChat databases).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re
|
import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re
|
||||||
import hmac as hmac_mod
|
import hmac as hmac_mod
|
||||||
from datetime import datetime
|
from contextlib import closing
|
||||||
|
from datetime import datetime
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
@ -219,9 +220,10 @@ atexit.register(_cache.cleanup)
|
||||||
|
|
||||||
_contact_names = None # {username: display_name}
|
_contact_names = None # {username: display_name}
|
||||||
_contact_full = None # [{username, nick_name, remark}]
|
_contact_full = None # [{username, nick_name, remark}]
|
||||||
_self_username = None
|
_self_username = None
|
||||||
_XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE)
|
_XML_UNSAFE_RE = re.compile(r'<!DOCTYPE|<!ENTITY', re.IGNORECASE)
|
||||||
_XML_PARSE_MAX_LEN = 20000
|
_XML_PARSE_MAX_LEN = 20000
|
||||||
|
_QUERY_LIMIT_MAX = 500
|
||||||
|
|
||||||
|
|
||||||
def _load_contacts_from(db_path):
|
def _load_contacts_from(db_path):
|
||||||
|
|
@ -568,12 +570,12 @@ MSG_DB_KEYS = sorted([
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def _find_msg_table_for_user(username):
|
def _find_msg_table_for_user(username):
|
||||||
"""在所有 message_N.db 中查找用户的消息表,返回 (db_path, table_name)"""
|
"""在所有 message_N.db 中查找用户的消息表,返回 (db_path, table_name)"""
|
||||||
table_hash = hashlib.md5(username.encode()).hexdigest()
|
table_hash = hashlib.md5(username.encode()).hexdigest()
|
||||||
table_name = f"Msg_{table_hash}"
|
table_name = f"Msg_{table_hash}"
|
||||||
if not _is_safe_msg_table_name(table_name):
|
if not _is_safe_msg_table_name(table_name):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
for rel_key in MSG_DB_KEYS:
|
for rel_key in MSG_DB_KEYS:
|
||||||
path = _cache.get(rel_key)
|
path = _cache.get(rel_key)
|
||||||
|
|
@ -592,15 +594,54 @@ def _find_msg_table_for_user(username):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_msg_tables_for_user(username):
|
||||||
|
"""返回用户在所有 message_N.db 中对应的消息表,按最新消息时间倒序排列。"""
|
||||||
|
table_hash = hashlib.md5(username.encode()).hexdigest()
|
||||||
|
table_name = f"Msg_{table_hash}"
|
||||||
|
if not _is_safe_msg_table_name(table_name):
|
||||||
|
return []
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for rel_key in MSG_DB_KEYS:
|
||||||
|
path = _cache.get(rel_key)
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
conn = sqlite3.connect(path)
|
||||||
|
try:
|
||||||
|
exists = conn.execute(
|
||||||
|
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
|
||||||
|
(table_name,)
|
||||||
|
).fetchone()
|
||||||
|
if not exists:
|
||||||
|
continue
|
||||||
|
max_create_time = conn.execute(
|
||||||
|
f"SELECT MAX(create_time) FROM [{table_name}]"
|
||||||
|
).fetchone()[0] or 0
|
||||||
|
matches.append({
|
||||||
|
'db_path': path,
|
||||||
|
'table_name': table_name,
|
||||||
|
'max_create_time': max_create_time,
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
matches.sort(key=lambda item: item['max_create_time'], reverse=True)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
def _validate_pagination(limit, offset=0):
|
def _validate_pagination(limit, offset=0):
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
raise ValueError("limit 必须大于 0")
|
raise ValueError("limit 必须大于 0")
|
||||||
if offset < 0:
|
if limit > _QUERY_LIMIT_MAX:
|
||||||
raise ValueError("offset 不能小于 0")
|
raise ValueError(f"limit 不能大于 {_QUERY_LIMIT_MAX}")
|
||||||
|
if offset < 0:
|
||||||
|
raise ValueError("offset 不能小于 0")
|
||||||
|
|
||||||
|
|
||||||
def _parse_time_value(value, field_name, is_end=False):
|
def _parse_time_value(value, field_name, is_end=False):
|
||||||
|
|
@ -650,49 +691,54 @@ def _build_message_filters(start_ts=None, end_ts=None, keyword=''):
|
||||||
return clauses, params
|
return clauses, params
|
||||||
|
|
||||||
|
|
||||||
def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0):
|
def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0):
|
||||||
if not _is_safe_msg_table_name(table_name):
|
if not _is_safe_msg_table_name(table_name):
|
||||||
raise ValueError(f'非法消息表名: {table_name}')
|
raise ValueError(f'非法消息表名: {table_name}')
|
||||||
|
|
||||||
clauses, params = _build_message_filters(start_ts, end_ts, keyword)
|
clauses, params = _build_message_filters(start_ts, end_ts, keyword)
|
||||||
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
|
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT local_id, local_type, create_time, real_sender_id, message_content,
|
SELECT local_id, local_type, create_time, real_sender_id, message_content,
|
||||||
WCDB_CT_message_content
|
WCDB_CT_message_content
|
||||||
FROM [{table_name}]
|
FROM [{table_name}]
|
||||||
{where_sql}
|
{where_sql}
|
||||||
ORDER BY create_time DESC
|
ORDER BY create_time DESC
|
||||||
LIMIT ? OFFSET ?
|
"""
|
||||||
"""
|
if limit is None:
|
||||||
return conn.execute(sql, (*params, limit, offset)).fetchall()
|
return conn.execute(sql, params).fetchall()
|
||||||
|
sql += "\n LIMIT ? OFFSET ?"
|
||||||
|
return conn.execute(sql, (*params, limit, offset)).fetchall()
|
||||||
|
|
||||||
|
|
||||||
def _resolve_chat_context(chat_name):
|
def _resolve_chat_context(chat_name):
|
||||||
username = resolve_username(chat_name)
|
username = resolve_username(chat_name)
|
||||||
if not username:
|
if not username:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
names = get_contact_names()
|
names = get_contact_names()
|
||||||
display_name = names.get(username, username)
|
display_name = names.get(username, username)
|
||||||
db_path, table_name = _find_msg_table_for_user(username)
|
message_tables = _find_msg_tables_for_user(username)
|
||||||
if not db_path:
|
if not message_tables:
|
||||||
return {
|
return {
|
||||||
'query': chat_name,
|
'query': chat_name,
|
||||||
'username': username,
|
'username': username,
|
||||||
'display_name': display_name,
|
'display_name': display_name,
|
||||||
'db_path': None,
|
'db_path': None,
|
||||||
'table_name': None,
|
'table_name': None,
|
||||||
'is_group': '@chatroom' in username,
|
'message_tables': [],
|
||||||
}
|
'is_group': '@chatroom' in username,
|
||||||
|
}
|
||||||
return {
|
|
||||||
'query': chat_name,
|
primary = message_tables[0]
|
||||||
'username': username,
|
return {
|
||||||
'display_name': display_name,
|
'query': chat_name,
|
||||||
'db_path': db_path,
|
'username': username,
|
||||||
'table_name': table_name,
|
'display_name': display_name,
|
||||||
'is_group': '@chatroom' in username,
|
'db_path': primary['db_path'],
|
||||||
}
|
'table_name': primary['table_name'],
|
||||||
|
'message_tables': message_tables,
|
||||||
|
'is_group': '@chatroom' in username,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_chat_contexts(chat_names):
|
def _resolve_chat_contexts(chat_names):
|
||||||
|
|
@ -713,11 +759,11 @@ def _resolve_chat_contexts(chat_names):
|
||||||
if not ctx:
|
if not ctx:
|
||||||
unresolved.append(name)
|
unresolved.append(name)
|
||||||
continue
|
continue
|
||||||
if not ctx['db_path']:
|
if not ctx['message_tables']:
|
||||||
missing_tables.append(ctx['display_name'])
|
missing_tables.append(ctx['display_name'])
|
||||||
continue
|
continue
|
||||||
if ctx['username'] in seen:
|
if ctx['username'] in seen:
|
||||||
continue
|
continue
|
||||||
seen.add(ctx['username'])
|
seen.add(ctx['username'])
|
||||||
resolved.append(ctx)
|
resolved.append(ctx)
|
||||||
|
|
||||||
|
|
@ -743,31 +789,20 @@ def _normalize_chat_names(chat_name):
|
||||||
return [value] if value else []
|
return [value] if value else []
|
||||||
|
|
||||||
|
|
||||||
def _format_history_lines(rows, username, display_name, is_group, names, id_to_username):
|
def _format_history_lines(rows, username, display_name, is_group, names, id_to_username):
|
||||||
lines = []
|
lines = []
|
||||||
for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows):
|
ctx = {
|
||||||
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M')
|
'username': username,
|
||||||
content = _decompress_content(content, ct)
|
'display_name': display_name,
|
||||||
if content is None:
|
'is_group': is_group,
|
||||||
content = '(无法解压)'
|
}
|
||||||
|
for row in reversed(rows):
|
||||||
sender, text = _format_message_text(
|
_, line = _build_history_line(row, ctx, names, id_to_username)
|
||||||
local_id, local_type, content, is_group, username, display_name, names
|
lines.append(line)
|
||||||
)
|
return lines
|
||||||
if text and len(text) > 500:
|
|
||||||
text = text[:500] + '...'
|
|
||||||
|
|
||||||
sender_label = _resolve_sender_label(
|
|
||||||
real_sender_id, sender, is_group, username, display_name, names, id_to_username
|
|
||||||
)
|
|
||||||
if sender_label:
|
|
||||||
lines.append(f'[{time_str}] {sender_label}: {text}')
|
|
||||||
else:
|
|
||||||
lines.append(f'[{time_str}] {text}')
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def _build_search_entry(row, ctx, names, id_to_username):
|
def _build_search_entry(row, ctx, names, id_to_username):
|
||||||
local_id, local_type, create_time, real_sender_id, content, ct = row
|
local_id, local_type, create_time, real_sender_id, content, ct = row
|
||||||
content = _decompress_content(content, ct)
|
content = _decompress_content(content, ct)
|
||||||
if content is None:
|
if content is None:
|
||||||
|
|
@ -792,8 +827,326 @@ def _build_search_entry(row, ctx, names, id_to_username):
|
||||||
entry = f"[{time_str}] [{ctx['display_name']}]"
|
entry = f"[{time_str}] [{ctx['display_name']}]"
|
||||||
if sender_label:
|
if sender_label:
|
||||||
entry += f" {sender_label}:"
|
entry += f" {sender_label}:"
|
||||||
entry += f" {text}"
|
entry += f" {text}"
|
||||||
return create_time, entry
|
return create_time, entry
|
||||||
|
|
||||||
|
|
||||||
|
def _build_history_line(row, ctx, names, id_to_username):
|
||||||
|
local_id, local_type, create_time, real_sender_id, content, ct = row
|
||||||
|
time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M')
|
||||||
|
content = _decompress_content(content, ct)
|
||||||
|
if content is None:
|
||||||
|
content = '(无法解压)'
|
||||||
|
|
||||||
|
sender, text = _format_message_text(
|
||||||
|
local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names
|
||||||
|
)
|
||||||
|
if text and len(text) > 500:
|
||||||
|
text = text[:500] + '...'
|
||||||
|
|
||||||
|
sender_label = _resolve_sender_label(
|
||||||
|
real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username
|
||||||
|
)
|
||||||
|
if sender_label:
|
||||||
|
return create_time, f'[{time_str}] {sender_label}: {text}'
|
||||||
|
return create_time, f'[{time_str}] {text}'
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chat_message_tables(ctx):
|
||||||
|
if ctx.get('message_tables'):
|
||||||
|
return ctx['message_tables']
|
||||||
|
if ctx.get('db_path') and ctx.get('table_name'):
|
||||||
|
return [{'db_path': ctx['db_path'], 'table_name': ctx['table_name']}]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_table_contexts(ctx):
|
||||||
|
for table in _get_chat_message_tables(ctx):
|
||||||
|
yield {
|
||||||
|
'query': ctx['query'],
|
||||||
|
'username': ctx['username'],
|
||||||
|
'display_name': ctx['display_name'],
|
||||||
|
'db_path': table['db_path'],
|
||||||
|
'table_name': table['table_name'],
|
||||||
|
'is_group': ctx['is_group'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _candidate_page_size(limit, offset):
|
||||||
|
return limit + offset
|
||||||
|
|
||||||
|
|
||||||
|
def _message_query_batch_size(candidate_limit):
|
||||||
|
return candidate_limit
|
||||||
|
|
||||||
|
|
||||||
|
def _page_ranked_entries(entries, limit, offset):
|
||||||
|
ordered = sorted(entries, key=lambda item: item[0], reverse=True)
|
||||||
|
paged = ordered[offset:offset + limit]
|
||||||
|
paged.sort(key=lambda item: item[0])
|
||||||
|
return paged
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0):
|
||||||
|
collected = []
|
||||||
|
failures = []
|
||||||
|
candidate_limit = _candidate_page_size(limit, offset)
|
||||||
|
|
||||||
|
for table_ctx in _iter_table_contexts(ctx):
|
||||||
|
try:
|
||||||
|
with closing(sqlite3.connect(table_ctx['db_path'])) as conn:
|
||||||
|
id_to_username = _load_name2id_maps(conn)
|
||||||
|
# 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。
|
||||||
|
rows = _query_messages(
|
||||||
|
conn,
|
||||||
|
table_ctx['table_name'],
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
limit=candidate_limit,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
collected.append(_build_history_line(row, table_ctx, names, id_to_username))
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{table_ctx['db_path']}: {e}")
|
||||||
|
|
||||||
|
paged = _page_ranked_entries(collected, limit, offset)
|
||||||
|
return [line for _, line in paged], failures
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
|
||||||
|
collected = []
|
||||||
|
failures = []
|
||||||
|
contexts_by_db = {}
|
||||||
|
for table_ctx in _iter_table_contexts(ctx):
|
||||||
|
contexts_by_db.setdefault(table_ctx['db_path'], []).append(table_ctx)
|
||||||
|
|
||||||
|
for db_path, db_contexts in contexts_by_db.items():
|
||||||
|
try:
|
||||||
|
with closing(sqlite3.connect(db_path)) as conn:
|
||||||
|
db_entries, db_failures = _collect_search_entries(
|
||||||
|
conn,
|
||||||
|
db_contexts,
|
||||||
|
names,
|
||||||
|
keyword,
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
candidate_limit=candidate_limit,
|
||||||
|
)
|
||||||
|
collected.extend(db_entries)
|
||||||
|
failures.extend(db_failures)
|
||||||
|
except Exception as e:
|
||||||
|
failures.extend(f"{table_ctx['display_name']}: {e}" for table_ctx in db_contexts)
|
||||||
|
|
||||||
|
return collected, failures
|
||||||
|
|
||||||
|
|
||||||
|
def _load_search_contexts_from_db(conn, db_path, names):
|
||||||
|
tables = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'"
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
table_to_username = {}
|
||||||
|
try:
|
||||||
|
for (user_name,) in conn.execute("SELECT user_name FROM Name2Id").fetchall():
|
||||||
|
if not user_name:
|
||||||
|
continue
|
||||||
|
table_hash = hashlib.md5(user_name.encode()).hexdigest()
|
||||||
|
table_to_username[f"Msg_{table_hash}"] = user_name
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
for (table_name,) in tables:
|
||||||
|
username = table_to_username.get(table_name, '')
|
||||||
|
display_name = names.get(username, username) if username else table_name
|
||||||
|
contexts.append({
|
||||||
|
'query': display_name,
|
||||||
|
'username': username,
|
||||||
|
'display_name': display_name,
|
||||||
|
'db_path': db_path,
|
||||||
|
'table_name': table_name,
|
||||||
|
'is_group': '@chatroom' in username,
|
||||||
|
})
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
|
||||||
|
collected = []
|
||||||
|
failures = []
|
||||||
|
id_to_username = _load_name2id_maps(conn)
|
||||||
|
batch_size = _message_query_batch_size(candidate_limit)
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
try:
|
||||||
|
fetch_offset = 0
|
||||||
|
collected_before_table = len(collected)
|
||||||
|
# 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。
|
||||||
|
while len(collected) - collected_before_table < candidate_limit:
|
||||||
|
rows = _query_messages(
|
||||||
|
conn,
|
||||||
|
ctx['table_name'],
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
keyword=keyword,
|
||||||
|
limit=batch_size,
|
||||||
|
offset=fetch_offset,
|
||||||
|
)
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
fetch_offset += len(rows)
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
formatted = _build_search_entry(row, ctx, names, id_to_username)
|
||||||
|
if formatted:
|
||||||
|
collected.append(formatted)
|
||||||
|
if len(collected) - collected_before_table >= candidate_limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{ctx['display_name']}: {e}")
|
||||||
|
|
||||||
|
return collected, failures
|
||||||
|
|
||||||
|
|
||||||
|
def _page_search_entries(entries, limit, offset):
|
||||||
|
return _page_ranked_entries(entries, limit, offset)
|
||||||
|
|
||||||
|
|
||||||
|
def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
|
||||||
|
names = get_contact_names()
|
||||||
|
candidate_limit = _candidate_page_size(limit, offset)
|
||||||
|
|
||||||
|
entries, failures = _collect_chat_search_entries(
|
||||||
|
ctx,
|
||||||
|
names,
|
||||||
|
keyword,
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
candidate_limit=candidate_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
paged = _page_search_entries(entries, limit, offset)
|
||||||
|
|
||||||
|
if not paged:
|
||||||
|
if failures:
|
||||||
|
return "查询失败: " + ";".join(failures)
|
||||||
|
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息"
|
||||||
|
|
||||||
|
header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(paged)} 条结果(offset={offset}, limit={limit})"
|
||||||
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
|
if failures:
|
||||||
|
header += "\n查询失败: " + ";".join(failures)
|
||||||
|
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
|
||||||
|
|
||||||
|
|
||||||
|
def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
|
||||||
|
try:
|
||||||
|
resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names)
|
||||||
|
except ValueError as e:
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
if not resolved_contexts:
|
||||||
|
details = []
|
||||||
|
if unresolved:
|
||||||
|
details.append("未找到联系人: " + "、".join(unresolved))
|
||||||
|
if missing_tables:
|
||||||
|
details.append("无消息表: " + "、".join(missing_tables))
|
||||||
|
suffix = f"\n{chr(10).join(details)}" if details else ""
|
||||||
|
return f"错误: 没有可查询的聊天对象{suffix}"
|
||||||
|
|
||||||
|
names = get_contact_names()
|
||||||
|
candidate_limit = _candidate_page_size(limit, offset)
|
||||||
|
collected = []
|
||||||
|
failures = []
|
||||||
|
for ctx in resolved_contexts:
|
||||||
|
chat_entries, chat_failures = _collect_chat_search_entries(
|
||||||
|
ctx,
|
||||||
|
names,
|
||||||
|
keyword,
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
candidate_limit=candidate_limit,
|
||||||
|
)
|
||||||
|
collected.extend(chat_entries)
|
||||||
|
failures.extend(chat_failures)
|
||||||
|
|
||||||
|
paged = _page_search_entries(collected, limit, offset)
|
||||||
|
|
||||||
|
notes = []
|
||||||
|
if unresolved:
|
||||||
|
notes.append("未找到联系人: " + "、".join(unresolved))
|
||||||
|
if missing_tables:
|
||||||
|
notes.append("无消息表: " + "、".join(missing_tables))
|
||||||
|
if failures:
|
||||||
|
notes.append("查询失败: " + ";".join(failures))
|
||||||
|
|
||||||
|
if not paged:
|
||||||
|
header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息"
|
||||||
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
|
if notes:
|
||||||
|
header += "\n" + "\n".join(notes)
|
||||||
|
return header
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果"
|
||||||
|
f"(offset={offset}, limit={limit})"
|
||||||
|
)
|
||||||
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
|
if notes:
|
||||||
|
header += "\n" + "\n".join(notes)
|
||||||
|
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
|
||||||
|
|
||||||
|
|
||||||
|
def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, offset):
|
||||||
|
names = get_contact_names()
|
||||||
|
collected = []
|
||||||
|
failures = []
|
||||||
|
candidate_limit = _candidate_page_size(limit, offset)
|
||||||
|
|
||||||
|
for rel_key in MSG_DB_KEYS:
|
||||||
|
path = _cache.get(rel_key)
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with closing(sqlite3.connect(path)) as conn:
|
||||||
|
contexts = _load_search_contexts_from_db(conn, path, names)
|
||||||
|
db_entries, db_failures = _collect_search_entries(
|
||||||
|
conn,
|
||||||
|
contexts,
|
||||||
|
names,
|
||||||
|
keyword,
|
||||||
|
start_ts=start_ts,
|
||||||
|
end_ts=end_ts,
|
||||||
|
candidate_limit=candidate_limit,
|
||||||
|
)
|
||||||
|
collected.extend(db_entries)
|
||||||
|
failures.extend(db_failures)
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{rel_key}: {e}")
|
||||||
|
|
||||||
|
paged = _page_search_entries(collected, limit, offset)
|
||||||
|
|
||||||
|
if not paged:
|
||||||
|
header = f"未找到包含 \"{keyword}\" 的消息"
|
||||||
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
|
if failures:
|
||||||
|
header += "\n查询失败: " + ";".join(failures)
|
||||||
|
return header
|
||||||
|
|
||||||
|
header = f"搜索 \"{keyword}\" 找到 {len(paged)} 条结果(offset={offset}, limit={limit})"
|
||||||
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
|
if failures:
|
||||||
|
header += "\n查询失败: " + ";".join(failures)
|
||||||
|
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
|
||||||
|
|
||||||
|
|
||||||
# ============ MCP Server ============
|
# ============ MCP Server ============
|
||||||
|
|
@ -805,7 +1158,7 @@ _last_check_state = {} # {username: last_timestamp}
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def get_recent_sessions(limit: int = 20) -> str:
|
def get_recent_sessions(limit: int = 20) -> str:
|
||||||
"""获取微信最近会话列表,包含最新消息摘要、未读数、时间等。
|
"""获取微信最近会话列表,包含最新消息摘要、未读数、时间等。
|
||||||
用于了解最近有哪些人/群在聊天。
|
用于了解最近有哪些人/群在聊天。
|
||||||
|
|
||||||
|
|
@ -817,16 +1170,15 @@ def get_recent_sessions(limit: int = 20) -> str:
|
||||||
return "错误: 无法解密 session.db"
|
return "错误: 无法解密 session.db"
|
||||||
|
|
||||||
names = get_contact_names()
|
names = get_contact_names()
|
||||||
conn = sqlite3.connect(path)
|
with closing(sqlite3.connect(path)) as conn:
|
||||||
rows = conn.execute("""
|
rows = conn.execute("""
|
||||||
SELECT username, unread_count, summary, last_timestamp,
|
SELECT username, unread_count, summary, last_timestamp,
|
||||||
last_msg_type, last_msg_sender, last_sender_display_name
|
last_msg_type, last_msg_sender, last_sender_display_name
|
||||||
FROM SessionTable
|
FROM SessionTable
|
||||||
WHERE last_timestamp > 0
|
WHERE last_timestamp > 0
|
||||||
ORDER BY last_timestamp DESC
|
ORDER BY last_timestamp DESC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
""", (limit,)).fetchall()
|
""", (limit,)).fetchall()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for r in rows:
|
for r in rows:
|
||||||
|
|
@ -864,15 +1216,15 @@ def get_recent_sessions(limit: int = 20) -> str:
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str:
|
def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str:
|
||||||
"""获取指定聊天的消息记录。
|
"""获取指定聊天的消息记录。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配
|
chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配
|
||||||
limit: 返回的消息数量,默认50
|
limit: 返回的消息数量,默认50,最大500
|
||||||
offset: 分页偏移量,默认0
|
offset: 分页偏移量,默认0
|
||||||
start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
|
start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
|
||||||
end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
|
end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
_validate_pagination(limit, offset)
|
_validate_pagination(limit, offset)
|
||||||
|
|
@ -886,41 +1238,29 @@ def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_tim
|
||||||
if not ctx['db_path']:
|
if not ctx['db_path']:
|
||||||
return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)"
|
return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)"
|
||||||
|
|
||||||
names = get_contact_names()
|
names = get_contact_names()
|
||||||
conn = sqlite3.connect(ctx['db_path'])
|
lines, failures = _collect_chat_history_lines(
|
||||||
try:
|
ctx,
|
||||||
id_to_username = _load_name2id_maps(conn)
|
names,
|
||||||
rows = _query_messages(
|
start_ts=start_ts,
|
||||||
conn,
|
end_ts=end_ts,
|
||||||
ctx['table_name'],
|
limit=limit,
|
||||||
start_ts=start_ts,
|
offset=offset,
|
||||||
end_ts=end_ts,
|
)
|
||||||
limit=limit,
|
|
||||||
offset=offset,
|
if not lines:
|
||||||
)
|
if failures:
|
||||||
except Exception as e:
|
return "查询失败: " + ";".join(failures)
|
||||||
conn.close()
|
return f"{ctx['display_name']} 无消息记录"
|
||||||
return f"查询失败: {e}"
|
|
||||||
conn.close()
|
header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})"
|
||||||
|
if ctx['is_group']:
|
||||||
if not rows:
|
header += " [群聊]"
|
||||||
return f"{ctx['display_name']} 无消息记录"
|
if start_time or end_time:
|
||||||
|
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
||||||
lines = _format_history_lines(
|
if failures:
|
||||||
rows,
|
header += "\n查询失败: " + ";".join(failures)
|
||||||
ctx['username'],
|
return header + ":\n\n" + "\n".join(lines)
|
||||||
ctx['display_name'],
|
|
||||||
ctx['is_group'],
|
|
||||||
names,
|
|
||||||
id_to_username,
|
|
||||||
)
|
|
||||||
|
|
||||||
header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})"
|
|
||||||
if ctx['is_group']:
|
|
||||||
header += " [群聊]"
|
|
||||||
if start_time or end_time:
|
|
||||||
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
|
||||||
return header + ":\n\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
|
|
@ -939,7 +1279,7 @@ def search_messages(
|
||||||
chat_name: 聊天对象名称,可为空、单个字符串或字符串列表
|
chat_name: 聊天对象名称,可为空、单个字符串或字符串列表
|
||||||
start_time: 起始时间,可为空
|
start_time: 起始时间,可为空
|
||||||
end_time: 结束时间,可为空
|
end_time: 结束时间,可为空
|
||||||
limit: 返回的结果数量,默认20
|
limit: 返回的结果数量,默认20,最大500
|
||||||
offset: 分页偏移量,默认0
|
offset: 分页偏移量,默认0
|
||||||
"""
|
"""
|
||||||
if not keyword or len(keyword) < 1:
|
if not keyword or len(keyword) < 1:
|
||||||
|
|
@ -959,193 +1299,38 @@ def search_messages(
|
||||||
return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人"
|
return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人"
|
||||||
if not ctx['db_path']:
|
if not ctx['db_path']:
|
||||||
return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)"
|
return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)"
|
||||||
|
return _search_single_chat(
|
||||||
names = get_contact_names()
|
ctx,
|
||||||
conn = sqlite3.connect(ctx['db_path'])
|
keyword,
|
||||||
try:
|
start_ts,
|
||||||
id_to_username = _load_name2id_maps(conn)
|
end_ts,
|
||||||
rows = _query_messages(
|
start_time,
|
||||||
conn,
|
end_time,
|
||||||
ctx['table_name'],
|
limit,
|
||||||
start_ts=start_ts,
|
offset,
|
||||||
end_ts=end_ts,
|
)
|
||||||
keyword=keyword,
|
|
||||||
limit=limit,
|
|
||||||
offset=offset,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
conn.close()
|
|
||||||
return f"查询失败: {e}"
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息"
|
|
||||||
|
|
||||||
entries = []
|
|
||||||
for row in rows:
|
|
||||||
formatted = _build_search_entry(row, ctx, names, id_to_username)
|
|
||||||
if formatted:
|
|
||||||
entries.append(formatted)
|
|
||||||
|
|
||||||
if not entries:
|
|
||||||
return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的可读消息"
|
|
||||||
|
|
||||||
entries.sort(key=lambda x: x[0])
|
|
||||||
header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})"
|
|
||||||
if start_time or end_time:
|
|
||||||
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
|
||||||
return header + ":\n\n" + "\n\n".join(item[1] for item in entries)
|
|
||||||
|
|
||||||
if len(chat_names) > 1:
|
if len(chat_names) > 1:
|
||||||
try:
|
return _search_multiple_chats(
|
||||||
resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names)
|
chat_names,
|
||||||
except ValueError as e:
|
keyword,
|
||||||
return f"错误: {e}"
|
start_ts,
|
||||||
|
end_ts,
|
||||||
if not resolved_contexts:
|
start_time,
|
||||||
details = []
|
end_time,
|
||||||
if unresolved:
|
limit,
|
||||||
details.append("未找到联系人: " + "、".join(unresolved))
|
offset,
|
||||||
if missing_tables:
|
|
||||||
details.append("无消息表: " + "、".join(missing_tables))
|
|
||||||
suffix = f"\n{chr(10).join(details)}" if details else ""
|
|
||||||
return f"错误: 没有可查询的聊天对象{suffix}"
|
|
||||||
|
|
||||||
names = get_contact_names()
|
|
||||||
collected = []
|
|
||||||
failures = []
|
|
||||||
per_chat_limit = limit + offset
|
|
||||||
|
|
||||||
for ctx in resolved_contexts:
|
|
||||||
conn = sqlite3.connect(ctx['db_path'])
|
|
||||||
try:
|
|
||||||
id_to_username = _load_name2id_maps(conn)
|
|
||||||
rows = _query_messages(
|
|
||||||
conn,
|
|
||||||
ctx['table_name'],
|
|
||||||
start_ts=start_ts,
|
|
||||||
end_ts=end_ts,
|
|
||||||
keyword=keyword,
|
|
||||||
limit=per_chat_limit,
|
|
||||||
offset=0,
|
|
||||||
)
|
|
||||||
for row in rows:
|
|
||||||
formatted = _build_search_entry(row, ctx, names, id_to_username)
|
|
||||||
if formatted:
|
|
||||||
collected.append(formatted)
|
|
||||||
except Exception as e:
|
|
||||||
failures.append(f"{ctx['display_name']}: {e}")
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
collected.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
paged = collected[offset:offset + limit]
|
|
||||||
|
|
||||||
notes = []
|
|
||||||
if unresolved:
|
|
||||||
notes.append("未找到联系人: " + "、".join(unresolved))
|
|
||||||
if missing_tables:
|
|
||||||
notes.append("无消息表: " + "、".join(missing_tables))
|
|
||||||
if failures:
|
|
||||||
notes.append("查询失败: " + ";".join(failures))
|
|
||||||
|
|
||||||
if not paged:
|
|
||||||
header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息"
|
|
||||||
if start_time or end_time:
|
|
||||||
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
|
||||||
if notes:
|
|
||||||
header += "\n" + "\n".join(notes)
|
|
||||||
return header
|
|
||||||
|
|
||||||
header = (
|
|
||||||
f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果"
|
|
||||||
f"(offset={offset}, limit={limit})"
|
|
||||||
)
|
)
|
||||||
if start_time or end_time:
|
|
||||||
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
|
||||||
if notes:
|
|
||||||
header += "\n" + "\n".join(notes)
|
|
||||||
return header + ":\n\n" + "\n\n".join(item[1] for item in paged)
|
|
||||||
|
|
||||||
names = get_contact_names()
|
return _search_all_messages(
|
||||||
results = []
|
keyword,
|
||||||
max_results = limit + offset
|
start_ts,
|
||||||
|
end_ts,
|
||||||
for rel_key in MSG_DB_KEYS:
|
start_time,
|
||||||
if len(results) >= max_results:
|
end_time,
|
||||||
break
|
limit,
|
||||||
|
offset,
|
||||||
path = _cache.get(rel_key)
|
)
|
||||||
if not path:
|
|
||||||
continue
|
|
||||||
|
|
||||||
conn = sqlite3.connect(path)
|
|
||||||
try:
|
|
||||||
# 获取所有 Msg_ 表
|
|
||||||
tables = conn.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'"
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
# 获取 Name2Id 映射(hash -> username 反查)
|
|
||||||
name2id = {}
|
|
||||||
try:
|
|
||||||
for r in conn.execute("SELECT user_name FROM Name2Id").fetchall():
|
|
||||||
h = hashlib.md5(r[0].encode()).hexdigest()
|
|
||||||
name2id[f"Msg_{h}"] = r[0]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for (tname,) in tables:
|
|
||||||
if len(results) >= max_results:
|
|
||||||
break
|
|
||||||
username = name2id.get(tname, '')
|
|
||||||
is_group = '@chatroom' in username
|
|
||||||
display = names.get(username, username) if username else tname
|
|
||||||
|
|
||||||
try:
|
|
||||||
clauses, params = _build_message_filters(start_ts, end_ts, keyword)
|
|
||||||
where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else ''
|
|
||||||
rows = conn.execute(f"""
|
|
||||||
SELECT local_type, create_time, message_content,
|
|
||||||
WCDB_CT_message_content
|
|
||||||
FROM [{tname}]
|
|
||||||
{where_sql}
|
|
||||||
ORDER BY create_time DESC
|
|
||||||
LIMIT ? OFFSET ?
|
|
||||||
""", (*params, max_results - len(results), 0)).fetchall()
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for local_type, ts, content, ct in rows:
|
|
||||||
content = _decompress_content(content, ct)
|
|
||||||
if content is None:
|
|
||||||
continue
|
|
||||||
sender, text = _parse_message_content(content, local_type, is_group)
|
|
||||||
time_str = datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M')
|
|
||||||
sender_name = ''
|
|
||||||
if is_group and sender:
|
|
||||||
sender_name = names.get(sender, sender)
|
|
||||||
|
|
||||||
entry = f"[{time_str}] [{display}]"
|
|
||||||
if sender_name:
|
|
||||||
entry += f" {sender_name}:"
|
|
||||||
entry += f" {text}"
|
|
||||||
if len(entry) > 300:
|
|
||||||
entry = entry[:300] + "..."
|
|
||||||
results.append((ts, entry))
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
results.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
entries = [r[1] for r in results[offset:offset + limit]]
|
|
||||||
|
|
||||||
if not entries:
|
|
||||||
return f"未找到包含 \"{keyword}\" 的消息"
|
|
||||||
|
|
||||||
header = f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})"
|
|
||||||
if start_time or end_time:
|
|
||||||
header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}"
|
|
||||||
return header + ":\n\n" + "\n\n".join(entries)
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def get_contacts(query: str = "", limit: int = 50) -> str:
|
def get_contacts(query: str = "", limit: int = 50) -> str:
|
||||||
|
|
@ -1191,7 +1376,7 @@ def get_contacts(query: str = "", limit: int = 50) -> str:
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
def get_new_messages() -> str:
|
def get_new_messages() -> str:
|
||||||
"""获取自上次调用以来的新消息。首次调用返回最近的会话状态。"""
|
"""获取自上次调用以来的新消息。首次调用返回最近的会话状态。"""
|
||||||
global _last_check_state
|
global _last_check_state
|
||||||
|
|
||||||
|
|
@ -1200,15 +1385,14 @@ def get_new_messages() -> str:
|
||||||
return "错误: 无法解密 session.db"
|
return "错误: 无法解密 session.db"
|
||||||
|
|
||||||
names = get_contact_names()
|
names = get_contact_names()
|
||||||
conn = sqlite3.connect(path)
|
with closing(sqlite3.connect(path)) as conn:
|
||||||
rows = conn.execute("""
|
rows = conn.execute("""
|
||||||
SELECT username, unread_count, summary, last_timestamp,
|
SELECT username, unread_count, summary, last_timestamp,
|
||||||
last_msg_type, last_msg_sender, last_sender_display_name
|
last_msg_type, last_msg_sender, last_sender_display_name
|
||||||
FROM SessionTable
|
FROM SessionTable
|
||||||
WHERE last_timestamp > 0
|
WHERE last_timestamp > 0
|
||||||
ORDER BY last_timestamp DESC
|
ORDER BY last_timestamp DESC
|
||||||
""").fetchall()
|
""").fetchall()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
curr_state = {}
|
curr_state = {}
|
||||||
for r in rows:
|
for r in rows:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,590 @@
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import mcp_server
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCache:
|
||||||
|
# 用最小缓存桩替代真实解密缓存,避免单元测试依赖本地微信环境。
|
||||||
|
def __init__(self, mapping):
|
||||||
|
self._mapping = mapping
|
||||||
|
|
||||||
|
def get(self, rel_key):
|
||||||
|
return self._mapping.get(rel_key)
|
||||||
|
|
||||||
|
|
||||||
|
def _msg_table_name(username):
|
||||||
|
# 生产代码使用 username 的 md5 作为消息表名,测试里保持一致。
|
||||||
|
return f"Msg_{hashlib.md5(username.encode()).hexdigest()}"
|
||||||
|
|
||||||
|
|
||||||
|
def _create_message_db(path, chats):
|
||||||
|
# 构造最小可用消息库,只包含搜索/历史查询依赖的字段。
|
||||||
|
conn = sqlite3.connect(path)
|
||||||
|
try:
|
||||||
|
conn.execute("CREATE TABLE Name2Id (user_name TEXT)")
|
||||||
|
for username, messages in chats.items():
|
||||||
|
conn.execute("INSERT INTO Name2Id(user_name) VALUES (?)", (username,))
|
||||||
|
table_name = _msg_table_name(username)
|
||||||
|
conn.execute(
|
||||||
|
f"""
|
||||||
|
CREATE TABLE [{table_name}] (
|
||||||
|
local_id INTEGER,
|
||||||
|
local_type INTEGER,
|
||||||
|
create_time INTEGER,
|
||||||
|
real_sender_id INTEGER,
|
||||||
|
message_content TEXT,
|
||||||
|
WCDB_CT_message_content INTEGER
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
for local_id, create_time, content in messages:
|
||||||
|
conn.execute(
|
||||||
|
f"""
|
||||||
|
INSERT INTO [{table_name}] (
|
||||||
|
local_id, local_type, create_time, real_sender_id,
|
||||||
|
message_content, WCDB_CT_message_content
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(local_id, 1, create_time, 0, content, 0),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
class SearchMessagesTests(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.addCleanup(self.temp_dir.cleanup)
|
||||||
|
|
||||||
|
def create_db(self, filename, chats):
|
||||||
|
path = os.path.join(self.temp_dir.name, filename)
|
||||||
|
_create_message_db(path, chats)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def test_validate_pagination_rejects_large_limit(self):
|
||||||
|
# 防止单次查询过大,保证 limit 上限校验存在。
|
||||||
|
with self.assertRaisesRegex(ValueError, "limit 不能大于 500"):
|
||||||
|
mcp_server._validate_pagination(501, 0)
|
||||||
|
|
||||||
|
def test_page_search_entries_returns_chronological_results_with_offset(self):
|
||||||
|
# 结果应先按最新时间分页,再把当前页恢复成时间正序输出。
|
||||||
|
entries = [(1, "a"), (5, "e"), (3, "c"), (4, "d"), (2, "b")]
|
||||||
|
|
||||||
|
paged = mcp_server._page_search_entries(entries, limit=2, offset=1)
|
||||||
|
|
||||||
|
self.assertEqual(paged, [(3, "c"), (4, "d")])
|
||||||
|
|
||||||
|
def test_search_messages_single_chat_uses_offset_and_returns_page(self):
|
||||||
|
# 单聊分页应只返回当前页,并按聊天阅读顺序展示。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"single.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 100, "foo newest"),
|
||||||
|
(2, 90, "foo middle"),
|
||||||
|
(3, 80, "foo oldest"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages("foo", chat_name="Alice", limit=2, offset=1)
|
||||||
|
|
||||||
|
self.assertIn('在 Alice 中搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result)
|
||||||
|
self.assertLess(result.index("foo oldest"), result.index("foo middle"))
|
||||||
|
self.assertNotIn("foo newest", result)
|
||||||
|
|
||||||
|
def test_search_messages_multiple_chats_applies_global_pagination(self):
|
||||||
|
# 多个聊天联合搜索时,分页必须基于合并后的全局结果。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"multi.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 110, "foo a1"),
|
||||||
|
(2, 90, "foo a2"),
|
||||||
|
],
|
||||||
|
"bob": [
|
||||||
|
(1, 100, "foo b1"),
|
||||||
|
(2, 80, "foo b2"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
contexts = [
|
||||||
|
{
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Bob",
|
||||||
|
"username": "bob",
|
||||||
|
"display_name": "Bob",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("bob"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
|
||||||
|
"is_group": False,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages("foo", chat_name=["Alice", "Bob"], limit=2, offset=1)
|
||||||
|
|
||||||
|
self.assertIn('在 2 个聊天对象中搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result)
|
||||||
|
self.assertLess(result.index("foo a2"), result.index("foo b1"))
|
||||||
|
self.assertNotIn("foo a1", result)
|
||||||
|
self.assertNotIn("foo b2", result)
|
||||||
|
|
||||||
|
def test_search_messages_all_messages_merges_global_results_before_paging(self):
|
||||||
|
# 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。
|
||||||
|
older_db = self.create_db(
|
||||||
|
"older.db",
|
||||||
|
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
|
||||||
|
)
|
||||||
|
newer_db = self.create_db(
|
||||||
|
"newer.db",
|
||||||
|
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2")]},
|
||||||
|
)
|
||||||
|
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
|
||||||
|
mcp_server, "_cache", fake_cache
|
||||||
|
), patch.object(
|
||||||
|
mcp_server,
|
||||||
|
"get_contact_names",
|
||||||
|
return_value={"older_user": "Older", "newer_user": "Newer"},
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages("foo", limit=2, offset=0)
|
||||||
|
|
||||||
|
self.assertIn('搜索 "foo" 找到 2 条结果(offset=0, limit=2)', result)
|
||||||
|
self.assertLess(result.index("foo newer 2"), result.index("foo newer 1"))
|
||||||
|
self.assertNotIn("foo older 1", result)
|
||||||
|
|
||||||
|
def test_search_messages_all_messages_uses_bounded_sql_pagination(self):
|
||||||
|
# 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。
|
||||||
|
older_db = self.create_db(
|
||||||
|
"older_paged.db",
|
||||||
|
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
|
||||||
|
)
|
||||||
|
newer_db = self.create_db(
|
||||||
|
"newer_paged.db",
|
||||||
|
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]},
|
||||||
|
)
|
||||||
|
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
|
||||||
|
original_query_messages = mcp_server._query_messages
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def recording_query_messages(*args, **kwargs):
|
||||||
|
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
|
||||||
|
return original_query_messages(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
|
||||||
|
mcp_server, "_cache", fake_cache
|
||||||
|
), patch.object(
|
||||||
|
mcp_server,
|
||||||
|
"get_contact_names",
|
||||||
|
return_value={"older_user": "Older", "newer_user": "Newer"},
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_query_messages", side_effect=recording_query_messages
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages("foo", limit=2, offset=1)
|
||||||
|
|
||||||
|
self.assertIn('搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result)
|
||||||
|
self.assertEqual(
|
||||||
|
calls,
|
||||||
|
[
|
||||||
|
(_msg_table_name("older_user"), 3, 0),
|
||||||
|
(_msg_table_name("newer_user"), 3, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search_messages_single_chat_respects_time_range(self):
|
||||||
|
# 单聊搜索的开始/结束时间都必须严格生效。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"single_time.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 300, "foo in range"),
|
||||||
|
(2, 200, "foo too early"),
|
||||||
|
(3, 400, "foo too late"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_parse_time_range", return_value=(250, 350)
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages(
|
||||||
|
"foo",
|
||||||
|
chat_name="Alice",
|
||||||
|
start_time="custom-start",
|
||||||
|
end_time="custom-end",
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("时间范围: custom-start ~ custom-end", result)
|
||||||
|
self.assertIn("foo in range", result)
|
||||||
|
self.assertNotIn("foo too early", result)
|
||||||
|
self.assertNotIn("foo too late", result)
|
||||||
|
|
||||||
|
def test_search_messages_multiple_chats_respects_time_range(self):
|
||||||
|
# 多聊联合搜索时,每个聊天对象都要套用同一时间范围。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"multi_time.db",
|
||||||
|
{
|
||||||
|
"alice": [(1, 300, "foo alice in range"), (2, 150, "foo alice too early")],
|
||||||
|
"bob": [(1, 320, "foo bob in range"), (2, 500, "foo bob too late")],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
contexts = [
|
||||||
|
{
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "Bob",
|
||||||
|
"username": "bob",
|
||||||
|
"display_name": "Bob",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("bob"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("bob")}],
|
||||||
|
"is_group": False,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice", "bob": "Bob"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_contexts", return_value=(contexts, [], [])
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_parse_time_range", return_value=(250, 400)
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages(
|
||||||
|
"foo",
|
||||||
|
chat_name=["Alice", "Bob"],
|
||||||
|
start_time="range-start",
|
||||||
|
end_time="range-end",
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("时间范围: range-start ~ range-end", result)
|
||||||
|
self.assertIn("foo alice in range", result)
|
||||||
|
self.assertIn("foo bob in range", result)
|
||||||
|
self.assertNotIn("foo alice too early", result)
|
||||||
|
self.assertNotIn("foo bob too late", result)
|
||||||
|
|
||||||
|
def test_search_messages_all_messages_respects_time_range(self):
|
||||||
|
# 全库搜索也不能返回时间范围外的消息。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"all_time.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 100, "foo too early"),
|
||||||
|
(2, 300, "foo in range"),
|
||||||
|
(3, 500, "foo too late"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fake_cache = _FakeCache({"all": db_path})
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "MSG_DB_KEYS", ["all"]), patch.object(
|
||||||
|
mcp_server, "_cache", fake_cache
|
||||||
|
), patch.object(
|
||||||
|
mcp_server,
|
||||||
|
"get_contact_names",
|
||||||
|
return_value={"alice": "Alice"},
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_parse_time_range", return_value=(250, 350)
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages(
|
||||||
|
"foo",
|
||||||
|
start_time="range-start",
|
||||||
|
end_time="range-end",
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("时间范围: range-start ~ range-end", result)
|
||||||
|
self.assertIn("foo in range", result)
|
||||||
|
self.assertNotIn("foo too early", result)
|
||||||
|
self.assertNotIn("foo too late", result)
|
||||||
|
|
||||||
|
def test_get_chat_history_merges_sharded_message_tables(self):
|
||||||
|
# 同一联系人跨多个 message_N.db 分片时,历史查询要先合并再分页。
|
||||||
|
older_db = self.create_db("history_older.db", {"alice": [(1, 100, "old message")]})
|
||||||
|
newer_db = self.create_db(
|
||||||
|
"history_newer.db",
|
||||||
|
{"alice": [(1, 300, "new message"), (2, 250, "middle message")]},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": newer_db,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [
|
||||||
|
{"db_path": older_db, "table_name": _msg_table_name("alice")},
|
||||||
|
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
|
||||||
|
],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
):
|
||||||
|
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
|
||||||
|
|
||||||
|
self.assertIn("Alice 的消息记录(返回 2 条,offset=0, limit=2)", result)
|
||||||
|
self.assertIn("middle message", result)
|
||||||
|
self.assertIn("new message", result)
|
||||||
|
self.assertNotIn("old message", result)
|
||||||
|
|
||||||
|
def test_get_chat_history_uses_bounded_sql_pagination(self):
|
||||||
|
# 历史查询应把 offset+limit 下推到 SQL,避免把整张消息表读出来后再切片。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"history_paged.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 400, "newest"),
|
||||||
|
(2, 300, "middle"),
|
||||||
|
(3, 200, "older"),
|
||||||
|
(4, 100, "oldest"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
original_query_messages = mcp_server._query_messages
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def recording_query_messages(*args, **kwargs):
|
||||||
|
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
|
||||||
|
return original_query_messages(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_query_messages", side_effect=recording_query_messages
|
||||||
|
):
|
||||||
|
result = mcp_server.get_chat_history("Alice", limit=2, offset=1)
|
||||||
|
|
||||||
|
self.assertIn("middle", result)
|
||||||
|
self.assertIn("older", result)
|
||||||
|
self.assertNotIn("newest", result)
|
||||||
|
self.assertNotIn("oldest", result)
|
||||||
|
self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)])
|
||||||
|
|
||||||
|
def test_get_chat_history_keeps_partial_results_when_formatting_fails(self):
|
||||||
|
# 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"history_partial_failure.db",
|
||||||
|
{"alice": [(1, 200, "good message"), (2, 100, "bad message")]},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
original_build_history_line = mcp_server._build_history_line
|
||||||
|
|
||||||
|
def flaky_build_history_line(row, *args, **kwargs):
|
||||||
|
if row[2] == 100:
|
||||||
|
raise ValueError("bad row")
|
||||||
|
return original_build_history_line(row, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_build_history_line", side_effect=flaky_build_history_line
|
||||||
|
):
|
||||||
|
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
|
||||||
|
|
||||||
|
self.assertIn("good message", result)
|
||||||
|
self.assertIn("查询失败:", result)
|
||||||
|
self.assertIn("bad row", result)
|
||||||
|
|
||||||
|
def test_search_messages_single_chat_merges_sharded_message_tables(self):
|
||||||
|
# 单聊搜索也要跨分片合并,否则最近消息可能查不到。
|
||||||
|
older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]})
|
||||||
|
newer_db = self.create_db(
|
||||||
|
"search_newer.db",
|
||||||
|
{"alice": [(1, 300, "foo new"), (2, 200, "foo middle")]},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": newer_db,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [
|
||||||
|
{"db_path": older_db, "table_name": _msg_table_name("alice")},
|
||||||
|
{"db_path": newer_db, "table_name": _msg_table_name("alice")},
|
||||||
|
],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_parse_time_range", return_value=(150, 350)
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages(
|
||||||
|
"foo",
|
||||||
|
chat_name="Alice",
|
||||||
|
start_time="range-start",
|
||||||
|
end_time="range-end",
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("foo middle", result)
|
||||||
|
self.assertIn("foo new", result)
|
||||||
|
self.assertNotIn("foo old", result)
|
||||||
|
|
||||||
|
def test_search_messages_keeps_partial_results_when_later_batch_fails(self):
|
||||||
|
# 后续批次失败时,前面已经拿到的有效结果不应被丢弃。
|
||||||
|
db_path = self.create_db(
|
||||||
|
"search_partial_failure.db",
|
||||||
|
{
|
||||||
|
"alice": [
|
||||||
|
(1, 400, "foo newest"),
|
||||||
|
(2, 300, "foo skipped"),
|
||||||
|
(3, 200, "foo older"),
|
||||||
|
(4, 100, "foo bad"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ctx = {
|
||||||
|
"query": "Alice",
|
||||||
|
"username": "alice",
|
||||||
|
"display_name": "Alice",
|
||||||
|
"db_path": db_path,
|
||||||
|
"table_name": _msg_table_name("alice"),
|
||||||
|
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||||
|
"is_group": False,
|
||||||
|
}
|
||||||
|
original_build_search_entry = mcp_server._build_search_entry
|
||||||
|
|
||||||
|
def flaky_build_search_entry(row, *args, **kwargs):
|
||||||
|
if row[2] == 300:
|
||||||
|
return None
|
||||||
|
if row[2] == 100:
|
||||||
|
raise ValueError("bad row")
|
||||||
|
return original_build_search_entry(row, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||||
|
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||||
|
), patch.object(
|
||||||
|
mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry
|
||||||
|
):
|
||||||
|
result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0)
|
||||||
|
|
||||||
|
self.assertIn("foo newest", result)
|
||||||
|
self.assertIn("foo older", result)
|
||||||
|
self.assertIn("查询失败:", result)
|
||||||
|
self.assertIn("bad row", result)
|
||||||
|
|
||||||
|
def test_get_recent_sessions_closes_connection_when_query_fails(self):
|
||||||
|
# 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。
|
||||||
|
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
|
||||||
|
|
||||||
|
class _FakeConn:
|
||||||
|
def __init__(self):
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
raise sqlite3.OperationalError("boom")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
fake_conn = _FakeConn()
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
|
||||||
|
mcp_server, "get_contact_names", return_value={}
|
||||||
|
), patch.object(
|
||||||
|
mcp_server.sqlite3, "connect", return_value=fake_conn
|
||||||
|
):
|
||||||
|
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
|
||||||
|
mcp_server.get_recent_sessions()
|
||||||
|
|
||||||
|
self.assertTrue(fake_conn.closed)
|
||||||
|
|
||||||
|
def test_get_new_messages_closes_connection_when_query_fails(self):
|
||||||
|
# 新消息轮询失败时也要释放 sqlite3 连接。
|
||||||
|
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
|
||||||
|
|
||||||
|
class _FakeConn:
|
||||||
|
def __init__(self):
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
raise sqlite3.OperationalError("boom")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
fake_conn = _FakeConn()
|
||||||
|
|
||||||
|
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
|
||||||
|
mcp_server, "get_contact_names", return_value={}
|
||||||
|
), patch.object(
|
||||||
|
mcp_server.sqlite3, "connect", return_value=fake_conn
|
||||||
|
):
|
||||||
|
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
|
||||||
|
mcp_server.get_new_messages()
|
||||||
|
|
||||||
|
self.assertTrue(fake_conn.closed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue