r""" WeChat MCP Server - query WeChat messages, contacts via Claude Based on FastMCP (stdio transport), reuses existing decryption. Runs on Windows Python (needs access to D:\ WeChat databases). """ import os, sys, json, time, sqlite3, tempfile, struct, hashlib, atexit, re import hmac as hmac_mod from datetime import datetime import xml.etree.ElementTree as ET from Crypto.Cipher import AES from mcp.server.fastmcp import FastMCP import zstandard as zstd from decode_image import ImageResolver from key_utils import get_key_info, key_path_variants, strip_key_metadata # ============ 加密常量 ============ PAGE_SZ = 4096 KEY_SZ = 32 SALT_SZ = 16 RESERVE_SZ = 80 SQLITE_HDR = b'SQLite format 3\x00' WAL_HEADER_SZ = 32 WAL_FRAME_HEADER_SZ = 24 # ============ 配置加载 ============ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) CONFIG_FILE = os.path.join(SCRIPT_DIR, "config.json") with open(CONFIG_FILE) as f: _cfg = json.load(f) for _key in ("keys_file", "decrypted_dir"): if _key in _cfg and not os.path.isabs(_cfg[_key]): _cfg[_key] = os.path.join(SCRIPT_DIR, _cfg[_key]) DB_DIR = _cfg["db_dir"] KEYS_FILE = _cfg["keys_file"] DECRYPTED_DIR = _cfg["decrypted_dir"] # 图片相关路径 _db_dir = _cfg["db_dir"] if os.path.basename(_db_dir) == "db_storage": WECHAT_BASE_DIR = os.path.dirname(_db_dir) else: WECHAT_BASE_DIR = _db_dir DECODED_IMAGE_DIR = _cfg.get("decoded_image_dir") if not DECODED_IMAGE_DIR: DECODED_IMAGE_DIR = os.path.join(SCRIPT_DIR, "decoded_images") elif not os.path.isabs(DECODED_IMAGE_DIR): DECODED_IMAGE_DIR = os.path.join(SCRIPT_DIR, DECODED_IMAGE_DIR) with open(KEYS_FILE) as f: ALL_KEYS = strip_key_metadata(json.load(f)) # ============ 解密函数 ============ def decrypt_page(enc_key, page_data, pgno): iv = page_data[PAGE_SZ - RESERVE_SZ : PAGE_SZ - RESERVE_SZ + 16] if pgno == 1: encrypted = page_data[SALT_SZ : PAGE_SZ - RESERVE_SZ] cipher = AES.new(enc_key, AES.MODE_CBC, iv) decrypted = cipher.decrypt(encrypted) return bytes(bytearray(SQLITE_HDR + decrypted + b'\x00' * RESERVE_SZ)) else: encrypted = page_data[: PAGE_SZ - RESERVE_SZ] cipher = AES.new(enc_key, AES.MODE_CBC, iv) decrypted = cipher.decrypt(encrypted) return decrypted + b'\x00' * RESERVE_SZ def full_decrypt(db_path, out_path, enc_key): file_size = os.path.getsize(db_path) total_pages = file_size // PAGE_SZ os.makedirs(os.path.dirname(out_path), exist_ok=True) with open(db_path, 'rb') as fin, open(out_path, 'wb') as fout: for pgno in range(1, total_pages + 1): page = fin.read(PAGE_SZ) if len(page) < PAGE_SZ: if len(page) > 0: page = page + b'\x00' * (PAGE_SZ - len(page)) else: break fout.write(decrypt_page(enc_key, page, pgno)) return total_pages def decrypt_wal(wal_path, out_path, enc_key): if not os.path.exists(wal_path): return 0 wal_size = os.path.getsize(wal_path) if wal_size <= WAL_HEADER_SZ: return 0 frame_size = WAL_FRAME_HEADER_SZ + PAGE_SZ patched = 0 with open(wal_path, 'rb') as wf, open(out_path, 'r+b') as df: wal_hdr = wf.read(WAL_HEADER_SZ) wal_salt1 = struct.unpack('>I', wal_hdr[16:20])[0] wal_salt2 = struct.unpack('>I', wal_hdr[20:24])[0] while wf.tell() + frame_size <= wal_size: fh = wf.read(WAL_FRAME_HEADER_SZ) if len(fh) < WAL_FRAME_HEADER_SZ: break pgno = struct.unpack('>I', fh[0:4])[0] frame_salt1 = struct.unpack('>I', fh[8:12])[0] frame_salt2 = struct.unpack('>I', fh[12:16])[0] ep = wf.read(PAGE_SZ) if len(ep) < PAGE_SZ: break if pgno == 0 or pgno > 1000000: continue if frame_salt1 != wal_salt1 or frame_salt2 != wal_salt2: continue dec = decrypt_page(enc_key, ep, pgno) df.seek((pgno - 1) * PAGE_SZ) df.write(dec) patched += 1 return patched # ============ DB 缓存 ============ class DBCache: """缓存解密后的 DB,通过 mtime 检测变化。使用固定文件名,重启后可复用。""" CACHE_DIR = os.path.join(tempfile.gettempdir(), "wechat_mcp_cache") MTIME_FILE = os.path.join(tempfile.gettempdir(), "wechat_mcp_cache", "_mtimes.json") def __init__(self): self._cache = {} # rel_key -> (db_mtime, wal_mtime, tmp_path) os.makedirs(self.CACHE_DIR, exist_ok=True) self._load_persistent_cache() def _cache_path(self, rel_key): """rel_key -> 固定的缓存文件路径""" h = hashlib.md5(rel_key.encode()).hexdigest()[:12] return os.path.join(self.CACHE_DIR, f"{h}.db") def _load_persistent_cache(self): """启动时从磁盘恢复缓存映射,验证 mtime 后复用""" if not os.path.exists(self.MTIME_FILE): return try: with open(self.MTIME_FILE) as f: saved = json.load(f) except (json.JSONDecodeError, OSError): return reused = 0 for rel_key, info in saved.items(): tmp_path = info["path"] if not os.path.exists(tmp_path): continue rel_path = rel_key.replace('\\', os.sep) db_path = os.path.join(DB_DIR, rel_path) wal_path = db_path + "-wal" try: db_mtime = os.path.getmtime(db_path) wal_mtime = os.path.getmtime(wal_path) if os.path.exists(wal_path) else 0 except OSError: continue if db_mtime == info["db_mt"] and wal_mtime == info["wal_mt"]: self._cache[rel_key] = (db_mtime, wal_mtime, tmp_path) reused += 1 if reused: print(f"[DBCache] reused {reused} cached decrypted DBs from previous run", flush=True) def _save_persistent_cache(self): """持久化缓存映射到磁盘""" data = {} for rel_key, (db_mt, wal_mt, path) in self._cache.items(): data[rel_key] = {"db_mt": db_mt, "wal_mt": wal_mt, "path": path} try: with open(self.MTIME_FILE, 'w') as f: json.dump(data, f) except OSError: pass def get(self, rel_key): key_info = get_key_info(ALL_KEYS, rel_key) if not key_info: return None rel_path = rel_key.replace('\\', '/').replace('/', os.sep) db_path = os.path.join(DB_DIR, rel_path) wal_path = db_path + "-wal" if not os.path.exists(db_path): return None try: db_mtime = os.path.getmtime(db_path) wal_mtime = os.path.getmtime(wal_path) if os.path.exists(wal_path) else 0 except OSError: return None if rel_key in self._cache: c_db_mt, c_wal_mt, c_path = self._cache[rel_key] if c_db_mt == db_mtime and c_wal_mt == wal_mtime and os.path.exists(c_path): return c_path tmp_path = self._cache_path(rel_key) enc_key = bytes.fromhex(key_info["enc_key"]) full_decrypt(db_path, tmp_path, enc_key) if os.path.exists(wal_path): decrypt_wal(wal_path, tmp_path, enc_key) self._cache[rel_key] = (db_mtime, wal_mtime, tmp_path) self._save_persistent_cache() return tmp_path def cleanup(self): """正常退出时保存缓存映射(不删文件,下次启动可复用)""" self._save_persistent_cache() _cache = DBCache() atexit.register(_cache.cleanup) # ============ 联系人缓存 ============ _contact_names = None # {username: display_name} _contact_full = None # [{username, nick_name, remark}] _self_username = None _XML_UNSAFE_RE = re.compile(r' 0xFFFFFFFF: return t & 0xFFFFFFFF, t >> 32 return t, 0 def resolve_username(chat_name): """将聊天名/备注名/wxid 解析为 username""" names = get_contact_names() # 直接是 username if chat_name in names or chat_name.startswith('wxid_') or '@chatroom' in chat_name: return chat_name # 模糊匹配(优先精确包含) chat_lower = chat_name.lower() for uname, display in names.items(): if chat_lower == display.lower(): return uname for uname, display in names.items(): if chat_lower in display.lower(): return uname return None _zstd_dctx = zstd.ZstdDecompressor() def _decompress_content(content, ct): """解压 zstd 压缩的消息内容""" if ct and ct == 4 and isinstance(content, bytes): try: return _zstd_dctx.decompress(content).decode('utf-8', errors='replace') except Exception: return None if isinstance(content, bytes): try: return content.decode('utf-8', errors='replace') except Exception: return None return content def _parse_message_content(content, local_type, is_group): """解析消息内容,返回 (sender_id, text)""" if content is None: return '', '' if isinstance(content, bytes): return '', '(二进制内容)' sender = '' text = content if is_group and ':\n' in content: sender, text = content.split(':\n', 1) return sender, text def _collapse_text(text): if not text: return '' return re.sub(r'\s+', ' ', text).strip() def _get_self_username(): global _self_username if _self_username: return _self_username if not DB_DIR: return '' names = get_contact_names() account_dir = os.path.basename(os.path.dirname(DB_DIR)) candidates = [account_dir] m = re.fullmatch(r'(.+)_([0-9a-fA-F]{4,})', account_dir) if m: candidates.insert(0, m.group(1)) for candidate in candidates: if candidate and candidate in names: _self_username = candidate return _self_username return '' def _load_name2id_maps(conn): id_to_username = {} try: rows = conn.execute("SELECT rowid, user_name FROM Name2Id").fetchall() except sqlite3.Error: return id_to_username for rowid, user_name in rows: if not user_name: continue id_to_username[rowid] = user_name return id_to_username def _display_name_for_username(username, names): if not username: return '' if username == _get_self_username(): return 'me' return names.get(username, username) def _resolve_sender_label(real_sender_id, sender_from_content, is_group, chat_username, chat_display_name, names, id_to_username): sender_username = id_to_username.get(real_sender_id, '') if is_group: if sender_username and sender_username != chat_username: return _display_name_for_username(sender_username, names) if sender_from_content: return _display_name_for_username(sender_from_content, names) return '' if sender_username == chat_username: return chat_display_name if sender_username: return _display_name_for_username(sender_username, names) return '' def _resolve_quote_sender_label(ref_user, ref_display_name, is_group, chat_username, chat_display_name, names): if is_group: if ref_user: return _display_name_for_username(ref_user, names) return ref_display_name or '' self_username = _get_self_username() if ref_user: if ref_user == chat_username: return chat_display_name if self_username and ref_user == self_username: return 'me' return names.get(ref_user, ref_display_name or ref_user) if ref_display_name: if ref_display_name == chat_display_name: return chat_display_name self_display_name = names.get(self_username, self_username) if self_username else '' if self_display_name and ref_display_name == self_display_name: return 'me' return ref_display_name return '' def _parse_xml_root(content): if not content or len(content) > _XML_PARSE_MAX_LEN or _XML_UNSAFE_RE.search(content): return None try: return ET.fromstring(content) except ET.ParseError: return None def _parse_int(value, fallback=0): try: return int(value) except (TypeError, ValueError): return fallback def _format_app_message_text(content, local_type, is_group, chat_username, chat_display_name, names): if not content or ' 160: ref_content = ref_content[:160] + "..." quote_text = title or "[引用消息]" if ref_content: ref_label = _resolve_quote_sender_label( ref_user, ref_display_name, is_group, chat_username, chat_display_name, names ) prefix = f"回复 {ref_label}: " if ref_label else "回复: " quote_text += f"\n ↳ {prefix}{ref_content}" return quote_text if app_type == 6: return f"[文件] {title}" if title else "[文件]" if app_type == 5: return f"[链接] {title}" if title else "[链接]" if app_type in (33, 36, 44): return f"[小程序] {title}" if title else "[小程序]" if title: return f"[链接/文件] {title}" return "[链接/文件]" def _format_voip_message_text(content): if not content or ' end_ts: raise ValueError('start_time 不能晚于 end_time') return start_ts, end_ts def _build_message_filters(start_ts=None, end_ts=None, keyword=''): clauses = [] params = [] if start_ts is not None: clauses.append('create_time >= ?') params.append(start_ts) if end_ts is not None: clauses.append('create_time <= ?') params.append(end_ts) if keyword: clauses.append('message_content LIKE ?') params.append(f'%{keyword}%') return clauses, params def _query_messages(conn, table_name, start_ts=None, end_ts=None, keyword='', limit=20, offset=0): if not _is_safe_msg_table_name(table_name): raise ValueError(f'非法消息表名: {table_name}') clauses, params = _build_message_filters(start_ts, end_ts, keyword) where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' sql = f""" SELECT local_id, local_type, create_time, real_sender_id, message_content, WCDB_CT_message_content FROM [{table_name}] {where_sql} ORDER BY create_time DESC LIMIT ? OFFSET ? """ return conn.execute(sql, (*params, limit, offset)).fetchall() def _resolve_chat_context(chat_name): username = resolve_username(chat_name) if not username: return None names = get_contact_names() display_name = names.get(username, username) db_path, table_name = _find_msg_table_for_user(username) if not db_path: return { 'query': chat_name, 'username': username, 'display_name': display_name, 'db_path': None, 'table_name': None, 'is_group': '@chatroom' in username, } return { 'query': chat_name, 'username': username, 'display_name': display_name, 'db_path': db_path, 'table_name': table_name, 'is_group': '@chatroom' in username, } def _resolve_chat_contexts(chat_names): if not chat_names: raise ValueError('chat_names 不能为空') resolved = [] unresolved = [] missing_tables = [] seen = set() for chat_name in chat_names: name = (chat_name or '').strip() if not name: unresolved.append('(空)') continue ctx = _resolve_chat_context(name) if not ctx: unresolved.append(name) continue if not ctx['db_path']: missing_tables.append(ctx['display_name']) continue if ctx['username'] in seen: continue seen.add(ctx['username']) resolved.append(ctx) return resolved, unresolved, missing_tables def _normalize_chat_names(chat_name): if chat_name is None: return [] if isinstance(chat_name, str): value = chat_name.strip() return [value] if value else [] if isinstance(chat_name, (list, tuple, set)): normalized = [] for item in chat_name: if item is None: continue value = str(item).strip() if value: normalized.append(value) return normalized value = str(chat_name).strip() return [value] if value else [] def _format_history_lines(rows, username, display_name, is_group, names, id_to_username): lines = [] for local_id, local_type, create_time, real_sender_id, content, ct in reversed(rows): time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') content = _decompress_content(content, ct) if content is None: content = '(无法解压)' sender, text = _format_message_text( local_id, local_type, content, is_group, username, display_name, names ) if text and len(text) > 500: text = text[:500] + '...' sender_label = _resolve_sender_label( real_sender_id, sender, is_group, username, display_name, names, id_to_username ) if sender_label: lines.append(f'[{time_str}] {sender_label}: {text}') else: lines.append(f'[{time_str}] {text}') return lines def _build_search_entry(row, ctx, names, id_to_username): local_id, local_type, create_time, real_sender_id, content, ct = row content = _decompress_content(content, ct) if content is None: return None sender, text = _format_message_text( local_id, local_type, content, ctx['is_group'], ctx['username'], ctx['display_name'], names ) if text and len(text) > 300: text = text[:300] + '...' sender_label = _resolve_sender_label( real_sender_id, sender, ctx['is_group'], ctx['username'], ctx['display_name'], names, id_to_username, ) time_str = datetime.fromtimestamp(create_time).strftime('%Y-%m-%d %H:%M') entry = f"[{time_str}] [{ctx['display_name']}]" if sender_label: entry += f" {sender_label}:" entry += f" {text}" return create_time, entry # ============ MCP Server ============ mcp = FastMCP("wechat", instructions="查询微信消息、联系人等数据") # 新消息追踪 _last_check_state = {} # {username: last_timestamp} @mcp.tool() def get_recent_sessions(limit: int = 20) -> str: """获取微信最近会话列表,包含最新消息摘要、未读数、时间等。 用于了解最近有哪些人/群在聊天。 Args: limit: 返回的会话数量,默认20 """ path = _cache.get(os.path.join("session", "session.db")) if not path: return "错误: 无法解密 session.db" names = get_contact_names() conn = sqlite3.connect(path) rows = conn.execute(""" SELECT username, unread_count, summary, last_timestamp, last_msg_type, last_msg_sender, last_sender_display_name FROM SessionTable WHERE last_timestamp > 0 ORDER BY last_timestamp DESC LIMIT ? """, (limit,)).fetchall() conn.close() results = [] for r in rows: username, unread, summary, ts, msg_type, sender, sender_name = r display = names.get(username, username) is_group = '@chatroom' in username if isinstance(summary, bytes): try: summary = _zstd_dctx.decompress(summary).decode('utf-8', errors='replace') except Exception: summary = '(压缩内容)' if isinstance(summary, str) and ':\n' in summary: summary = summary.split(':\n', 1)[1] sender_display = '' if is_group and sender: sender_display = names.get(sender, sender_name or sender) time_str = datetime.fromtimestamp(ts).strftime('%m-%d %H:%M') entry = f"[{time_str}] {display}" if is_group: entry += " [群]" if unread and unread > 0: entry += f" ({unread}条未读)" entry += f"\n {format_msg_type(msg_type)}: " if sender_display: entry += f"{sender_display}: " entry += str(summary or "(无内容)") results.append(entry) return f"最近 {len(results)} 个会话:\n\n" + "\n\n".join(results) @mcp.tool() def get_chat_history(chat_name: str, limit: int = 50, offset: int = 0, start_time: str = "", end_time: str = "") -> str: """获取指定聊天的消息记录。 Args: chat_name: 聊天对象的名字、备注名或wxid,自动模糊匹配 limit: 返回的消息数量,默认50 offset: 分页偏移量,默认0 start_time: 起始时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS end_time: 结束时间,支持 YYYY-MM-DD / YYYY-MM-DD HH:MM / YYYY-MM-DD HH:MM:SS """ try: _validate_pagination(limit, offset) start_ts, end_ts = _parse_time_range(start_time, end_time) except ValueError as e: return f"错误: {e}" ctx = _resolve_chat_context(chat_name) if not ctx: return f"找不到聊天对象: {chat_name}\n提示: 可以用 get_contacts(query='{chat_name}') 搜索联系人" if not ctx['db_path']: return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" names = get_contact_names() conn = sqlite3.connect(ctx['db_path']) try: id_to_username = _load_name2id_maps(conn) rows = _query_messages( conn, ctx['table_name'], start_ts=start_ts, end_ts=end_ts, limit=limit, offset=offset, ) except Exception as e: conn.close() return f"查询失败: {e}" conn.close() if not rows: return f"{ctx['display_name']} 无消息记录" lines = _format_history_lines( rows, ctx['username'], ctx['display_name'], ctx['is_group'], names, id_to_username, ) header = f"{ctx['display_name']} 的消息记录(返回 {len(lines)} 条,offset={offset}, limit={limit})" if ctx['is_group']: header += " [群聊]" if start_time or end_time: header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" return header + ":\n\n" + "\n".join(lines) @mcp.tool() def search_messages( keyword: str, chat_name: str | list[str] | None = None, start_time: str = "", end_time: str = "", limit: int = 20, offset: int = 0, ) -> str: """搜索消息内容,支持全库、单个聊天对象、多个聊天对象,以及时间范围和分页。 Args: keyword: 搜索关键词 chat_name: 聊天对象名称,可为空、单个字符串或字符串列表 start_time: 起始时间,可为空 end_time: 结束时间,可为空 limit: 返回的结果数量,默认20 offset: 分页偏移量,默认0 """ if not keyword or len(keyword) < 1: return "请提供搜索关键词" chat_names = _normalize_chat_names(chat_name) try: _validate_pagination(limit, offset) start_ts, end_ts = _parse_time_range(start_time, end_time) except ValueError as e: return f"错误: {e}" if len(chat_names) == 1: ctx = _resolve_chat_context(chat_names[0]) if not ctx: return f"找不到聊天对象: {chat_names[0]}\n提示: 可以用 get_contacts(query='{chat_names[0]}') 搜索联系人" if not ctx['db_path']: return f"找不到 {ctx['display_name']} 的消息记录(可能在未解密的DB中或无消息)" names = get_contact_names() conn = sqlite3.connect(ctx['db_path']) try: id_to_username = _load_name2id_maps(conn) rows = _query_messages( conn, ctx['table_name'], start_ts=start_ts, end_ts=end_ts, keyword=keyword, limit=limit, offset=offset, ) except Exception as e: conn.close() return f"查询失败: {e}" conn.close() if not rows: return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的消息" entries = [] for row in rows: formatted = _build_search_entry(row, ctx, names, id_to_username) if formatted: entries.append(formatted) if not entries: return f"未在 {ctx['display_name']} 中找到包含 \"{keyword}\" 的可读消息" entries.sort(key=lambda x: x[0]) header = f"在 {ctx['display_name']} 中搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" if start_time or end_time: header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" return header + ":\n\n" + "\n\n".join(item[1] for item in entries) if len(chat_names) > 1: try: resolved_contexts, unresolved, missing_tables = _resolve_chat_contexts(chat_names) except ValueError as e: return f"错误: {e}" if not resolved_contexts: details = [] if unresolved: details.append("未找到联系人: " + "、".join(unresolved)) if missing_tables: details.append("无消息表: " + "、".join(missing_tables)) suffix = f"\n{chr(10).join(details)}" if details else "" return f"错误: 没有可查询的聊天对象{suffix}" names = get_contact_names() collected = [] failures = [] per_chat_limit = limit + offset for ctx in resolved_contexts: conn = sqlite3.connect(ctx['db_path']) try: id_to_username = _load_name2id_maps(conn) rows = _query_messages( conn, ctx['table_name'], start_ts=start_ts, end_ts=end_ts, keyword=keyword, limit=per_chat_limit, offset=0, ) for row in rows: formatted = _build_search_entry(row, ctx, names, id_to_username) if formatted: collected.append(formatted) except Exception as e: failures.append(f"{ctx['display_name']}: {e}") finally: conn.close() collected.sort(key=lambda x: x[0], reverse=True) paged = collected[offset:offset + limit] notes = [] if unresolved: notes.append("未找到联系人: " + "、".join(unresolved)) if missing_tables: notes.append("无消息表: " + "、".join(missing_tables)) if failures: notes.append("查询失败: " + ";".join(failures)) if not paged: header = f"在 {len(resolved_contexts)} 个聊天对象中未找到包含 \"{keyword}\" 的消息" if start_time or end_time: header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" if notes: header += "\n" + "\n".join(notes) return header header = ( f"在 {len(resolved_contexts)} 个聊天对象中搜索 \"{keyword}\" 找到 {len(paged)} 条结果" f"(offset={offset}, limit={limit})" ) if start_time or end_time: header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" if notes: header += "\n" + "\n".join(notes) return header + ":\n\n" + "\n\n".join(item[1] for item in paged) names = get_contact_names() results = [] max_results = limit + offset for rel_key in MSG_DB_KEYS: if len(results) >= max_results: break path = _cache.get(rel_key) if not path: continue conn = sqlite3.connect(path) try: # 获取所有 Msg_ 表 tables = conn.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Msg_%'" ).fetchall() # 获取 Name2Id 映射(hash -> username 反查) name2id = {} try: for r in conn.execute("SELECT user_name FROM Name2Id").fetchall(): h = hashlib.md5(r[0].encode()).hexdigest() name2id[f"Msg_{h}"] = r[0] except Exception: pass for (tname,) in tables: if len(results) >= max_results: break username = name2id.get(tname, '') is_group = '@chatroom' in username display = names.get(username, username) if username else tname try: clauses, params = _build_message_filters(start_ts, end_ts, keyword) where_sql = f"WHERE {' AND '.join(clauses)}" if clauses else '' rows = conn.execute(f""" SELECT local_type, create_time, message_content, WCDB_CT_message_content FROM [{tname}] {where_sql} ORDER BY create_time DESC LIMIT ? OFFSET ? """, (*params, max_results - len(results), 0)).fetchall() except Exception: continue for local_type, ts, content, ct in rows: content = _decompress_content(content, ct) if content is None: continue sender, text = _parse_message_content(content, local_type, is_group) time_str = datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M') sender_name = '' if is_group and sender: sender_name = names.get(sender, sender) entry = f"[{time_str}] [{display}]" if sender_name: entry += f" {sender_name}:" entry += f" {text}" if len(entry) > 300: entry = entry[:300] + "..." results.append((ts, entry)) finally: conn.close() results.sort(key=lambda x: x[0], reverse=True) entries = [r[1] for r in results[offset:offset + limit]] if not entries: return f"未找到包含 \"{keyword}\" 的消息" header = f"搜索 \"{keyword}\" 找到 {len(entries)} 条结果(offset={offset}, limit={limit})" if start_time or end_time: header += f"\n时间范围: {start_time or '最早'} ~ {end_time or '最新'}" return header + ":\n\n" + "\n\n".join(entries) @mcp.tool() def get_contacts(query: str = "", limit: int = 50) -> str: """搜索或列出微信联系人。 Args: query: 搜索关键词(匹配昵称、备注名、wxid),留空列出所有 limit: 返回数量,默认50 """ contacts = get_contact_full() if not contacts: return "错误: 无法加载联系人数据" if query: q = query.lower() filtered = [ c for c in contacts if q in c['nick_name'].lower() or q in c['remark'].lower() or q in c['username'].lower() ] else: filtered = contacts filtered = filtered[:limit] if not filtered: return f"未找到匹配 \"{query}\" 的联系人" lines = [] for c in filtered: line = c['username'] if c['remark']: line += f" 备注: {c['remark']}" if c['nick_name']: line += f" 昵称: {c['nick_name']}" lines.append(line) header = f"找到 {len(filtered)} 个联系人" if query: header += f"(搜索: {query})" return header + ":\n\n" + "\n".join(lines) @mcp.tool() def get_new_messages() -> str: """获取自上次调用以来的新消息。首次调用返回最近的会话状态。""" global _last_check_state path = _cache.get(os.path.join("session", "session.db")) if not path: return "错误: 无法解密 session.db" names = get_contact_names() conn = sqlite3.connect(path) rows = conn.execute(""" SELECT username, unread_count, summary, last_timestamp, last_msg_type, last_msg_sender, last_sender_display_name FROM SessionTable WHERE last_timestamp > 0 ORDER BY last_timestamp DESC """).fetchall() conn.close() curr_state = {} for r in rows: username, unread, summary, ts, msg_type, sender, sender_name = r curr_state[username] = { 'unread': unread, 'summary': summary, 'timestamp': ts, 'msg_type': msg_type, 'sender': sender or '', 'sender_name': sender_name or '', } if not _last_check_state: _last_check_state = {u: s['timestamp'] for u, s in curr_state.items()} # 首次调用,返回有未读的会话 unread_msgs = [] for username, s in curr_state.items(): if s['unread'] and s['unread'] > 0: display = names.get(username, username) is_group = '@chatroom' in username summary = s['summary'] if isinstance(summary, bytes): try: summary = _zstd_dctx.decompress(summary).decode('utf-8', errors='replace') except Exception: summary = '(压缩内容)' if isinstance(summary, str) and ':\n' in summary: summary = summary.split(':\n', 1)[1] time_str = datetime.fromtimestamp(s['timestamp']).strftime('%H:%M') tag = "[群]" if is_group else "" unread_msgs.append(f"[{time_str}] {display}{tag} ({s['unread']}条未读): {summary}") if unread_msgs: return f"当前 {len(unread_msgs)} 个未读会话:\n\n" + "\n".join(unread_msgs) return "当前无未读消息(已记录状态,下次调用将返回新消息)" # 对比上次状态 new_msgs = [] for username, s in curr_state.items(): prev_ts = _last_check_state.get(username, 0) if s['timestamp'] > prev_ts: display = names.get(username, username) is_group = '@chatroom' in username summary = s['summary'] if isinstance(summary, bytes): try: summary = _zstd_dctx.decompress(summary).decode('utf-8', errors='replace') except Exception: summary = '(压缩内容)' if isinstance(summary, str) and ':\n' in summary: summary = summary.split(':\n', 1)[1] sender_display = '' if is_group and s['sender']: sender_display = names.get(s['sender'], s['sender_name'] or s['sender']) time_str = datetime.fromtimestamp(s['timestamp']).strftime('%H:%M:%S') entry = f"[{time_str}] {display}" if is_group: entry += " [群]" entry += f": {format_msg_type(s['msg_type'])}" if sender_display: entry += f" ({sender_display})" entry += f" - {summary}" new_msgs.append((s['timestamp'], entry)) _last_check_state = {u: s['timestamp'] for u, s in curr_state.items()} if not new_msgs: return "无新消息" new_msgs.sort(key=lambda x: x[0]) entries = [m[1] for m in new_msgs] return f"{len(entries)} 条新消息:\n\n" + "\n".join(entries) # ============ 图片解密 ============ _image_resolver = ImageResolver(WECHAT_BASE_DIR, DECODED_IMAGE_DIR, _cache) @mcp.tool() def decode_image(chat_name: str, local_id: int) -> str: """解密微信聊天中的一张图片。 先用 get_chat_history 查看消息,图片消息会显示 local_id, 然后用此工具解密对应图片。 Args: chat_name: 聊天对象的名字、备注名或wxid local_id: 图片消息的 local_id(从 get_chat_history 获取) """ username = resolve_username(chat_name) if not username: return f"找不到聊天对象: {chat_name}" result = _image_resolver.decode_image(username, local_id) if result['success']: return ( f"解密成功!\n" f" 文件: {result['path']}\n" f" 格式: {result['format']}\n" f" 大小: {result['size']:,} bytes\n" f" MD5: {result['md5']}" ) else: error = result['error'] if 'md5' in result: error += f"\n MD5: {result['md5']}" return f"解密失败: {error}" @mcp.tool() def get_chat_images(chat_name: str, limit: int = 20) -> str: """列出某个聊天中的图片消息。 返回图片的时间、local_id、MD5、文件大小等信息。 可以配合 decode_image 工具解密指定图片。 Args: chat_name: 聊天对象的名字、备注名或wxid limit: 返回数量,默认20 """ username = resolve_username(chat_name) if not username: return f"找不到聊天对象: {chat_name}" names = get_contact_names() display_name = names.get(username, username) db_path, table_name = _find_msg_table_for_user(username) if not db_path: return f"找不到 {display_name} 的消息记录" images = _image_resolver.list_chat_images(db_path, table_name, username, limit) if not images: return f"{display_name} 无图片消息" lines = [] for img in images: time_str = datetime.fromtimestamp(img['create_time']).strftime('%Y-%m-%d %H:%M') line = f"[{time_str}] local_id={img['local_id']}" if img.get('md5'): line += f" MD5={img['md5']}" if img.get('size'): size_kb = img['size'] / 1024 line += f" {size_kb:.0f}KB" if not img.get('md5'): line += " (无资源信息)" lines.append(line) return f"{display_name} 的 {len(lines)} 张图片:\n\n" + "\n".join(lines) if __name__ == "__main__": mcp.run()