diff --git a/mcp_server.py b/mcp_server.py index e2a98c4..f4261aa 100644 --- a/mcp_server.py +++ b/mcp_server.py @@ -120,10 +120,59 @@ def decrypt_wal(wal_path, out_path, enc_key): # ============ DB 缓存 ============ class DBCache: - """缓存解密后的 DB,通过 mtime 检测变化""" + """缓存解密后的 DB,通过 mtime 检测变化。使用固定文件名,重启后可复用。""" + + CACHE_DIR = os.path.join(tempfile.gettempdir(), "wechat_mcp_cache") + MTIME_FILE = os.path.join(tempfile.gettempdir(), "wechat_mcp_cache", "_mtimes.json") def __init__(self): self._cache = {} # rel_key -> (db_mtime, wal_mtime, tmp_path) + os.makedirs(self.CACHE_DIR, exist_ok=True) + self._load_persistent_cache() + + def _cache_path(self, rel_key): + """rel_key -> 固定的缓存文件路径""" + h = hashlib.md5(rel_key.encode()).hexdigest()[:12] + return os.path.join(self.CACHE_DIR, f"{h}.db") + + def _load_persistent_cache(self): + """启动时从磁盘恢复缓存映射,验证 mtime 后复用""" + if not os.path.exists(self.MTIME_FILE): + return + try: + with open(self.MTIME_FILE) as f: + saved = json.load(f) + except (json.JSONDecodeError, OSError): + return + reused = 0 + for rel_key, info in saved.items(): + tmp_path = info["path"] + if not os.path.exists(tmp_path): + continue + rel_path = rel_key.replace('\\', os.sep) + db_path = os.path.join(DB_DIR, rel_path) + wal_path = db_path + "-wal" + try: + db_mtime = os.path.getmtime(db_path) + wal_mtime = os.path.getmtime(wal_path) if os.path.exists(wal_path) else 0 + except OSError: + continue + if db_mtime == info["db_mt"] and wal_mtime == info["wal_mt"]: + self._cache[rel_key] = (db_mtime, wal_mtime, tmp_path) + reused += 1 + if reused: + print(f"[DBCache] reused {reused} cached decrypted DBs from previous run", flush=True) + + def _save_persistent_cache(self): + """持久化缓存映射到磁盘""" + data = {} + for rel_key, (db_mt, wal_mt, path) in self._cache.items(): + data[rel_key] = {"db_mt": db_mt, "wal_mt": wal_mt, "path": path} + try: + with open(self.MTIME_FILE, 'w') as f: + json.dump(data, f) + except OSError: + pass def get(self, rel_key): if rel_key not in ALL_KEYS: @@ -144,27 +193,19 @@ class DBCache: c_db_mt, c_wal_mt, c_path = self._cache[rel_key] if c_db_mt == db_mtime and c_wal_mt == wal_mtime and os.path.exists(c_path): return c_path - try: - os.unlink(c_path) - except OSError: - pass + tmp_path = self._cache_path(rel_key) enc_key = bytes.fromhex(ALL_KEYS[rel_key]["enc_key"]) - fd, tmp_path = tempfile.mkstemp(suffix='.db') - os.close(fd) full_decrypt(db_path, tmp_path, enc_key) if os.path.exists(wal_path): decrypt_wal(wal_path, tmp_path, enc_key) self._cache[rel_key] = (db_mtime, wal_mtime, tmp_path) + self._save_persistent_cache() return tmp_path def cleanup(self): - for _, _, path in self._cache.values(): - try: - os.unlink(path) - except OSError: - pass - self._cache.clear() + """正常退出时保存缓存映射(不删文件,下次启动可复用)""" + self._save_persistent_cache() _cache = DBCache()