From 9ae558a31e363645658f201473cdbf358bb148d4 Mon Sep 17 00:00:00 2001 From: dsjzazs Date: Sat, 14 Mar 2026 16:36:55 +0800 Subject: [PATCH] Fix global search pagination --- mcp_server.py | 119 +++++++++++------- tests/test_mcp_server_search.py | 212 +++++++++++++++++++++++++++++++- 2 files changed, 286 insertions(+), 45 deletions(-) diff --git a/mcp_server.py b/mcp_server.py index be0abac..f1c1547 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -872,33 +872,49 @@ def _iter_table_contexts(ctx): } +def _candidate_page_size(limit, offset): + return limit + offset + + +def _message_query_batch_size(candidate_limit): + return candidate_limit + + +def _page_ranked_entries(entries, limit, offset): + ordered = sorted(entries, key=lambda item: item[0], reverse=True) + paged = ordered[offset:offset + limit] + paged.sort(key=lambda item: item[0]) + return paged + + def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0): collected = [] failures = [] + candidate_limit = _candidate_page_size(limit, offset) for table_ctx in _iter_table_contexts(ctx): try: with closing(sqlite3.connect(table_ctx['db_path'])) as conn: id_to_username = _load_name2id_maps(conn) + # 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。 rows = _query_messages( conn, table_ctx['table_name'], start_ts=start_ts, end_ts=end_ts, - limit=None, + limit=candidate_limit, + offset=0, ) for row in rows: collected.append(_build_history_line(row, table_ctx, names, id_to_username)) except Exception as e: failures.append(f"{table_ctx['db_path']}: {e}") - ordered = sorted(collected, key=lambda item: item[0], reverse=True) - paged = ordered[offset:offset + limit] - paged.sort(key=lambda item: item[0]) + paged = _page_ranked_entries(collected, limit, offset) return [line for _, line in paged], failures -def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None): +def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20): collected = [] failures = [] contexts_by_db = {} @@ -915,6 +931,7 @@ def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(db_entries) failures.extend(db_failures) @@ -954,25 +971,40 @@ def _load_search_contexts_from_db(conn, db_path, names): return contexts -def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None): +def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20): collected = [] failures = [] id_to_username = _load_name2id_maps(conn) + batch_size = _message_query_batch_size(candidate_limit) for ctx in contexts: try: - rows = _query_messages( - conn, - ctx['table_name'], - start_ts=start_ts, - end_ts=end_ts, - keyword=keyword, - limit=None, - ) - for row in rows: - formatted = _build_search_entry(row, ctx, names, id_to_username) - if formatted: - collected.append(formatted) + fetch_offset = 0 + collected_before_table = len(collected) + # 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。 + while len(collected) - collected_before_table < candidate_limit: + rows = _query_messages( + conn, + ctx['table_name'], + start_ts=start_ts, + end_ts=end_ts, + keyword=keyword, + limit=batch_size, + offset=fetch_offset, + ) + if not rows: + break + fetch_offset += len(rows) + + for row in rows: + formatted = _build_search_entry(row, ctx, names, id_to_username) + if formatted: + collected.append(formatted) + if len(collected) - collected_before_table >= candidate_limit: + break + + if len(rows) < batch_size: + break except Exception as e: failures.append(f"{ctx['display_name']}: {e}") @@ -980,14 +1012,12 @@ def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_t def _page_search_entries(entries, limit, offset): - ordered = sorted(entries, key=lambda x: x[0], reverse=True) - paged = ordered[offset:offset + limit] - paged.sort(key=lambda x: x[0]) - return paged + return _page_ranked_entries(entries, limit, offset) def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset): names = get_contact_names() + candidate_limit = _candidate_page_size(limit, offset) entries, failures = _collect_chat_search_entries( ctx, @@ -995,6 +1025,7 @@ def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, li keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) paged = _page_search_entries(entries, limit, offset) @@ -1028,6 +1059,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en return f"错误: 没有可查询的聊天对象{suffix}" names = get_contact_names() + candidate_limit = _candidate_page_size(limit, offset) collected = [] failures = [] for ctx in resolved_contexts: @@ -1037,6 +1069,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(chat_entries) failures.extend(chat_failures) @@ -1074,6 +1107,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, names = get_contact_names() collected = [] failures = [] + candidate_limit = _candidate_page_size(limit, offset) for rel_key in MSG_DB_KEYS: path = _cache.get(rel_key) @@ -1090,6 +1124,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit, keyword, start_ts=start_ts, end_ts=end_ts, + candidate_limit=candidate_limit, ) collected.extend(db_entries) failures.extend(db_failures) @@ -1123,7 +1158,7 @@ _last_check_state = {} # {username: last_timestamp} @mcp.tool() -def get_recent_sessions(limit: int = 20) -> str: +def get_recent_sessions(limit: int = 20) -> str: """获取微信最近会话列表,包含最新消息摘要、未读数、时间等。 用于了解最近有哪些人/群在聊天。 @@ -1135,16 +1170,15 @@ def get_recent_sessions(limit: int = 20) -> str: return "错误: 无法解密 session.db" names = get_contact_names() - conn = sqlite3.connect(path) - rows = conn.execute(""" - SELECT username, unread_count, summary, last_timestamp, - last_msg_type, last_msg_sender, last_sender_display_name - FROM SessionTable - WHERE last_timestamp > 0 - ORDER BY last_timestamp DESC - LIMIT ? - """, (limit,)).fetchall() - conn.close() + with closing(sqlite3.connect(path)) as conn: + rows = conn.execute(""" + SELECT username, unread_count, summary, last_timestamp, + last_msg_type, last_msg_sender, last_sender_display_name + FROM SessionTable + WHERE last_timestamp > 0 + ORDER BY last_timestamp DESC + LIMIT ? + """, (limit,)).fetchall() results = [] for r in rows: @@ -1342,7 +1376,7 @@ def get_contacts(query: str = "", limit: int = 50) -> str: @mcp.tool() -def get_new_messages() -> str: +def get_new_messages() -> str: """获取自上次调用以来的新消息。首次调用返回最近的会话状态。""" global _last_check_state @@ -1351,15 +1385,14 @@ def get_new_messages() -> str: return "错误: 无法解密 session.db" names = get_contact_names() - conn = sqlite3.connect(path) - rows = conn.execute(""" - SELECT username, unread_count, summary, last_timestamp, - last_msg_type, last_msg_sender, last_sender_display_name - FROM SessionTable - WHERE last_timestamp > 0 - ORDER BY last_timestamp DESC - """).fetchall() - conn.close() + with closing(sqlite3.connect(path)) as conn: + rows = conn.execute(""" + SELECT username, unread_count, summary, last_timestamp, + last_msg_type, last_msg_sender, last_sender_display_name + FROM SessionTable + WHERE last_timestamp > 0 + ORDER BY last_timestamp DESC + """).fetchall() curr_state = {} for r in rows: diff --git a/tests/test_mcp_server_search.py b/tests/test_mcp_server_search.py index 8505dc0..859d996 100644 --- a/tests/test_mcp_server_search.py +++ b/tests/test_mcp_server_search.py @@ -157,8 +157,8 @@ class SearchMessagesTests(unittest.TestCase): self.assertNotIn("foo a1", result) self.assertNotIn("foo b2", result) - def test_search_messages_all_messages_scans_all_dbs_before_paging(self): - # 全库搜索要先扫完整个库,再分页,不能被旧分片提前截断。 + def test_search_messages_all_messages_merges_global_results_before_paging(self): + # 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。 older_db = self.create_db( "older.db", {"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]}, @@ -182,6 +182,44 @@ class SearchMessagesTests(unittest.TestCase): self.assertLess(result.index("foo newer 2"), result.index("foo newer 1")) self.assertNotIn("foo older 1", result) + def test_search_messages_all_messages_uses_bounded_sql_pagination(self): + # 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。 + older_db = self.create_db( + "older_paged.db", + {"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]}, + ) + newer_db = self.create_db( + "newer_paged.db", + {"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]}, + ) + fake_cache = _FakeCache({"older": older_db, "newer": newer_db}) + original_query_messages = mcp_server._query_messages + calls = [] + + def recording_query_messages(*args, **kwargs): + calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0))) + return original_query_messages(*args, **kwargs) + + with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object( + mcp_server, "_cache", fake_cache + ), patch.object( + mcp_server, + "get_contact_names", + return_value={"older_user": "Older", "newer_user": "Newer"}, + ), patch.object( + mcp_server, "_query_messages", side_effect=recording_query_messages + ): + result = mcp_server.search_messages("foo", limit=2, offset=1) + + self.assertIn('搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result) + self.assertEqual( + calls, + [ + (_msg_table_name("older_user"), 3, 0), + (_msg_table_name("newer_user"), 3, 0), + ], + ) + def test_search_messages_single_chat_respects_time_range(self): # 单聊搜索的开始/结束时间都必须严格生效。 db_path = self.create_db( @@ -339,6 +377,81 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("new message", result) self.assertNotIn("old message", result) + def test_get_chat_history_uses_bounded_sql_pagination(self): + # 历史查询应把 offset+limit 下推到 SQL,避免把整张消息表读出来后再切片。 + db_path = self.create_db( + "history_paged.db", + { + "alice": [ + (1, 400, "newest"), + (2, 300, "middle"), + (3, 200, "older"), + (4, 100, "oldest"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_query_messages = mcp_server._query_messages + calls = [] + + def recording_query_messages(*args, **kwargs): + calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0))) + return original_query_messages(*args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_query_messages", side_effect=recording_query_messages + ): + result = mcp_server.get_chat_history("Alice", limit=2, offset=1) + + self.assertIn("middle", result) + self.assertIn("older", result) + self.assertNotIn("newest", result) + self.assertNotIn("oldest", result) + self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)]) + + def test_get_chat_history_keeps_partial_results_when_formatting_fails(self): + # 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。 + db_path = self.create_db( + "history_partial_failure.db", + {"alice": [(1, 200, "good message"), (2, 100, "bad message")]}, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_build_history_line = mcp_server._build_history_line + + def flaky_build_history_line(row, *args, **kwargs): + if row[2] == 100: + raise ValueError("bad row") + return original_build_history_line(row, *args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_build_history_line", side_effect=flaky_build_history_line + ): + result = mcp_server.get_chat_history("Alice", limit=2, offset=0) + + self.assertIn("good message", result) + self.assertIn("查询失败:", result) + self.assertIn("bad row", result) + def test_search_messages_single_chat_merges_sharded_message_tables(self): # 单聊搜索也要跨分片合并,否则最近消息可能查不到。 older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]}) @@ -377,6 +490,101 @@ class SearchMessagesTests(unittest.TestCase): self.assertIn("foo new", result) self.assertNotIn("foo old", result) + def test_search_messages_keeps_partial_results_when_later_batch_fails(self): + # 后续批次失败时,前面已经拿到的有效结果不应被丢弃。 + db_path = self.create_db( + "search_partial_failure.db", + { + "alice": [ + (1, 400, "foo newest"), + (2, 300, "foo skipped"), + (3, 200, "foo older"), + (4, 100, "foo bad"), + ] + }, + ) + ctx = { + "query": "Alice", + "username": "alice", + "display_name": "Alice", + "db_path": db_path, + "table_name": _msg_table_name("alice"), + "message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}], + "is_group": False, + } + original_build_search_entry = mcp_server._build_search_entry + + def flaky_build_search_entry(row, *args, **kwargs): + if row[2] == 300: + return None + if row[2] == 100: + raise ValueError("bad row") + return original_build_search_entry(row, *args, **kwargs) + + with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object( + mcp_server, "_resolve_chat_context", return_value=ctx + ), patch.object( + mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry + ): + result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0) + + self.assertIn("foo newest", result) + self.assertIn("foo older", result) + self.assertIn("查询失败:", result) + self.assertIn("bad row", result) + + def test_get_recent_sessions_closes_connection_when_query_fails(self): + # 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。 + fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"}) + + class _FakeConn: + def __init__(self): + self.closed = False + + def execute(self, *args, **kwargs): + raise sqlite3.OperationalError("boom") + + def close(self): + self.closed = True + + fake_conn = _FakeConn() + + with patch.object(mcp_server, "_cache", fake_cache), patch.object( + mcp_server, "get_contact_names", return_value={} + ), patch.object( + mcp_server.sqlite3, "connect", return_value=fake_conn + ): + with self.assertRaisesRegex(sqlite3.OperationalError, "boom"): + mcp_server.get_recent_sessions() + + self.assertTrue(fake_conn.closed) + + def test_get_new_messages_closes_connection_when_query_fails(self): + # 新消息轮询失败时也要释放 sqlite3 连接。 + fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"}) + + class _FakeConn: + def __init__(self): + self.closed = False + + def execute(self, *args, **kwargs): + raise sqlite3.OperationalError("boom") + + def close(self): + self.closed = True + + fake_conn = _FakeConn() + + with patch.object(mcp_server, "_cache", fake_cache), patch.object( + mcp_server, "get_contact_names", return_value={} + ), patch.object( + mcp_server.sqlite3, "connect", return_value=fake_conn + ): + with self.assertRaisesRegex(sqlite3.OperationalError, "boom"): + mcp_server.get_new_messages() + + self.assertTrue(fake_conn.closed) + if __name__ == "__main__": unittest.main()