mirror of https://github.com/jackwener/wx-cli.git
Fix global search pagination
parent
b623711410
commit
9ae558a31e
|
|
@ -872,33 +872,49 @@ def _iter_table_contexts(ctx):
|
|||
}
|
||||
|
||||
|
||||
def _candidate_page_size(limit, offset):
|
||||
return limit + offset
|
||||
|
||||
|
||||
def _message_query_batch_size(candidate_limit):
|
||||
return candidate_limit
|
||||
|
||||
|
||||
def _page_ranked_entries(entries, limit, offset):
|
||||
ordered = sorted(entries, key=lambda item: item[0], reverse=True)
|
||||
paged = ordered[offset:offset + limit]
|
||||
paged.sort(key=lambda item: item[0])
|
||||
return paged
|
||||
|
||||
|
||||
def _collect_chat_history_lines(ctx, names, start_ts=None, end_ts=None, limit=20, offset=0):
|
||||
collected = []
|
||||
failures = []
|
||||
candidate_limit = _candidate_page_size(limit, offset)
|
||||
|
||||
for table_ctx in _iter_table_contexts(ctx):
|
||||
try:
|
||||
with closing(sqlite3.connect(table_ctx['db_path'])) as conn:
|
||||
id_to_username = _load_name2id_maps(conn)
|
||||
# 当前页上的消息一定落在各分表最近的 offset+limit 条记录内。
|
||||
rows = _query_messages(
|
||||
conn,
|
||||
table_ctx['table_name'],
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
limit=None,
|
||||
limit=candidate_limit,
|
||||
offset=0,
|
||||
)
|
||||
for row in rows:
|
||||
collected.append(_build_history_line(row, table_ctx, names, id_to_username))
|
||||
except Exception as e:
|
||||
failures.append(f"{table_ctx['db_path']}: {e}")
|
||||
|
||||
ordered = sorted(collected, key=lambda item: item[0], reverse=True)
|
||||
paged = ordered[offset:offset + limit]
|
||||
paged.sort(key=lambda item: item[0])
|
||||
paged = _page_ranked_entries(collected, limit, offset)
|
||||
return [line for _, line in paged], failures
|
||||
|
||||
|
||||
def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None):
|
||||
def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
|
||||
collected = []
|
||||
failures = []
|
||||
contexts_by_db = {}
|
||||
|
|
@ -915,6 +931,7 @@ def _collect_chat_search_entries(ctx, names, keyword, start_ts=None, end_ts=None
|
|||
keyword,
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
candidate_limit=candidate_limit,
|
||||
)
|
||||
collected.extend(db_entries)
|
||||
failures.extend(db_failures)
|
||||
|
|
@ -954,25 +971,40 @@ def _load_search_contexts_from_db(conn, db_path, names):
|
|||
return contexts
|
||||
|
||||
|
||||
def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None):
|
||||
def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_ts=None, candidate_limit=20):
|
||||
collected = []
|
||||
failures = []
|
||||
id_to_username = _load_name2id_maps(conn)
|
||||
batch_size = _message_query_batch_size(candidate_limit)
|
||||
|
||||
for ctx in contexts:
|
||||
try:
|
||||
fetch_offset = 0
|
||||
collected_before_table = len(collected)
|
||||
# 全局分页只需要每个分表最新的 offset+limit 条有效命中,无需把整表命中读进内存。
|
||||
while len(collected) - collected_before_table < candidate_limit:
|
||||
rows = _query_messages(
|
||||
conn,
|
||||
ctx['table_name'],
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
keyword=keyword,
|
||||
limit=None,
|
||||
limit=batch_size,
|
||||
offset=fetch_offset,
|
||||
)
|
||||
if not rows:
|
||||
break
|
||||
fetch_offset += len(rows)
|
||||
|
||||
for row in rows:
|
||||
formatted = _build_search_entry(row, ctx, names, id_to_username)
|
||||
if formatted:
|
||||
collected.append(formatted)
|
||||
if len(collected) - collected_before_table >= candidate_limit:
|
||||
break
|
||||
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
except Exception as e:
|
||||
failures.append(f"{ctx['display_name']}: {e}")
|
||||
|
||||
|
|
@ -980,14 +1012,12 @@ def _collect_search_entries(conn, contexts, names, keyword, start_ts=None, end_t
|
|||
|
||||
|
||||
def _page_search_entries(entries, limit, offset):
|
||||
ordered = sorted(entries, key=lambda x: x[0], reverse=True)
|
||||
paged = ordered[offset:offset + limit]
|
||||
paged.sort(key=lambda x: x[0])
|
||||
return paged
|
||||
return _page_ranked_entries(entries, limit, offset)
|
||||
|
||||
|
||||
def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, limit, offset):
|
||||
names = get_contact_names()
|
||||
candidate_limit = _candidate_page_size(limit, offset)
|
||||
|
||||
entries, failures = _collect_chat_search_entries(
|
||||
ctx,
|
||||
|
|
@ -995,6 +1025,7 @@ def _search_single_chat(ctx, keyword, start_ts, end_ts, start_time, end_time, li
|
|||
keyword,
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
candidate_limit=candidate_limit,
|
||||
)
|
||||
|
||||
paged = _page_search_entries(entries, limit, offset)
|
||||
|
|
@ -1028,6 +1059,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en
|
|||
return f"错误: 没有可查询的聊天对象{suffix}"
|
||||
|
||||
names = get_contact_names()
|
||||
candidate_limit = _candidate_page_size(limit, offset)
|
||||
collected = []
|
||||
failures = []
|
||||
for ctx in resolved_contexts:
|
||||
|
|
@ -1037,6 +1069,7 @@ def _search_multiple_chats(chat_names, keyword, start_ts, end_ts, start_time, en
|
|||
keyword,
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
candidate_limit=candidate_limit,
|
||||
)
|
||||
collected.extend(chat_entries)
|
||||
failures.extend(chat_failures)
|
||||
|
|
@ -1074,6 +1107,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit,
|
|||
names = get_contact_names()
|
||||
collected = []
|
||||
failures = []
|
||||
candidate_limit = _candidate_page_size(limit, offset)
|
||||
|
||||
for rel_key in MSG_DB_KEYS:
|
||||
path = _cache.get(rel_key)
|
||||
|
|
@ -1090,6 +1124,7 @@ def _search_all_messages(keyword, start_ts, end_ts, start_time, end_time, limit,
|
|||
keyword,
|
||||
start_ts=start_ts,
|
||||
end_ts=end_ts,
|
||||
candidate_limit=candidate_limit,
|
||||
)
|
||||
collected.extend(db_entries)
|
||||
failures.extend(db_failures)
|
||||
|
|
@ -1135,7 +1170,7 @@ def get_recent_sessions(limit: int = 20) -> str:
|
|||
return "错误: 无法解密 session.db"
|
||||
|
||||
names = get_contact_names()
|
||||
conn = sqlite3.connect(path)
|
||||
with closing(sqlite3.connect(path)) as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT username, unread_count, summary, last_timestamp,
|
||||
last_msg_type, last_msg_sender, last_sender_display_name
|
||||
|
|
@ -1144,7 +1179,6 @@ def get_recent_sessions(limit: int = 20) -> str:
|
|||
ORDER BY last_timestamp DESC
|
||||
LIMIT ?
|
||||
""", (limit,)).fetchall()
|
||||
conn.close()
|
||||
|
||||
results = []
|
||||
for r in rows:
|
||||
|
|
@ -1351,7 +1385,7 @@ def get_new_messages() -> str:
|
|||
return "错误: 无法解密 session.db"
|
||||
|
||||
names = get_contact_names()
|
||||
conn = sqlite3.connect(path)
|
||||
with closing(sqlite3.connect(path)) as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT username, unread_count, summary, last_timestamp,
|
||||
last_msg_type, last_msg_sender, last_sender_display_name
|
||||
|
|
@ -1359,7 +1393,6 @@ def get_new_messages() -> str:
|
|||
WHERE last_timestamp > 0
|
||||
ORDER BY last_timestamp DESC
|
||||
""").fetchall()
|
||||
conn.close()
|
||||
|
||||
curr_state = {}
|
||||
for r in rows:
|
||||
|
|
|
|||
|
|
@ -157,8 +157,8 @@ class SearchMessagesTests(unittest.TestCase):
|
|||
self.assertNotIn("foo a1", result)
|
||||
self.assertNotIn("foo b2", result)
|
||||
|
||||
def test_search_messages_all_messages_scans_all_dbs_before_paging(self):
|
||||
# 全库搜索要先扫完整个库,再分页,不能被旧分片提前截断。
|
||||
def test_search_messages_all_messages_merges_global_results_before_paging(self):
|
||||
# 全库搜索要基于跨库合并后的全局时间线分页,不能被单个分库提前截断。
|
||||
older_db = self.create_db(
|
||||
"older.db",
|
||||
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
|
||||
|
|
@ -182,6 +182,44 @@ class SearchMessagesTests(unittest.TestCase):
|
|||
self.assertLess(result.index("foo newer 2"), result.index("foo newer 1"))
|
||||
self.assertNotIn("foo older 1", result)
|
||||
|
||||
def test_search_messages_all_messages_uses_bounded_sql_pagination(self):
|
||||
# 每个消息表都只应查询当前页所需的候选窗口,不能回退到 limit=None 的全量扫描。
|
||||
older_db = self.create_db(
|
||||
"older_paged.db",
|
||||
{"older_user": [(1, 10, "foo older 1"), (2, 9, "foo older 2"), (3, 8, "foo older 3")]},
|
||||
)
|
||||
newer_db = self.create_db(
|
||||
"newer_paged.db",
|
||||
{"newer_user": [(1, 30, "foo newer 1"), (2, 20, "foo newer 2"), (3, 19, "foo newer 3")]},
|
||||
)
|
||||
fake_cache = _FakeCache({"older": older_db, "newer": newer_db})
|
||||
original_query_messages = mcp_server._query_messages
|
||||
calls = []
|
||||
|
||||
def recording_query_messages(*args, **kwargs):
|
||||
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
|
||||
return original_query_messages(*args, **kwargs)
|
||||
|
||||
with patch.object(mcp_server, "MSG_DB_KEYS", ["older", "newer"]), patch.object(
|
||||
mcp_server, "_cache", fake_cache
|
||||
), patch.object(
|
||||
mcp_server,
|
||||
"get_contact_names",
|
||||
return_value={"older_user": "Older", "newer_user": "Newer"},
|
||||
), patch.object(
|
||||
mcp_server, "_query_messages", side_effect=recording_query_messages
|
||||
):
|
||||
result = mcp_server.search_messages("foo", limit=2, offset=1)
|
||||
|
||||
self.assertIn('搜索 "foo" 找到 2 条结果(offset=1, limit=2)', result)
|
||||
self.assertEqual(
|
||||
calls,
|
||||
[
|
||||
(_msg_table_name("older_user"), 3, 0),
|
||||
(_msg_table_name("newer_user"), 3, 0),
|
||||
],
|
||||
)
|
||||
|
||||
def test_search_messages_single_chat_respects_time_range(self):
|
||||
# 单聊搜索的开始/结束时间都必须严格生效。
|
||||
db_path = self.create_db(
|
||||
|
|
@ -339,6 +377,81 @@ class SearchMessagesTests(unittest.TestCase):
|
|||
self.assertIn("new message", result)
|
||||
self.assertNotIn("old message", result)
|
||||
|
||||
def test_get_chat_history_uses_bounded_sql_pagination(self):
|
||||
# 历史查询应把 offset+limit 下推到 SQL,避免把整张消息表读出来后再切片。
|
||||
db_path = self.create_db(
|
||||
"history_paged.db",
|
||||
{
|
||||
"alice": [
|
||||
(1, 400, "newest"),
|
||||
(2, 300, "middle"),
|
||||
(3, 200, "older"),
|
||||
(4, 100, "oldest"),
|
||||
]
|
||||
},
|
||||
)
|
||||
ctx = {
|
||||
"query": "Alice",
|
||||
"username": "alice",
|
||||
"display_name": "Alice",
|
||||
"db_path": db_path,
|
||||
"table_name": _msg_table_name("alice"),
|
||||
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||
"is_group": False,
|
||||
}
|
||||
original_query_messages = mcp_server._query_messages
|
||||
calls = []
|
||||
|
||||
def recording_query_messages(*args, **kwargs):
|
||||
calls.append((args[1], kwargs.get("limit"), kwargs.get("offset", 0)))
|
||||
return original_query_messages(*args, **kwargs)
|
||||
|
||||
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||
), patch.object(
|
||||
mcp_server, "_query_messages", side_effect=recording_query_messages
|
||||
):
|
||||
result = mcp_server.get_chat_history("Alice", limit=2, offset=1)
|
||||
|
||||
self.assertIn("middle", result)
|
||||
self.assertIn("older", result)
|
||||
self.assertNotIn("newest", result)
|
||||
self.assertNotIn("oldest", result)
|
||||
self.assertEqual(calls, [(_msg_table_name("alice"), 3, 0)])
|
||||
|
||||
def test_get_chat_history_keeps_partial_results_when_formatting_fails(self):
|
||||
# 单条坏消息不应让整个历史查询失败,已有结果仍应返回并附带失败说明。
|
||||
db_path = self.create_db(
|
||||
"history_partial_failure.db",
|
||||
{"alice": [(1, 200, "good message"), (2, 100, "bad message")]},
|
||||
)
|
||||
ctx = {
|
||||
"query": "Alice",
|
||||
"username": "alice",
|
||||
"display_name": "Alice",
|
||||
"db_path": db_path,
|
||||
"table_name": _msg_table_name("alice"),
|
||||
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||
"is_group": False,
|
||||
}
|
||||
original_build_history_line = mcp_server._build_history_line
|
||||
|
||||
def flaky_build_history_line(row, *args, **kwargs):
|
||||
if row[2] == 100:
|
||||
raise ValueError("bad row")
|
||||
return original_build_history_line(row, *args, **kwargs)
|
||||
|
||||
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||
), patch.object(
|
||||
mcp_server, "_build_history_line", side_effect=flaky_build_history_line
|
||||
):
|
||||
result = mcp_server.get_chat_history("Alice", limit=2, offset=0)
|
||||
|
||||
self.assertIn("good message", result)
|
||||
self.assertIn("查询失败:", result)
|
||||
self.assertIn("bad row", result)
|
||||
|
||||
def test_search_messages_single_chat_merges_sharded_message_tables(self):
|
||||
# 单聊搜索也要跨分片合并,否则最近消息可能查不到。
|
||||
older_db = self.create_db("search_older.db", {"alice": [(1, 100, "foo old")]})
|
||||
|
|
@ -377,6 +490,101 @@ class SearchMessagesTests(unittest.TestCase):
|
|||
self.assertIn("foo new", result)
|
||||
self.assertNotIn("foo old", result)
|
||||
|
||||
def test_search_messages_keeps_partial_results_when_later_batch_fails(self):
|
||||
# 后续批次失败时,前面已经拿到的有效结果不应被丢弃。
|
||||
db_path = self.create_db(
|
||||
"search_partial_failure.db",
|
||||
{
|
||||
"alice": [
|
||||
(1, 400, "foo newest"),
|
||||
(2, 300, "foo skipped"),
|
||||
(3, 200, "foo older"),
|
||||
(4, 100, "foo bad"),
|
||||
]
|
||||
},
|
||||
)
|
||||
ctx = {
|
||||
"query": "Alice",
|
||||
"username": "alice",
|
||||
"display_name": "Alice",
|
||||
"db_path": db_path,
|
||||
"table_name": _msg_table_name("alice"),
|
||||
"message_tables": [{"db_path": db_path, "table_name": _msg_table_name("alice")}],
|
||||
"is_group": False,
|
||||
}
|
||||
original_build_search_entry = mcp_server._build_search_entry
|
||||
|
||||
def flaky_build_search_entry(row, *args, **kwargs):
|
||||
if row[2] == 300:
|
||||
return None
|
||||
if row[2] == 100:
|
||||
raise ValueError("bad row")
|
||||
return original_build_search_entry(row, *args, **kwargs)
|
||||
|
||||
with patch.object(mcp_server, "get_contact_names", return_value={"alice": "Alice"}), patch.object(
|
||||
mcp_server, "_resolve_chat_context", return_value=ctx
|
||||
), patch.object(
|
||||
mcp_server, "_build_search_entry", side_effect=flaky_build_search_entry
|
||||
):
|
||||
result = mcp_server.search_messages("foo", chat_name="Alice", limit=3, offset=0)
|
||||
|
||||
self.assertIn("foo newest", result)
|
||||
self.assertIn("foo older", result)
|
||||
self.assertIn("查询失败:", result)
|
||||
self.assertIn("bad row", result)
|
||||
|
||||
def test_get_recent_sessions_closes_connection_when_query_fails(self):
|
||||
# 会话查询抛异常时也必须关闭 sqlite3 连接,避免资源泄漏。
|
||||
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
|
||||
|
||||
class _FakeConn:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
raise sqlite3.OperationalError("boom")
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
fake_conn = _FakeConn()
|
||||
|
||||
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
|
||||
mcp_server, "get_contact_names", return_value={}
|
||||
), patch.object(
|
||||
mcp_server.sqlite3, "connect", return_value=fake_conn
|
||||
):
|
||||
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
|
||||
mcp_server.get_recent_sessions()
|
||||
|
||||
self.assertTrue(fake_conn.closed)
|
||||
|
||||
def test_get_new_messages_closes_connection_when_query_fails(self):
|
||||
# 新消息轮询失败时也要释放 sqlite3 连接。
|
||||
fake_cache = _FakeCache({os.path.join("session", "session.db"): "session.db"})
|
||||
|
||||
class _FakeConn:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
raise sqlite3.OperationalError("boom")
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
fake_conn = _FakeConn()
|
||||
|
||||
with patch.object(mcp_server, "_cache", fake_cache), patch.object(
|
||||
mcp_server, "get_contact_names", return_value={}
|
||||
), patch.object(
|
||||
mcp_server.sqlite3, "connect", return_value=fake_conn
|
||||
):
|
||||
with self.assertRaisesRegex(sqlite3.OperationalError, "boom"):
|
||||
mcp_server.get_new_messages()
|
||||
|
||||
self.assertTrue(fake_conn.closed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue