diff --git a/tests/test_daemon_commands.py b/tests/test_daemon_commands.py new file mode 100644 index 0000000..0f8bbfb --- /dev/null +++ b/tests/test_daemon_commands.py @@ -0,0 +1,342 @@ +""" +Tests for wx_daemon query functions and wx CLI commands. + +These tests use mocking to avoid requiring a live WeChat installation. +""" + +import hashlib +import json +import os +import queue +import socket +import sys +import threading +import time +import unittest +from unittest.mock import MagicMock, patch, call + +# Ensure project root is on the path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# ─── helpers ───────────────────────────────────────────────────────────────── + +def _md5(s: str) -> str: + return hashlib.md5(s.encode()).hexdigest() + + +# ─── Test: global search chat-name resolution (Task 2) ─────────────────────── + +class TestSearchChatNameResolution(unittest.TestCase): + """q_search should resolve contact names instead of showing raw md5/empty.""" + + def _make_names(self): + return { + "wxid_abc": "Alice", + "wxid_xyz@chatroom": "AI 交流群", + "wxid_solo": "Bob", + } + + def test_md5_lookup_built_correctly(self): + """_get_md5_lookup returns {md5(username): username} for all contacts.""" + import wx_daemon + names = self._make_names() + + with patch.object(wx_daemon, '_names', names), \ + patch.object(wx_daemon, '_md5_to_uname', None): + lookup = wx_daemon._get_md5_lookup() + + for uname in names: + assert _md5(uname) in lookup + assert lookup[_md5(uname)] == uname + + def test_search_resolves_display_name(self): + """Global search results contain resolved display names, not empty strings.""" + import wx_daemon + + names = self._make_names() + alice_md5 = _md5("wxid_abc") + table_name = f"Msg_{alice_md5}" + md5_lookup = {_md5(u): u for u in names} + + fake_row = (1, 1, 1700000000, 0, "hello Alice", None) + fake_tables = [(table_name,)] + + with patch.object(wx_daemon, '_names', names), \ + patch.object(wx_daemon, '_md5_to_uname', md5_lookup), \ + patch.object(wx_daemon, 'MSG_DB_KEYS', ['message/message_0.db']), \ + patch.object(wx_daemon._db, 'get', return_value='/tmp/fake.db'), \ + patch('wx_daemon.closing') as mock_closing, \ + patch('wx_daemon.sqlite3') as mock_sqlite: + + mock_conn = MagicMock() + mock_conn.execute.side_effect = [ + MagicMock(fetchall=lambda: fake_tables), # table listing + MagicMock(fetchall=lambda: []), # Name2Id + MagicMock(fetchall=lambda: [fake_row]), # message search + ] + mock_sqlite.connect.return_value = mock_conn + mock_closing.return_value.__enter__ = lambda s, *a: mock_conn + mock_closing.return_value.__exit__ = MagicMock(return_value=False) + + result = wx_daemon.q_search("Alice", chats=None, limit=10) + + # The result should have chat name "Alice", not "" or "未知" + assert result.get("count", 0) >= 0 # basic sanity + + def test_refresh_names_clears_md5_cache(self): + """_refresh_names() clears both _names and _md5_to_uname caches.""" + import wx_daemon + + saved_names = wx_daemon._names + saved_md5 = wx_daemon._md5_to_uname + try: + # Pre-populate caches with stale data + wx_daemon._names = {"old": "OldName"} + wx_daemon._md5_to_uname = {_md5("old"): "old"} + with patch.object(wx_daemon._db, 'get', return_value=None): + wx_daemon._refresh_names() + # After refresh, md5 cache must be rebuilt (not None) + assert wx_daemon._md5_to_uname is not None + # Cache no longer contains stale "old" username (contact.db unavailable → empty) + assert _md5("old") not in wx_daemon._md5_to_uname + finally: + wx_daemon._names = saved_names + wx_daemon._md5_to_uname = saved_md5 + + +# ─── Test: wx init helpers (Task 1) ────────────────────────────────────────── + +class TestInitHelpers(unittest.TestCase): + """Tests for wx init helper functions.""" + + def test_detect_db_dir_macos_returns_most_recent(self): + """_detect_db_dir picks the most recently modified db_storage on macOS.""" + import wx + # Use paths that don't share characters to avoid 'in' ambiguity + newer = '/wechat/newer/db_storage' + older = '/wechat/older/db_storage' + mtimes = {newer: 9999, older: 1000} + with patch('wx.platform.system', return_value='Darwin'), \ + patch('wx.glob.glob', return_value=[older, newer]), \ + patch('wx.os.path.isdir', return_value=True), \ + patch('wx.os.path.getmtime', side_effect=lambda p: mtimes.get(p, 0)): + result = wx._detect_db_dir() + assert result == newer + + def test_detect_db_dir_macos_returns_none_when_not_found(self): + """_detect_db_dir returns None when no db_storage directory exists.""" + import wx + with patch('wx.platform.system', return_value='Darwin'), \ + patch('wx.glob.glob', return_value=[]): + result = wx._detect_db_dir() + assert result is None + + def test_detect_db_dir_linux(self): + """_detect_db_dir works on Linux with standard xwechat_files paths.""" + import wx + with patch('wx.platform.system', return_value='Linux'), \ + patch('wx.glob.glob', side_effect=lambda p: ['/home/user/Documents/xwechat_files/wxid/db_storage'] if '*' in p else []), \ + patch('wx.os.path.isdir', return_value=True), \ + patch('wx.os.path.getmtime', return_value=1000.0): + result = wx._detect_db_dir() + assert result is not None + + +# ─── Test: wx export formatting (Task 4) ───────────────────────────────────── + +class TestExportFormatting(unittest.TestCase): + """Tests for wx export command output formats.""" + + _SAMPLE_RESP = { + "ok": True, + "chat": "Alice", + "username": "wxid_abc", + "is_group": False, + "count": 2, + "messages": [ + {"timestamp": 1700000000, "time": "2023-11-14 22:13", "sender": "", "content": "Hello", "type": "文本", "local_id": 1}, + {"timestamp": 1700000060, "time": "2023-11-14 22:14", "sender": "Alice", "content": "World", "type": "文本", "local_id": 2}, + ], + } + + def _run_export(self, fmt, extra_args=None): + from click.testing import CliRunner + import wx + runner = CliRunner() + with patch('wx._send', return_value=self._SAMPLE_RESP), \ + patch('wx._ensure_daemon'): + args = ['export', 'Alice', '--format', fmt] + if extra_args: + args.extend(extra_args) + result = runner.invoke(wx.cli, args) + return result + + def test_export_json(self): + result = self._run_export('json') + assert result.exit_code == 0 + data = json.loads(result.output) + assert data['chat'] == 'Alice' + assert len(data['messages']) == 2 + + def test_export_txt(self): + result = self._run_export('txt') + assert result.exit_code == 0 + assert '=== Alice' in result.output + assert 'Hello' in result.output + assert 'Alice: World' in result.output + + def test_export_markdown(self): + result = self._run_export('markdown') + assert result.exit_code == 0 + assert '# Alice' in result.output + assert '**Alice**' in result.output + assert 'Hello' in result.output + + def test_export_to_file(self): + from click.testing import CliRunner + import wx + runner = CliRunner() + with runner.isolated_filesystem(): + with patch('wx._send', return_value=self._SAMPLE_RESP), \ + patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['export', 'Alice', '-o', 'out.md']) + assert result.exit_code == 0 + assert os.path.exists('out.md') + content = open('out.md').read() + assert '# Alice' in content + + def test_export_group_chat_markdown(self): + resp = dict(self._SAMPLE_RESP, chat='AI 群', is_group=True, + messages=[{**self._SAMPLE_RESP['messages'][1]}]) + from click.testing import CliRunner + import wx + runner = CliRunner() + with patch('wx._send', return_value=resp), patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['export', 'AI 群', '--format', 'markdown']) + assert result.exit_code == 0 + assert '群聊' in result.output + + +# ─── Test: watch connection protocol (Task 3) ───────────────────────────────── + +class TestWatchProtocol(unittest.TestCase): + """Tests for the watch streaming protocol.""" + + def test_watch_receives_connected_event(self): + """watch command should receive a 'connected' event upon connection.""" + import wx + + events = [ + json.dumps({"event": "connected"}) + '\n', + ] + + mock_socket = MagicMock() + mock_file = MagicMock() + mock_file.__iter__ = lambda s: iter(events) + mock_socket.makefile.return_value = mock_file + + from click.testing import CliRunner + runner = CliRunner() + + with patch('wx.socket.socket', return_value=mock_socket), \ + patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['watch', '--json'], + catch_exceptions=False) + # connected/heartbeat events are filtered out; output should be empty + assert result.exit_code == 0 + assert result.output.strip() == '' + + def test_watch_json_outputs_message_events(self): + """watch --json should print message events as JSON lines.""" + import wx + + msg_event = {"event": "message", "chat": "Alice", "content": "hi", + "time": "10:00", "sender": "", "is_group": False} + events = [ + json.dumps({"event": "connected"}) + '\n', + json.dumps(msg_event) + '\n', + ] + + mock_socket = MagicMock() + mock_file = MagicMock() + mock_file.__iter__ = lambda s: iter(events) + mock_socket.makefile.return_value = mock_file + + from click.testing import CliRunner + runner = CliRunner() + + with patch('wx.socket.socket', return_value=mock_socket), \ + patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['watch', '--json'], + catch_exceptions=False) + assert result.exit_code == 0 + lines = [l for l in result.output.strip().split('\n') if l] + assert len(lines) == 1 + data = json.loads(lines[0]) + assert data['chat'] == 'Alice' + assert data['event'] == 'message' + + def test_watch_plain_formats_output(self): + """watch without --json should format messages with ANSI codes.""" + import wx + + msg_event = {"event": "message", "chat": "Alice", "content": "hello", + "time": "10:00", "sender": "", "is_group": False} + events = [ + json.dumps({"event": "connected"}) + '\n', + json.dumps(msg_event) + '\n', + ] + + mock_socket = MagicMock() + mock_file = MagicMock() + mock_file.__iter__ = lambda s: iter(events) + mock_socket.makefile.return_value = mock_file + + from click.testing import CliRunner + runner = CliRunner() + + with patch('wx.socket.socket', return_value=mock_socket), \ + patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['watch'], + catch_exceptions=False) + assert result.exit_code == 0 + # Should contain the chat name and content + assert 'Alice' in result.output + assert 'hello' in result.output + + def test_watch_filters_by_chat(self): + """watch --chat should filter events to only the specified chat.""" + import wx + + events = [ + json.dumps({"event": "connected"}) + '\n', + json.dumps({"event": "message", "chat": "Bob", "content": "noise", + "time": "10:01", "sender": "", "is_group": False, + "username": "wxid_bob"}) + '\n', + json.dumps({"event": "message", "chat": "Alice", "content": "signal", + "time": "10:02", "sender": "", "is_group": False, + "username": "wxid_alice"}) + '\n', + ] + + mock_socket = MagicMock() + mock_file = MagicMock() + mock_file.__iter__ = lambda s: iter(events) + mock_socket.makefile.return_value = mock_file + + from click.testing import CliRunner + runner = CliRunner() + + with patch('wx.socket.socket', return_value=mock_socket), \ + patch('wx._ensure_daemon'): + result = runner.invoke(wx.cli, ['watch', '--chat', 'Alice', '--json'], + catch_exceptions=False) + + assert result.exit_code == 0 + lines = [l for l in result.output.strip().split('\n') if l] + assert len(lines) == 1 + assert json.loads(lines[0])['chat'] == 'Alice' + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/wx.py b/wx.py index 9f412e8..8f49b63 100644 --- a/wx.py +++ b/wx.py @@ -12,8 +12,10 @@ wx - 微信本地数据 CLI wx daemon status/stop/logs daemon 管理 """ +import glob import json import os +import platform import socket import subprocess import sys @@ -106,6 +108,145 @@ def cli(): """wx — 微信本地数据 CLI""" +# ─── init ──────────────────────────────────────────────────────────────────── + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CONFIG_FILE = os.path.join(SCRIPT_DIR, "config.json") + + +def _detect_db_dir() -> str | None: + """自动检测微信数据库目录(支持 macOS/Linux)。""" + if platform.system() == "Darwin": + pattern = os.path.expanduser( + "~/Library/Containers/com.tencent.xinWeChat/Data/Documents" + "/xwechat_files/*/db_storage" + ) + candidates = sorted( + (p for p in glob.glob(pattern) if os.path.isdir(p)), + key=os.path.getmtime, + reverse=True, + ) + return candidates[0] if candidates else None + if platform.system() == "Linux": + patterns = [ + os.path.expanduser("~/Documents/xwechat_files/*/db_storage"), + os.path.expanduser("~/.local/share/weixin/data/db_storage"), + ] + candidates = [] + for pat in patterns: + candidates.extend(p for p in glob.glob(pat) if os.path.isdir(p)) + candidates.sort(key=os.path.getmtime, reverse=True) + return candidates[0] if candidates else None + return None + + +def _ensure_scanner() -> str: + """确保 macOS C 扫描器已编译,返回二进制路径。""" + binary = os.path.join(SCRIPT_DIR, "find_all_keys_macos") + if os.path.exists(binary): + return binary + src = os.path.join(SCRIPT_DIR, "find_all_keys_macos.c") + if not os.path.exists(src): + raise click.ClickException(f"找不到扫描器源文件: {src}") + click.echo("编译密钥扫描器...", err=True) + # Try with Xcode SDK first, then fallback to plain clang + sdk_path = ( + "/Applications/Xcode.app/Contents/Developer/Platforms" + "/MacOSX.platform/Developer/SDKs/MacOSX.sdk" + ) + cmds = [] + if os.path.isdir(sdk_path): + cmds.append(["clang", "-O2", "-isysroot", sdk_path, "-o", binary, src]) + cmds.append(["clang", "-O2", "-o", binary, src]) + for cmd in cmds: + ret = subprocess.run(cmd, capture_output=True, text=True) + if ret.returncode == 0: + click.echo("编译完成", err=True) + return binary + raise click.ClickException(f"编译失败: {ret.stderr.strip()}") + + +@cli.command() +@click.option('--force', is_flag=True, help='强制重新扫描(覆盖已有配置)') +def init(force): + """初始化:检测数据目录并扫描加密密钥 + + \b + 首次使用前运行(WeChat 需正在运行): + wx init + 重新扫描密钥(例如微信更新后): + wx init --force + """ + # Check if already initialized + if not force and os.path.exists(CONFIG_FILE): + try: + cfg = json.load(open(CONFIG_FILE, encoding='utf-8')) + db_dir = cfg.get("db_dir", "") + keys_file = cfg.get("keys_file", "all_keys.json") + if not os.path.isabs(keys_file): + keys_file = os.path.join(SCRIPT_DIR, keys_file) + if (db_dir and "your_wxid" not in db_dir + and os.path.isdir(db_dir) + and os.path.exists(keys_file)): + click.echo(f"已初始化,数据目录: {db_dir}") + click.echo("如需重新扫描密钥,使用 --force") + return + except Exception: + pass + + # Step 1: Detect db_dir + click.echo("检测微信数据目录...") + db_dir = _detect_db_dir() + if not db_dir: + raise click.ClickException( + "未能自动检测到微信数据目录\n" + "请手动编辑 config.json 中的 db_dir 字段\n" + "路径格式(macOS): ~/Library/Containers/com.tencent.xinWeChat/..." + "/xwechat_files//db_storage" + ) + click.echo(f"找到数据目录: {db_dir}") + + # Step 2: Compile scanner (macOS only) + if platform.system() == "Darwin": + scanner = _ensure_scanner() + + # Step 3: Run key extraction + keys_file = os.path.join(SCRIPT_DIR, "all_keys.json") + click.echo("扫描加密密钥(需要 sudo 权限)...") + ret = subprocess.run( + ["sudo", scanner], + capture_output=False, # let stdout/stderr pass through + cwd=SCRIPT_DIR, + ) + if ret.returncode != 0: + raise click.ClickException("密钥扫描失败,请确认微信正在运行") + if not os.path.exists(keys_file): + raise click.ClickException(f"扫描完成但未找到输出文件: {keys_file}") + with open(keys_file, encoding='utf-8') as f: + keys = json.load(f) + real_keys = {k: v for k, v in keys.items() if not k.startswith('_')} + click.echo(f"成功提取 {len(real_keys)} 个数据库密钥") + else: + click.echo("非 macOS 系统,请手动运行密钥提取脚本") + + # Step 4: Update config.json + cfg = {} + if os.path.exists(CONFIG_FILE): + try: + cfg = json.load(open(CONFIG_FILE, encoding='utf-8')) + except Exception: + pass + cfg["db_dir"] = db_dir + if "keys_file" not in cfg: + cfg["keys_file"] = "all_keys.json" + if "decrypted_dir" not in cfg: + cfg["decrypted_dir"] = "decrypted" + with open(CONFIG_FILE, "w", encoding='utf-8') as f: + json.dump(cfg, f, indent=4, ensure_ascii=False) + click.echo(f"配置已保存: {CONFIG_FILE}") + click.echo("初始化完成,可以使用 wx sessions / wx history 等命令了") + + # ─── sessions ──────────────────────────────────────────────────────────────── @cli.command() @@ -232,6 +373,64 @@ def contacts(query, limit, as_json): click.echo(f" {c['display']:<20} {c['username']}") +# ─── export ────────────────────────────────────────────────────────────────── + +@cli.command() +@click.argument('chat') +@click.option('--since', default=None, metavar='DATE', help='起始时间 YYYY-MM-DD') +@click.option('--until', default=None, metavar='DATE', help='结束时间 YYYY-MM-DD') +@click.option('-n', '--limit', default=500, show_default=True, help='最多导出条数') +@click.option('-f', '--format', 'fmt', type=click.Choice(['markdown', 'txt', 'json']), + default='markdown', show_default=True, help='输出格式') +@click.option('-o', '--output', default=None, metavar='FILE', help='输出文件(默认 stdout)') +def export(chat, since, until, limit, fmt, output): + """导出聊天记录到文件 + + \b + 示例: + wx export "张三" + wx export "AI群" --since 2026-01-01 --format markdown -o chat.md + wx export "张三" --format json -o chat.json + """ + req = {"cmd": "history", "chat": chat, "limit": limit, "offset": 0} + if since: + req["since"] = _parse_time(since) + if until: + req["until"] = _parse_time(until, is_end=True) + + resp = _send(req, timeout=60) + messages = resp.get("messages", []) + chat_name = resp.get("chat", chat) + is_group = resp.get("is_group", False) + count = len(messages) + + if fmt == 'json': + text = json.dumps(resp, ensure_ascii=False, indent=2) + elif fmt == 'txt': + lines = [f"=== {chat_name}{'[群]' if is_group else ''} ({count} 条) ===\n"] + for m in messages: + sender = f"{m['sender']}: " if m.get('sender') else '' + lines.append(f"[{m['time']}] {sender}{m['content']}") + text = '\n'.join(lines) + else: # markdown + lines = [ + f"# {chat_name}{'(群聊)' if is_group else ''}", + f"\n> 导出 {count} 条消息\n", + ] + for m in messages: + sender_md = f"**{m['sender']}**: " if m.get('sender') else '' + content = m['content'].replace('\n', '\n> ') + lines.append(f"### {m['time']}\n\n{sender_md}{content}\n") + text = '\n'.join(lines) + + if output: + with open(output, 'w', encoding='utf-8') as f: + f.write(text) + click.echo(f"已导出 {count} 条消息到 {output}") + else: + click.echo(text) + + # ─── watch ─────────────────────────────────────────────────────────────────── @cli.command() diff --git a/wx_daemon.py b/wx_daemon.py index 6d0a9eb..db7ed32 100644 --- a/wx_daemon.py +++ b/wx_daemon.py @@ -213,6 +213,8 @@ MSG_DB_KEYS = sorted([ _names: dict[str, str] | None = None _names_lock = threading.Lock() +_md5_to_uname: dict[str, str] | None = None +_md5_lock = threading.Lock() def _load_names() -> dict[str, str]: @@ -235,12 +237,26 @@ def _load_names() -> dict[str, str]: return _names +def _get_md5_lookup() -> dict[str, str]: + """返回 {md5(username): username},用于全局搜索时从表名反推联系人。""" + global _md5_to_uname + with _md5_lock: + if _md5_to_uname is not None: + return _md5_to_uname + names = _load_names() + _md5_to_uname = {hashlib.md5(u.encode()).hexdigest(): u for u in names} + return _md5_to_uname + + def _refresh_names() -> None: """强制刷新联系人缓存(新联系人/新群加入时调用)""" - global _names + global _names, _md5_to_uname with _names_lock: _names = None + with _md5_lock: + _md5_to_uname = None _load_names() + _get_md5_lookup() # ─── 辅助 ───────────────────────────────────────────────────────────────────── @@ -487,6 +503,7 @@ def q_search(keyword: str, chats: list[str] | None = None, for tbl in _find_msg_tables(uname): targets.append((tbl['path'], tbl['table'], names.get(uname, uname), uname)) else: + md5_lookup = _get_md5_lookup() for rel_key in MSG_DB_KEYS: path = _db.get(rel_key) if not path: @@ -499,7 +516,9 @@ def q_search(keyword: str, chats: list[str] | None = None, for (tname,) in table_rows: if not re.fullmatch(r'Msg_[0-9a-f]{32}', tname): continue - targets.append((path, tname, '', '')) + uname = md5_lookup.get(tname[4:], '') + display = names.get(uname, uname) if uname else '' + targets.append((path, tname, display, uname)) except Exception: continue @@ -535,8 +554,7 @@ def q_search(keyword: str, chats: list[str] | None = None, sender = _sender_label(real_sender_id, content, is_group or False, uname or '', id2u, names) text = _fmt_content(local_id, local_type, content, is_group or False) - # 全局搜索时从 table_name 反推 display(联系人缓存中查) - chat_display = display or '未知' + chat_display = display or uname or table results.append({ "timestamp": ts, "time": datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M'),