From bccc6d413f55fa4e252e53e86630a19b5f7a6e70 Mon Sep 17 00:00:00 2001 From: z060142 Date: Fri, 9 May 2025 12:32:06 +0800 Subject: [PATCH] Migrate ChromaDB embedding model to paraphrase-multilingual-mpnet-base-v2 --- Setup.py | 46 +++- chroma_client.py | 54 ++++- memory_manager.py | 68 +++--- reembed_chroma_data.py | 529 +++++++++++++++++++++++++++++++++++++++++ tools/chroma_view.py | 19 +- 5 files changed, 666 insertions(+), 50 deletions(-) create mode 100644 reembed_chroma_data.py diff --git a/Setup.py b/Setup.py index fd3ffcd..bd0918e 100644 --- a/Setup.py +++ b/Setup.py @@ -316,6 +316,15 @@ def load_current_config(): if backup_minute_match: config_data["MEMORY_BACKUP_MINUTE"] = int(backup_minute_match.group(1)) + # Extract EMBEDDING_MODEL_NAME + embedding_model_match = re.search(r'EMBEDDING_MODEL_NAME\s*=\s*["\'](.+?)["\']', config_content) + if embedding_model_match: + config_data["EMBEDDING_MODEL_NAME"] = embedding_model_match.group(1) + else: + # Default if not found in config.py, will be set in UI if not overridden by load + config_data["EMBEDDING_MODEL_NAME"] = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + + profile_model_match = re.search(r'MEMORY_PROFILE_MODEL\s*=\s*["\']?(.+?)["\']?\s*(?:#|$)', config_content) # Handle potential LLM_MODEL reference if profile_model_match: @@ -537,10 +546,15 @@ def generate_config_file(config_data, env_data): f.write(f"MEMORY_BACKUP_MINUTE = {backup_minute}\n") # Write profile model, potentially referencing LLM_MODEL if profile_model == config_data.get('LLM_MODEL'): - f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n") + f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n") else: - f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n") - f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n") + f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n") + f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n\n") + + # Write Embedding Model Name + embedding_model_name = config_data.get('EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + f.write("# Embedding model for ChromaDB\n") + f.write(f"EMBEDDING_MODEL_NAME = \"{embedding_model_name}\"\n") print("Generated config.py file successfully") @@ -1717,6 +1731,24 @@ class WolfChatSetup(tk.Tk): related_info = ttk.Label(related_frame, text="(0 to disable related memories pre-loading)") related_info.pack(side=tk.LEFT, padx=(5, 0)) + # Embedding Model Settings Frame + embedding_model_settings_frame = ttk.LabelFrame(main_frame, text="Embedding Model Settings") + embedding_model_settings_frame.pack(fill=tk.X, pady=10) + + embedding_model_name_frame = ttk.Frame(embedding_model_settings_frame) + embedding_model_name_frame.pack(fill=tk.X, pady=5, padx=10) + + embedding_model_name_label = ttk.Label(embedding_model_name_frame, text="Embedding Model Name:", width=25) # Adjusted width + embedding_model_name_label.pack(side=tk.LEFT, padx=(0, 5)) + + self.embedding_model_name_var = tk.StringVar(value="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + embedding_model_name_entry = ttk.Entry(embedding_model_name_frame, textvariable=self.embedding_model_name_var) + embedding_model_name_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + + embedding_model_info = ttk.Label(embedding_model_settings_frame, text="Default: sentence-transformers/paraphrase-multilingual-mpnet-base-v2", justify=tk.LEFT) + embedding_model_info.pack(anchor=tk.W, padx=10, pady=(0,5)) + + # Information box info_frame = ttk.LabelFrame(main_frame, text="Information") info_frame.pack(fill=tk.BOTH, expand=True, pady=10) @@ -2067,6 +2099,10 @@ class WolfChatSetup(tk.Tk): self.profiles_collection_var.set(self.config_data.get("PROFILES_COLLECTION", "user_profiles")) # Default was user_profiles self.conversations_collection_var.set(self.config_data.get("CONVERSATIONS_COLLECTION", "conversations")) self.bot_memory_collection_var.set(self.config_data.get("BOT_MEMORY_COLLECTION", "wolfhart_memory")) + # Embedding Model Name for Memory Settings Tab + if hasattr(self, 'embedding_model_name_var'): + self.embedding_model_name_var.set(self.config_data.get("EMBEDDING_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")) + # Memory Management Tab Settings if hasattr(self, 'backup_hour_var'): # Check if UI elements for memory management tab exist @@ -2343,6 +2379,10 @@ class WolfChatSetup(tk.Tk): self.config_data["PROFILES_COLLECTION"] = self.profiles_collection_var.get() self.config_data["CONVERSATIONS_COLLECTION"] = self.conversations_collection_var.get() self.config_data["BOT_MEMORY_COLLECTION"] = self.bot_memory_collection_var.get() + # Save Embedding Model Name from Memory Settings Tab + if hasattr(self, 'embedding_model_name_var'): + self.config_data["EMBEDDING_MODEL_NAME"] = self.embedding_model_name_var.get() + # Get Memory Management settings from UI if hasattr(self, 'backup_hour_var'): # Check if UI elements exist diff --git a/chroma_client.py b/chroma_client.py index db05626..6149703 100644 --- a/chroma_client.py +++ b/chroma_client.py @@ -1,6 +1,7 @@ # chroma_client.py import chromadb from chromadb.config import Settings +from chromadb.utils import embedding_functions # New import import os import json import config @@ -10,6 +11,33 @@ import time _client = None _collections = {} +# Global embedding function variable +_embedding_function = None + +def get_embedding_function(): + """Gets or creates the embedding function based on config""" + global _embedding_function + if _embedding_function is None: + # Default to paraphrase-multilingual-mpnet-base-v2 if not specified or on error + model_name = getattr(config, 'EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + try: + _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) + print(f"Successfully initialized embedding function with model: {model_name}") + except Exception as e: + print(f"Failed to initialize embedding function with model '{model_name}': {e}") + # Fallback to default if specified model fails and it's not already the default + if model_name != "sentence-transformers/paraphrase-multilingual-mpnet-base-v2": + print("Falling back to default embedding model: sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + try: + _embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + print(f"Successfully initialized embedding function with default model.") + except Exception as e_default: + print(f"Failed to initialize default embedding function: {e_default}") + _embedding_function = None # Ensure it's None if all attempts fail + else: + _embedding_function = None # Ensure it's None if default model also fails + return _embedding_function + def initialize_chroma_client(): """Initializes and connects to ChromaDB""" global _client @@ -34,13 +62,31 @@ def get_collection(collection_name): if collection_name not in _collections: try: + emb_func = get_embedding_function() + if emb_func is None: + print(f"Failed to get or create collection '{collection_name}' due to embedding function initialization failure.") + return None + _collections[collection_name] = _client.get_or_create_collection( - name=collection_name + name=collection_name, + embedding_function=emb_func ) - print(f"Successfully got or created collection '{collection_name}'") + print(f"Successfully got or created collection '{collection_name}' using configured embedding function.") except Exception as e: - print(f"Failed to get collection '{collection_name}': {e}") - return None + print(f"Failed to get collection '{collection_name}' with configured embedding function: {e}") + # Attempt to create collection with default embedding function as a fallback + print(f"Attempting to create collection '{collection_name}' with default embedding function...") + try: + # Ensure we try the absolute default if the configured one (even if it was the default) failed + default_emb_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + _collections[collection_name] = _client.get_or_create_collection( + name=collection_name, + embedding_function=default_emb_func + ) + print(f"Successfully got or created collection '{collection_name}' with default embedding function after initial failure.") + except Exception as e_default: + print(f"Failed to get collection '{collection_name}' even with default embedding function: {e_default}") + return None return _collections[collection_name] diff --git a/memory_manager.py b/memory_manager.py index fe4b3fe..1a52aea 100644 --- a/memory_manager.py +++ b/memory_manager.py @@ -16,11 +16,12 @@ import schedule from pathlib import Path from typing import Dict, List, Optional, Any, Union -import chromadb -from chromadb.utils import embedding_functions +# import chromadb # No longer directly needed by ChromaDBManager +# from chromadb.utils import embedding_functions # No longer directly needed by ChromaDBManager from openai import AsyncOpenAI import config +import chroma_client # Import the centralized chroma client # ============================================================================= # 日誌解析部分 @@ -345,28 +346,22 @@ class MemoryGenerator: class ChromaDBManager: def __init__(self, collection_name: Optional[str] = None): - self.client = chromadb.PersistentClient(path=config.CHROMA_DATA_DIR) self.collection_name = collection_name or config.BOT_MEMORY_COLLECTION - self.embedding_function = embedding_functions.DefaultEmbeddingFunction() - self._ensure_collection() - - def _ensure_collection(self) -> None: - """確保集合存在""" - try: - self.collection = self.client.get_collection( - name=self.collection_name, - embedding_function=self.embedding_function - ) - print(f"Connected to existing collection: {self.collection_name}") - except Exception: - self.collection = self.client.create_collection( - name=self.collection_name, - embedding_function=self.embedding_function - ) - print(f"Created new collection: {self.collection_name}") - + self._db_collection = None # Cache for the collection object + + def _get_db_collection(self): + """Helper to get the collection object from chroma_client""" + if self._db_collection is None: + # Use the centralized get_collection function + self._db_collection = chroma_client.get_collection(self.collection_name) + if self._db_collection is None: + # This indicates a failure in chroma_client to provide the collection + raise RuntimeError(f"Failed to get or create collection '{self.collection_name}' via chroma_client. Check chroma_client logs.") + return self._db_collection + def upsert_user_profile(self, profile_data: Dict[str, Any]) -> bool: """寫入或更新用戶檔案""" + collection = self._get_db_collection() if not profile_data or not isinstance(profile_data, dict): print("無效的檔案數據") return False @@ -377,14 +372,13 @@ class ChromaDBManager: print("檔案缺少ID字段") return False - # 先檢查是否已存在 - results = self.collection.get( - ids=[user_id], # Query by a list of IDs - # where={"id": user_id}, # 'where' is for metadata filtering - limit=1 - ) - # 準備元數據 + # Note: ChromaDB's upsert handles existence check implicitly. + # The .get call here isn't strictly necessary for the upsert operation itself, + # but might be kept if there was other logic depending on prior existence. + # For a clean upsert, it can be removed. Let's assume it's not critical for now. + # results = collection.get(ids=[user_id], limit=1) # Optional: if needed for pre-check logic + metadata = { "id": user_id, "type": "user_profile", @@ -402,14 +396,12 @@ class ChromaDBManager: content_doc = json.dumps(profile_data.get("content", {}), ensure_ascii=False) # 寫入或更新 - # ChromaDB's add/upsert handles both cases. - # If an ID exists, it's an update; otherwise, it's an add. - self.collection.upsert( + collection.upsert( ids=[user_id], documents=[content_doc], metadatas=[metadata] ) - print(f"Upserted user profile: {user_id}") + print(f"Upserted user profile: {user_id} into collection {self.collection_name}") return True @@ -419,6 +411,7 @@ class ChromaDBManager: def upsert_conversation_summary(self, summary_data: Dict[str, Any]) -> bool: """寫入對話總結""" + collection = self._get_db_collection() if not summary_data or not isinstance(summary_data, dict): print("無效的總結數據") return False @@ -450,13 +443,13 @@ class ChromaDBManager: key_points_str = "\n".join([f"- {point}" for point in summary_data["key_points"]]) content_doc += f"\n\n關鍵點:\n{key_points_str}" - # 寫入數據 (ChromaDB's add implies upsert if ID exists, but upsert is more explicit) - self.collection.upsert( + # 寫入數據 + collection.upsert( ids=[summary_id], documents=[content_doc], metadatas=[metadata] ) - print(f"Upserted conversation summary: {summary_id}") + print(f"Upserted conversation summary: {summary_id} into collection {self.collection_name}") return True @@ -466,10 +459,11 @@ class ChromaDBManager: def get_existing_profile(self, username: str) -> Optional[Dict[str, Any]]: """獲取現有的用戶檔案""" + collection = self._get_db_collection() try: profile_id = f"{username}_profile" - results = self.collection.get( - ids=[profile_id], # Query by a list of IDs + results = collection.get( + ids=[profile_id], limit=1 ) diff --git a/reembed_chroma_data.py b/reembed_chroma_data.py new file mode 100644 index 0000000..328adb5 --- /dev/null +++ b/reembed_chroma_data.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +重新嵌入工具 (Reembedding Tool) + +這個腳本用於將現有ChromaDB集合中的數據使用新的嵌入模型重新計算向量並儲存。 +""" + +import os +import sys +import json +import time +import argparse +import shutil +from datetime import datetime +from typing import List, Dict, Any, Optional, Tuple +from tqdm import tqdm # 進度條 + +try: + import chromadb + from chromadb.utils import embedding_functions +except ImportError: + print("錯誤: 請先安裝 chromadb: pip install chromadb") + sys.exit(1) + +try: + from sentence_transformers import SentenceTransformer +except ImportError: + print("錯誤: 請先安裝 sentence-transformers: pip install sentence-transformers") + sys.exit(1) + +# 嘗試導入配置 +try: + import config +except ImportError: + print("警告: 無法導入config.py,將使用預設設定") + # 建立最小配置 + class MinimalConfig: + CHROMA_DATA_DIR = "chroma_data" + BOT_MEMORY_COLLECTION = "wolfhart_memory" + CONVERSATIONS_COLLECTION = "wolfhart_memory" + PROFILES_COLLECTION = "wolfhart_memory" + config = MinimalConfig() + +def parse_args(): + """處理命令行參數""" + parser = argparse.ArgumentParser(description='ChromaDB 數據重新嵌入工具') + + parser.add_argument('--new-model', type=str, + default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + help='新的嵌入模型名稱 (預設: sentence-transformers/paraphrase-multilingual-mpnet-base-v2)') + + parser.add_argument('--collections', type=str, nargs='+', + help=f'要處理的集合名稱列表,空白分隔 (預設: 使用配置中的所有集合)') + + parser.add_argument('--backup', action='store_true', + help='在處理前備份資料庫 (推薦)') + + parser.add_argument('--batch-size', type=int, default=100, + help='批處理大小 (預設: 100)') + + parser.add_argument('--temp-collection-suffix', type=str, default="_temp_new", + help='臨時集合的後綴名稱 (預設: _temp_new)') + + parser.add_argument('--dry-run', action='store_true', + help='模擬執行但不實際修改資料') + + parser.add_argument('--confirm-dangerous', action='store_true', + help='確認執行危險操作(例如刪除集合)') + + return parser.parse_args() + +def backup_chroma_directory(chroma_dir: str) -> str: + """備份ChromaDB數據目錄 + + Args: + chroma_dir: ChromaDB數據目錄路徑 + + Returns: + 備份目錄的路徑 + """ + if not os.path.exists(chroma_dir): + print(f"錯誤: ChromaDB目錄 '{chroma_dir}' 不存在") + sys.exit(1) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_dir = f"{chroma_dir}_backup_{timestamp}" + + print(f"備份資料庫從 '{chroma_dir}' 到 '{backup_dir}'...") + shutil.copytree(chroma_dir, backup_dir) + print(f"備份完成: {backup_dir}") + + return backup_dir + +def create_embedding_function(model_name: str): + """創建嵌入函數 + + Args: + model_name: 嵌入模型名稱 + + Returns: + 嵌入函數對象 + """ + if not model_name: + print("使用ChromaDB預設嵌入模型") + return embedding_functions.DefaultEmbeddingFunction() + + print(f"正在加載嵌入模型: {model_name}") + try: + # 直接使用SentenceTransformerEmbeddingFunction + from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction + embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name) + # 預熱模型 + _ = embedding_function(["."]) + return embedding_function + except Exception as e: + print(f"錯誤: 無法加載模型 '{model_name}': {e}") + print("退回到預設嵌入模型") + return embedding_functions.DefaultEmbeddingFunction() + +def get_collection_names(client, default_collections: List[str]) -> List[str]: + """獲取所有可用的集合名稱 + + Args: + client: ChromaDB客戶端 + default_collections: 預設集合列表 + + Returns: + 可用的集合名稱列表 + """ + try: + all_collections = client.list_collections() + collection_names = [col.name for col in all_collections] + + if collection_names: + return collection_names + else: + print("警告: 沒有找到集合,將使用預設集合") + return default_collections + + except Exception as e: + print(f"獲取集合列表失敗: {e}") + print("將使用預設集合") + return default_collections + +def fetch_collection_data(client, collection_name: str, batch_size: int = 100) -> Dict[str, Any]: + """從集合中提取所有數據 + + Args: + client: ChromaDB客戶端 + collection_name: 集合名稱 + batch_size: 批處理大小 + + Returns: + 集合數據字典,包含ids, documents, metadatas + """ + try: + collection = client.get_collection(name=collection_name) + + # 獲取該集合中的項目總數 + count_result = collection.count() + if count_result == 0: + print(f"集合 '{collection_name}' 是空的") + return {"ids": [], "documents": [], "metadatas": []} + + print(f"從集合 '{collection_name}' 中讀取 {count_result} 項數據...") + + # 分批獲取數據 + all_ids = [] + all_documents = [] + all_metadatas = [] + + offset = 0 + with tqdm(total=count_result, desc=f"正在讀取 {collection_name}") as pbar: + while True: + # 注意: 使用include參數指定只獲取需要的數據 + batch_result = collection.get( + limit=batch_size, + offset=offset, + include=["documents", "metadatas"] + ) + + batch_ids = batch_result.get("ids", []) + if not batch_ids: + break + + all_ids.extend(batch_ids) + all_documents.extend(batch_result.get("documents", [])) + all_metadatas.extend(batch_result.get("metadatas", [])) + + offset += len(batch_ids) + pbar.update(len(batch_ids)) + + if len(batch_ids) < batch_size: + break + + return { + "ids": all_ids, + "documents": all_documents, + "metadatas": all_metadatas + } + + except Exception as e: + print(f"從集合 '{collection_name}' 獲取數據時出錯: {e}") + return {"ids": [], "documents": [], "metadatas": []} + +def create_and_populate_collection( + client, + collection_name: str, + data: Dict[str, Any], + embedding_func, + batch_size: int = 100, + dry_run: bool = False +) -> bool: + """創建新集合並填充數據 + + Args: + client: ChromaDB客戶端 + collection_name: 集合名稱 + data: 要添加的數據 (ids, documents, metadatas) + embedding_func: 嵌入函數 + batch_size: 批處理大小 + dry_run: 是否只模擬執行 + + Returns: + 成功返回True,否則返回False + """ + if dry_run: + print(f"[模擬] 將創建集合 '{collection_name}' 並添加 {len(data['ids'])} 項數據") + return True + + try: + # 檢查集合是否已存在 + if collection_name in [col.name for col in client.list_collections()]: + client.delete_collection(collection_name) + + # 創建新集合 + collection = client.create_collection( + name=collection_name, + embedding_function=embedding_func + ) + + # 如果沒有數據,直接返回 + if not data["ids"]: + print(f"集合 '{collection_name}' 創建完成,但沒有數據添加") + return True + + # 分批添加數據 + total_items = len(data["ids"]) + with tqdm(total=total_items, desc=f"正在填充 {collection_name}") as pbar: + for i in range(0, total_items, batch_size): + end_idx = min(i + batch_size, total_items) + + batch_ids = data["ids"][i:end_idx] + batch_docs = data["documents"][i:end_idx] + batch_meta = data["metadatas"][i:end_idx] + + # 處理可能的None值 + processed_docs = [] + for doc in batch_docs: + if doc is None: + processed_docs.append("") # 使用空字符串替代None + else: + processed_docs.append(doc) + + collection.add( + ids=batch_ids, + documents=processed_docs, + metadatas=batch_meta + ) + + pbar.update(end_idx - i) + + print(f"成功將 {total_items} 項數據添加到集合 '{collection_name}'") + return True + + except Exception as e: + print(f"創建或填充集合 '{collection_name}' 時出錯: {e}") + import traceback + traceback.print_exc() + return False + +def swap_collections( + client, + original_collection: str, + temp_collection: str, + confirm_dangerous: bool = False, + dry_run: bool = False, + embedding_func = None # 添加嵌入函數作為參數 +) -> bool: + """替換集合(刪除原始集合,將臨時集合重命名為原始集合名) + + Args: + client: ChromaDB客戶端 + original_collection: 原始集合名稱 + temp_collection: 臨時集合名稱 + confirm_dangerous: 是否確認危險操作 + dry_run: 是否只模擬執行 + embedding_func: 嵌入函數,用於創建新集合 + + Returns: + 成功返回True,否則返回False + """ + if dry_run: + print(f"[模擬] 將替換集合: 刪除 '{original_collection}',重命名 '{temp_collection}' 到 '{original_collection}'") + return True + + try: + # 檢查是否有確認標誌 + if not confirm_dangerous: + response = input(f"警告: 即將刪除集合 '{original_collection}' 並用 '{temp_collection}' 替換它。確認操作? (y/N): ") + if response.lower() != 'y': + print("操作已取消") + return False + + # 檢查兩個集合是否都存在 + all_collections = [col.name for col in client.list_collections()] + if original_collection not in all_collections: + print(f"錯誤: 原始集合 '{original_collection}' 不存在") + return False + + if temp_collection not in all_collections: + print(f"錯誤: 臨時集合 '{temp_collection}' 不存在") + return False + + # 獲取臨時集合的所有數據 + # 在刪除原始集合之前先獲取臨時集合的所有數據 + print(f"獲取臨時集合 '{temp_collection}' 的數據...") + temp_collection_obj = client.get_collection(temp_collection) + temp_data = temp_collection_obj.get(include=["documents", "metadatas"]) + + # 刪除原始集合 + print(f"刪除原始集合 '{original_collection}'...") + client.delete_collection(original_collection) + + # 創建一個同名的新集合(與原始集合同名) + print(f"創建新集合 '{original_collection}'...") + + # 使用傳入的嵌入函數或臨時集合的嵌入函數 + embedding_function = embedding_func or temp_collection_obj._embedding_function + + # 創建新的集合 + original_collection_obj = client.create_collection( + name=original_collection, + embedding_function=embedding_function + ) + + # 將數據添加到新集合 + if temp_data["ids"]: + print(f"將 {len(temp_data['ids'])} 項數據從臨時集合複製到新集合...") + + # 處理可能的None值 + processed_docs = [] + for doc in temp_data["documents"]: + if doc is None: + processed_docs.append("") + else: + processed_docs.append(doc) + + # 使用分批方式添加數據以避免潛在的大數據問題 + batch_size = 100 + for i in range(0, len(temp_data["ids"]), batch_size): + end = min(i + batch_size, len(temp_data["ids"])) + original_collection_obj.add( + ids=temp_data["ids"][i:end], + documents=processed_docs[i:end], + metadatas=temp_data["metadatas"][i:end] if temp_data["metadatas"] else None + ) + + # 刪除臨時集合 + print(f"刪除臨時集合 '{temp_collection}'...") + client.delete_collection(temp_collection) + + print(f"成功用重新嵌入的數據替換集合 '{original_collection}'") + return True + + except Exception as e: + print(f"替換集合時出錯: {e}") + import traceback + traceback.print_exc() + return False + +def process_collection( + client, + collection_name: str, + embedding_func, + temp_suffix: str, + batch_size: int, + confirm_dangerous: bool, + dry_run: bool +) -> bool: + """處理一個集合的完整流程 + + Args: + client: ChromaDB客戶端 + collection_name: 要處理的集合名稱 + embedding_func: 新的嵌入函數 + temp_suffix: 臨時集合的後綴 + batch_size: 批處理大小 + confirm_dangerous: 是否確認危險操作 + dry_run: 是否只模擬執行 + + Returns: + 處理成功返回True,否則返回False + """ + print(f"\n{'=' * 60}") + print(f"處理集合: '{collection_name}'") + print(f"{'=' * 60}") + + # 暫時集合名稱 + temp_collection_name = f"{collection_name}{temp_suffix}" + + # 1. 獲取原始集合的數據 + data = fetch_collection_data(client, collection_name, batch_size) + + if not data["ids"]: + print(f"集合 '{collection_name}' 為空或不存在,跳過") + return True + + # 2. 創建臨時集合並使用新的嵌入模型填充數據 + success = create_and_populate_collection( + client, + temp_collection_name, + data, + embedding_func, + batch_size, + dry_run + ) + + if not success: + print(f"創建臨時集合 '{temp_collection_name}' 失敗,跳過替換") + return False + + # 3. 替換原始集合 + success = swap_collections( + client, + collection_name, + temp_collection_name, + confirm_dangerous, + dry_run, + embedding_func # 添加嵌入函數作為參數 + ) + + return success + +def main(): + """主函數""" + args = parse_args() + + # 獲取ChromaDB目錄 + chroma_dir = getattr(config, "CHROMA_DATA_DIR", "chroma_data") + print(f"使用ChromaDB目錄: {chroma_dir}") + + # 備份數據庫(如果請求) + if args.backup: + backup_chroma_directory(chroma_dir) + + # 創建ChromaDB客戶端 + try: + client = chromadb.PersistentClient(path=chroma_dir) + except Exception as e: + print(f"錯誤: 無法連接到ChromaDB: {e}") + sys.exit(1) + + # 創建嵌入函數 + embedding_func = create_embedding_function(args.new_model) + + # 確定要處理的集合 + if args.collections: + collections_to_process = args.collections + else: + # 使用配置中的默認集合或獲取所有可用集合 + default_collections = [ + getattr(config, "BOT_MEMORY_COLLECTION", "wolfhart_memory"), + getattr(config, "CONVERSATIONS_COLLECTION", "conversations"), + getattr(config, "PROFILES_COLLECTION", "user_profiles") + ] + collections_to_process = get_collection_names(client, default_collections) + + # 過濾掉已經是臨時集合的集合名稱 + filtered_collections = [] + for collection in collections_to_process: + if args.temp_collection_suffix in collection: + print(f"警告: 跳過可能的臨時集合 '{collection}'") + continue + filtered_collections.append(collection) + + collections_to_process = filtered_collections + + if not collections_to_process: + print("沒有找到可處理的集合。") + sys.exit(0) + + print(f"將處理以下集合: {', '.join(collections_to_process)}") + if args.dry_run: + print("注意: 執行為乾運行模式,不會實際修改數據") + + # 詢問用戶確認 + if not args.confirm_dangerous and not args.dry_run: + confirm = input("這個操作將使用新的嵌入模型重新計算所有數據。繼續? (y/N): ") + if confirm.lower() != 'y': + print("操作已取消") + sys.exit(0) + + # 處理每個集合 + start_time = time.time() + success_count = 0 + + for collection_name in collections_to_process: + if process_collection( + client, + collection_name, + embedding_func, + args.temp_collection_suffix, + args.batch_size, + args.confirm_dangerous, + args.dry_run + ): + success_count += 1 + + # 報告結果 + elapsed_time = time.time() - start_time + print(f"\n{'=' * 60}") + print(f"處理完成: {success_count}/{len(collections_to_process)} 個集合成功") + print(f"總耗時: {elapsed_time:.2f} 秒") + print(f"{'=' * 60}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/chroma_view.py b/tools/chroma_view.py index 0669c31..df2b638 100644 --- a/tools/chroma_view.py +++ b/tools/chroma_view.py @@ -147,6 +147,15 @@ class ChromaDBReader: except Exception as e: self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-MiniLM-L12-v2: {e}。將使用集合內部模型。") self.query_embedding_function = None + # 添加新的模型支持 + elif model_name == "paraphrase-multilingual-mpnet-base-v2": + try: + # 注意: sentence-transformers 庫需要安裝 + self.query_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") + self.logger.info(f"查詢將使用外部嵌入模型: {model_name}") + except Exception as e: + self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-mpnet-base-v2: {e}。將使用集合內部模型。") + self.query_embedding_function = None else: self.logger.warning(f"未知的查詢嵌入模型: {model_name}, 將使用集合內部模型。") self.query_embedding_function = None @@ -450,13 +459,11 @@ class ChromaDBReaderUI: self.embedding_models = { "預設 (ChromaDB)": "default", "all-MiniLM-L6-v2 (ST)": "all-MiniLM-L6-v2", - "paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2" + "paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2", + "paraphrase-multilingual-mpnet-base-v2 (ST)": "paraphrase-multilingual-mpnet-base-v2" # 添加新的模型選項 } - # 初始化 reader 中的嵌入模型 (確保 reader 實例已創建) - # self.reader.set_query_embedding_model(self.embedding_models[self.embedding_model_var.get()]) - # ^^^ 這行需要在 setup_ui 之後,或者在 on_embedding_model_changed 中處理首次設置 - - self.setup_ui() # setup_ui 會創建 reader 實例 + + self.setup_ui() # 默認主題 self.current_theme = "darkly" # ttkbootstrap的深色主題