Migrate ChromaDB embedding model to paraphrase-multilingual-mpnet-base-v2

This commit is contained in:
z060142 2025-05-09 12:32:06 +08:00
parent 65df12a20e
commit bccc6d413f
5 changed files with 666 additions and 50 deletions

View File

@ -316,6 +316,15 @@ def load_current_config():
if backup_minute_match:
config_data["MEMORY_BACKUP_MINUTE"] = int(backup_minute_match.group(1))
# Extract EMBEDDING_MODEL_NAME
embedding_model_match = re.search(r'EMBEDDING_MODEL_NAME\s*=\s*["\'](.+?)["\']', config_content)
if embedding_model_match:
config_data["EMBEDDING_MODEL_NAME"] = embedding_model_match.group(1)
else:
# Default if not found in config.py, will be set in UI if not overridden by load
config_data["EMBEDDING_MODEL_NAME"] = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
profile_model_match = re.search(r'MEMORY_PROFILE_MODEL\s*=\s*["\']?(.+?)["\']?\s*(?:#|$)', config_content)
# Handle potential LLM_MODEL reference
if profile_model_match:
@ -537,10 +546,15 @@ def generate_config_file(config_data, env_data):
f.write(f"MEMORY_BACKUP_MINUTE = {backup_minute}\n")
# Write profile model, potentially referencing LLM_MODEL
if profile_model == config_data.get('LLM_MODEL'):
f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n")
f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n")
else:
f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n")
f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n")
f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n")
f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n\n")
# Write Embedding Model Name
embedding_model_name = config_data.get('EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
f.write("# Embedding model for ChromaDB\n")
f.write(f"EMBEDDING_MODEL_NAME = \"{embedding_model_name}\"\n")
print("Generated config.py file successfully")
@ -1717,6 +1731,24 @@ class WolfChatSetup(tk.Tk):
related_info = ttk.Label(related_frame, text="(0 to disable related memories pre-loading)")
related_info.pack(side=tk.LEFT, padx=(5, 0))
# Embedding Model Settings Frame
embedding_model_settings_frame = ttk.LabelFrame(main_frame, text="Embedding Model Settings")
embedding_model_settings_frame.pack(fill=tk.X, pady=10)
embedding_model_name_frame = ttk.Frame(embedding_model_settings_frame)
embedding_model_name_frame.pack(fill=tk.X, pady=5, padx=10)
embedding_model_name_label = ttk.Label(embedding_model_name_frame, text="Embedding Model Name:", width=25) # Adjusted width
embedding_model_name_label.pack(side=tk.LEFT, padx=(0, 5))
self.embedding_model_name_var = tk.StringVar(value="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
embedding_model_name_entry = ttk.Entry(embedding_model_name_frame, textvariable=self.embedding_model_name_var)
embedding_model_name_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
embedding_model_info = ttk.Label(embedding_model_settings_frame, text="Default: sentence-transformers/paraphrase-multilingual-mpnet-base-v2", justify=tk.LEFT)
embedding_model_info.pack(anchor=tk.W, padx=10, pady=(0,5))
# Information box
info_frame = ttk.LabelFrame(main_frame, text="Information")
info_frame.pack(fill=tk.BOTH, expand=True, pady=10)
@ -2067,6 +2099,10 @@ class WolfChatSetup(tk.Tk):
self.profiles_collection_var.set(self.config_data.get("PROFILES_COLLECTION", "user_profiles")) # Default was user_profiles
self.conversations_collection_var.set(self.config_data.get("CONVERSATIONS_COLLECTION", "conversations"))
self.bot_memory_collection_var.set(self.config_data.get("BOT_MEMORY_COLLECTION", "wolfhart_memory"))
# Embedding Model Name for Memory Settings Tab
if hasattr(self, 'embedding_model_name_var'):
self.embedding_model_name_var.set(self.config_data.get("EMBEDDING_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"))
# Memory Management Tab Settings
if hasattr(self, 'backup_hour_var'): # Check if UI elements for memory management tab exist
@ -2343,6 +2379,10 @@ class WolfChatSetup(tk.Tk):
self.config_data["PROFILES_COLLECTION"] = self.profiles_collection_var.get()
self.config_data["CONVERSATIONS_COLLECTION"] = self.conversations_collection_var.get()
self.config_data["BOT_MEMORY_COLLECTION"] = self.bot_memory_collection_var.get()
# Save Embedding Model Name from Memory Settings Tab
if hasattr(self, 'embedding_model_name_var'):
self.config_data["EMBEDDING_MODEL_NAME"] = self.embedding_model_name_var.get()
# Get Memory Management settings from UI
if hasattr(self, 'backup_hour_var'): # Check if UI elements exist

View File

@ -1,6 +1,7 @@
# chroma_client.py
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions # New import
import os
import json
import config
@ -10,6 +11,33 @@ import time
_client = None
_collections = {}
# Global embedding function variable
_embedding_function = None
def get_embedding_function():
"""Gets or creates the embedding function based on config"""
global _embedding_function
if _embedding_function is None:
# Default to paraphrase-multilingual-mpnet-base-v2 if not specified or on error
model_name = getattr(config, 'EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
try:
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
print(f"Successfully initialized embedding function with model: {model_name}")
except Exception as e:
print(f"Failed to initialize embedding function with model '{model_name}': {e}")
# Fallback to default if specified model fails and it's not already the default
if model_name != "sentence-transformers/paraphrase-multilingual-mpnet-base-v2":
print("Falling back to default embedding model: sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
try:
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
print(f"Successfully initialized embedding function with default model.")
except Exception as e_default:
print(f"Failed to initialize default embedding function: {e_default}")
_embedding_function = None # Ensure it's None if all attempts fail
else:
_embedding_function = None # Ensure it's None if default model also fails
return _embedding_function
def initialize_chroma_client():
"""Initializes and connects to ChromaDB"""
global _client
@ -34,13 +62,31 @@ def get_collection(collection_name):
if collection_name not in _collections:
try:
emb_func = get_embedding_function()
if emb_func is None:
print(f"Failed to get or create collection '{collection_name}' due to embedding function initialization failure.")
return None
_collections[collection_name] = _client.get_or_create_collection(
name=collection_name
name=collection_name,
embedding_function=emb_func
)
print(f"Successfully got or created collection '{collection_name}'")
print(f"Successfully got or created collection '{collection_name}' using configured embedding function.")
except Exception as e:
print(f"Failed to get collection '{collection_name}': {e}")
return None
print(f"Failed to get collection '{collection_name}' with configured embedding function: {e}")
# Attempt to create collection with default embedding function as a fallback
print(f"Attempting to create collection '{collection_name}' with default embedding function...")
try:
# Ensure we try the absolute default if the configured one (even if it was the default) failed
default_emb_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
_collections[collection_name] = _client.get_or_create_collection(
name=collection_name,
embedding_function=default_emb_func
)
print(f"Successfully got or created collection '{collection_name}' with default embedding function after initial failure.")
except Exception as e_default:
print(f"Failed to get collection '{collection_name}' even with default embedding function: {e_default}")
return None
return _collections[collection_name]

View File

@ -16,11 +16,12 @@ import schedule
from pathlib import Path
from typing import Dict, List, Optional, Any, Union
import chromadb
from chromadb.utils import embedding_functions
# import chromadb # No longer directly needed by ChromaDBManager
# from chromadb.utils import embedding_functions # No longer directly needed by ChromaDBManager
from openai import AsyncOpenAI
import config
import chroma_client # Import the centralized chroma client
# =============================================================================
# 日誌解析部分
@ -345,28 +346,22 @@ class MemoryGenerator:
class ChromaDBManager:
def __init__(self, collection_name: Optional[str] = None):
self.client = chromadb.PersistentClient(path=config.CHROMA_DATA_DIR)
self.collection_name = collection_name or config.BOT_MEMORY_COLLECTION
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
self._ensure_collection()
self._db_collection = None # Cache for the collection object
def _ensure_collection(self) -> None:
"""確保集合存在"""
try:
self.collection = self.client.get_collection(
name=self.collection_name,
embedding_function=self.embedding_function
)
print(f"Connected to existing collection: {self.collection_name}")
except Exception:
self.collection = self.client.create_collection(
name=self.collection_name,
embedding_function=self.embedding_function
)
print(f"Created new collection: {self.collection_name}")
def _get_db_collection(self):
"""Helper to get the collection object from chroma_client"""
if self._db_collection is None:
# Use the centralized get_collection function
self._db_collection = chroma_client.get_collection(self.collection_name)
if self._db_collection is None:
# This indicates a failure in chroma_client to provide the collection
raise RuntimeError(f"Failed to get or create collection '{self.collection_name}' via chroma_client. Check chroma_client logs.")
return self._db_collection
def upsert_user_profile(self, profile_data: Dict[str, Any]) -> bool:
"""寫入或更新用戶檔案"""
collection = self._get_db_collection()
if not profile_data or not isinstance(profile_data, dict):
print("無效的檔案數據")
return False
@ -377,14 +372,13 @@ class ChromaDBManager:
print("檔案缺少ID字段")
return False
# 先檢查是否已存在
results = self.collection.get(
ids=[user_id], # Query by a list of IDs
# where={"id": user_id}, # 'where' is for metadata filtering
limit=1
)
# 準備元數據
# Note: ChromaDB's upsert handles existence check implicitly.
# The .get call here isn't strictly necessary for the upsert operation itself,
# but might be kept if there was other logic depending on prior existence.
# For a clean upsert, it can be removed. Let's assume it's not critical for now.
# results = collection.get(ids=[user_id], limit=1) # Optional: if needed for pre-check logic
metadata = {
"id": user_id,
"type": "user_profile",
@ -402,14 +396,12 @@ class ChromaDBManager:
content_doc = json.dumps(profile_data.get("content", {}), ensure_ascii=False)
# 寫入或更新
# ChromaDB's add/upsert handles both cases.
# If an ID exists, it's an update; otherwise, it's an add.
self.collection.upsert(
collection.upsert(
ids=[user_id],
documents=[content_doc],
metadatas=[metadata]
)
print(f"Upserted user profile: {user_id}")
print(f"Upserted user profile: {user_id} into collection {self.collection_name}")
return True
@ -419,6 +411,7 @@ class ChromaDBManager:
def upsert_conversation_summary(self, summary_data: Dict[str, Any]) -> bool:
"""寫入對話總結"""
collection = self._get_db_collection()
if not summary_data or not isinstance(summary_data, dict):
print("無效的總結數據")
return False
@ -450,13 +443,13 @@ class ChromaDBManager:
key_points_str = "\n".join([f"- {point}" for point in summary_data["key_points"]])
content_doc += f"\n\n關鍵點:\n{key_points_str}"
# 寫入數據 (ChromaDB's add implies upsert if ID exists, but upsert is more explicit)
self.collection.upsert(
# 寫入數據
collection.upsert(
ids=[summary_id],
documents=[content_doc],
metadatas=[metadata]
)
print(f"Upserted conversation summary: {summary_id}")
print(f"Upserted conversation summary: {summary_id} into collection {self.collection_name}")
return True
@ -466,10 +459,11 @@ class ChromaDBManager:
def get_existing_profile(self, username: str) -> Optional[Dict[str, Any]]:
"""獲取現有的用戶檔案"""
collection = self._get_db_collection()
try:
profile_id = f"{username}_profile"
results = self.collection.get(
ids=[profile_id], # Query by a list of IDs
results = collection.get(
ids=[profile_id],
limit=1
)

529
reembed_chroma_data.py Normal file
View File

@ -0,0 +1,529 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
重新嵌入工具 (Reembedding Tool)
這個腳本用於將現有ChromaDB集合中的數據使用新的嵌入模型重新計算向量並儲存
"""
import os
import sys
import json
import time
import argparse
import shutil
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm # 進度條
try:
import chromadb
from chromadb.utils import embedding_functions
except ImportError:
print("錯誤: 請先安裝 chromadb: pip install chromadb")
sys.exit(1)
try:
from sentence_transformers import SentenceTransformer
except ImportError:
print("錯誤: 請先安裝 sentence-transformers: pip install sentence-transformers")
sys.exit(1)
# 嘗試導入配置
try:
import config
except ImportError:
print("警告: 無法導入config.py將使用預設設定")
# 建立最小配置
class MinimalConfig:
CHROMA_DATA_DIR = "chroma_data"
BOT_MEMORY_COLLECTION = "wolfhart_memory"
CONVERSATIONS_COLLECTION = "wolfhart_memory"
PROFILES_COLLECTION = "wolfhart_memory"
config = MinimalConfig()
def parse_args():
"""處理命令行參數"""
parser = argparse.ArgumentParser(description='ChromaDB 數據重新嵌入工具')
parser.add_argument('--new-model', type=str,
default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
help='新的嵌入模型名稱 (預設: sentence-transformers/paraphrase-multilingual-mpnet-base-v2)')
parser.add_argument('--collections', type=str, nargs='+',
help=f'要處理的集合名稱列表,空白分隔 (預設: 使用配置中的所有集合)')
parser.add_argument('--backup', action='store_true',
help='在處理前備份資料庫 (推薦)')
parser.add_argument('--batch-size', type=int, default=100,
help='批處理大小 (預設: 100)')
parser.add_argument('--temp-collection-suffix', type=str, default="_temp_new",
help='臨時集合的後綴名稱 (預設: _temp_new)')
parser.add_argument('--dry-run', action='store_true',
help='模擬執行但不實際修改資料')
parser.add_argument('--confirm-dangerous', action='store_true',
help='確認執行危險操作(例如刪除集合)')
return parser.parse_args()
def backup_chroma_directory(chroma_dir: str) -> str:
"""備份ChromaDB數據目錄
Args:
chroma_dir: ChromaDB數據目錄路徑
Returns:
備份目錄的路徑
"""
if not os.path.exists(chroma_dir):
print(f"錯誤: ChromaDB目錄 '{chroma_dir}' 不存在")
sys.exit(1)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_dir = f"{chroma_dir}_backup_{timestamp}"
print(f"備份資料庫從 '{chroma_dir}''{backup_dir}'...")
shutil.copytree(chroma_dir, backup_dir)
print(f"備份完成: {backup_dir}")
return backup_dir
def create_embedding_function(model_name: str):
"""創建嵌入函數
Args:
model_name: 嵌入模型名稱
Returns:
嵌入函數對象
"""
if not model_name:
print("使用ChromaDB預設嵌入模型")
return embedding_functions.DefaultEmbeddingFunction()
print(f"正在加載嵌入模型: {model_name}")
try:
# 直接使用SentenceTransformerEmbeddingFunction
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
# 預熱模型
_ = embedding_function(["."])
return embedding_function
except Exception as e:
print(f"錯誤: 無法加載模型 '{model_name}': {e}")
print("退回到預設嵌入模型")
return embedding_functions.DefaultEmbeddingFunction()
def get_collection_names(client, default_collections: List[str]) -> List[str]:
"""獲取所有可用的集合名稱
Args:
client: ChromaDB客戶端
default_collections: 預設集合列表
Returns:
可用的集合名稱列表
"""
try:
all_collections = client.list_collections()
collection_names = [col.name for col in all_collections]
if collection_names:
return collection_names
else:
print("警告: 沒有找到集合,將使用預設集合")
return default_collections
except Exception as e:
print(f"獲取集合列表失敗: {e}")
print("將使用預設集合")
return default_collections
def fetch_collection_data(client, collection_name: str, batch_size: int = 100) -> Dict[str, Any]:
"""從集合中提取所有數據
Args:
client: ChromaDB客戶端
collection_name: 集合名稱
batch_size: 批處理大小
Returns:
集合數據字典包含ids, documents, metadatas
"""
try:
collection = client.get_collection(name=collection_name)
# 獲取該集合中的項目總數
count_result = collection.count()
if count_result == 0:
print(f"集合 '{collection_name}' 是空的")
return {"ids": [], "documents": [], "metadatas": []}
print(f"從集合 '{collection_name}' 中讀取 {count_result} 項數據...")
# 分批獲取數據
all_ids = []
all_documents = []
all_metadatas = []
offset = 0
with tqdm(total=count_result, desc=f"正在讀取 {collection_name}") as pbar:
while True:
# 注意: 使用include參數指定只獲取需要的數據
batch_result = collection.get(
limit=batch_size,
offset=offset,
include=["documents", "metadatas"]
)
batch_ids = batch_result.get("ids", [])
if not batch_ids:
break
all_ids.extend(batch_ids)
all_documents.extend(batch_result.get("documents", []))
all_metadatas.extend(batch_result.get("metadatas", []))
offset += len(batch_ids)
pbar.update(len(batch_ids))
if len(batch_ids) < batch_size:
break
return {
"ids": all_ids,
"documents": all_documents,
"metadatas": all_metadatas
}
except Exception as e:
print(f"從集合 '{collection_name}' 獲取數據時出錯: {e}")
return {"ids": [], "documents": [], "metadatas": []}
def create_and_populate_collection(
client,
collection_name: str,
data: Dict[str, Any],
embedding_func,
batch_size: int = 100,
dry_run: bool = False
) -> bool:
"""創建新集合並填充數據
Args:
client: ChromaDB客戶端
collection_name: 集合名稱
data: 要添加的數據 (ids, documents, metadatas)
embedding_func: 嵌入函數
batch_size: 批處理大小
dry_run: 是否只模擬執行
Returns:
成功返回True否則返回False
"""
if dry_run:
print(f"[模擬] 將創建集合 '{collection_name}' 並添加 {len(data['ids'])} 項數據")
return True
try:
# 檢查集合是否已存在
if collection_name in [col.name for col in client.list_collections()]:
client.delete_collection(collection_name)
# 創建新集合
collection = client.create_collection(
name=collection_name,
embedding_function=embedding_func
)
# 如果沒有數據,直接返回
if not data["ids"]:
print(f"集合 '{collection_name}' 創建完成,但沒有數據添加")
return True
# 分批添加數據
total_items = len(data["ids"])
with tqdm(total=total_items, desc=f"正在填充 {collection_name}") as pbar:
for i in range(0, total_items, batch_size):
end_idx = min(i + batch_size, total_items)
batch_ids = data["ids"][i:end_idx]
batch_docs = data["documents"][i:end_idx]
batch_meta = data["metadatas"][i:end_idx]
# 處理可能的None值
processed_docs = []
for doc in batch_docs:
if doc is None:
processed_docs.append("") # 使用空字符串替代None
else:
processed_docs.append(doc)
collection.add(
ids=batch_ids,
documents=processed_docs,
metadatas=batch_meta
)
pbar.update(end_idx - i)
print(f"成功將 {total_items} 項數據添加到集合 '{collection_name}'")
return True
except Exception as e:
print(f"創建或填充集合 '{collection_name}' 時出錯: {e}")
import traceback
traceback.print_exc()
return False
def swap_collections(
client,
original_collection: str,
temp_collection: str,
confirm_dangerous: bool = False,
dry_run: bool = False,
embedding_func = None # 添加嵌入函數作為參數
) -> bool:
"""替換集合(刪除原始集合,將臨時集合重命名為原始集合名)
Args:
client: ChromaDB客戶端
original_collection: 原始集合名稱
temp_collection: 臨時集合名稱
confirm_dangerous: 是否確認危險操作
dry_run: 是否只模擬執行
embedding_func: 嵌入函數用於創建新集合
Returns:
成功返回True否則返回False
"""
if dry_run:
print(f"[模擬] 將替換集合: 刪除 '{original_collection}',重命名 '{temp_collection}''{original_collection}'")
return True
try:
# 檢查是否有確認標誌
if not confirm_dangerous:
response = input(f"警告: 即將刪除集合 '{original_collection}' 並用 '{temp_collection}' 替換它。確認操作? (y/N): ")
if response.lower() != 'y':
print("操作已取消")
return False
# 檢查兩個集合是否都存在
all_collections = [col.name for col in client.list_collections()]
if original_collection not in all_collections:
print(f"錯誤: 原始集合 '{original_collection}' 不存在")
return False
if temp_collection not in all_collections:
print(f"錯誤: 臨時集合 '{temp_collection}' 不存在")
return False
# 獲取臨時集合的所有數據
# 在刪除原始集合之前先獲取臨時集合的所有數據
print(f"獲取臨時集合 '{temp_collection}' 的數據...")
temp_collection_obj = client.get_collection(temp_collection)
temp_data = temp_collection_obj.get(include=["documents", "metadatas"])
# 刪除原始集合
print(f"刪除原始集合 '{original_collection}'...")
client.delete_collection(original_collection)
# 創建一個同名的新集合(與原始集合同名)
print(f"創建新集合 '{original_collection}'...")
# 使用傳入的嵌入函數或臨時集合的嵌入函數
embedding_function = embedding_func or temp_collection_obj._embedding_function
# 創建新的集合
original_collection_obj = client.create_collection(
name=original_collection,
embedding_function=embedding_function
)
# 將數據添加到新集合
if temp_data["ids"]:
print(f"{len(temp_data['ids'])} 項數據從臨時集合複製到新集合...")
# 處理可能的None值
processed_docs = []
for doc in temp_data["documents"]:
if doc is None:
processed_docs.append("")
else:
processed_docs.append(doc)
# 使用分批方式添加數據以避免潛在的大數據問題
batch_size = 100
for i in range(0, len(temp_data["ids"]), batch_size):
end = min(i + batch_size, len(temp_data["ids"]))
original_collection_obj.add(
ids=temp_data["ids"][i:end],
documents=processed_docs[i:end],
metadatas=temp_data["metadatas"][i:end] if temp_data["metadatas"] else None
)
# 刪除臨時集合
print(f"刪除臨時集合 '{temp_collection}'...")
client.delete_collection(temp_collection)
print(f"成功用重新嵌入的數據替換集合 '{original_collection}'")
return True
except Exception as e:
print(f"替換集合時出錯: {e}")
import traceback
traceback.print_exc()
return False
def process_collection(
client,
collection_name: str,
embedding_func,
temp_suffix: str,
batch_size: int,
confirm_dangerous: bool,
dry_run: bool
) -> bool:
"""處理一個集合的完整流程
Args:
client: ChromaDB客戶端
collection_name: 要處理的集合名稱
embedding_func: 新的嵌入函數
temp_suffix: 臨時集合的後綴
batch_size: 批處理大小
confirm_dangerous: 是否確認危險操作
dry_run: 是否只模擬執行
Returns:
處理成功返回True否則返回False
"""
print(f"\n{'=' * 60}")
print(f"處理集合: '{collection_name}'")
print(f"{'=' * 60}")
# 暫時集合名稱
temp_collection_name = f"{collection_name}{temp_suffix}"
# 1. 獲取原始集合的數據
data = fetch_collection_data(client, collection_name, batch_size)
if not data["ids"]:
print(f"集合 '{collection_name}' 為空或不存在,跳過")
return True
# 2. 創建臨時集合並使用新的嵌入模型填充數據
success = create_and_populate_collection(
client,
temp_collection_name,
data,
embedding_func,
batch_size,
dry_run
)
if not success:
print(f"創建臨時集合 '{temp_collection_name}' 失敗,跳過替換")
return False
# 3. 替換原始集合
success = swap_collections(
client,
collection_name,
temp_collection_name,
confirm_dangerous,
dry_run,
embedding_func # 添加嵌入函數作為參數
)
return success
def main():
"""主函數"""
args = parse_args()
# 獲取ChromaDB目錄
chroma_dir = getattr(config, "CHROMA_DATA_DIR", "chroma_data")
print(f"使用ChromaDB目錄: {chroma_dir}")
# 備份數據庫(如果請求)
if args.backup:
backup_chroma_directory(chroma_dir)
# 創建ChromaDB客戶端
try:
client = chromadb.PersistentClient(path=chroma_dir)
except Exception as e:
print(f"錯誤: 無法連接到ChromaDB: {e}")
sys.exit(1)
# 創建嵌入函數
embedding_func = create_embedding_function(args.new_model)
# 確定要處理的集合
if args.collections:
collections_to_process = args.collections
else:
# 使用配置中的默認集合或獲取所有可用集合
default_collections = [
getattr(config, "BOT_MEMORY_COLLECTION", "wolfhart_memory"),
getattr(config, "CONVERSATIONS_COLLECTION", "conversations"),
getattr(config, "PROFILES_COLLECTION", "user_profiles")
]
collections_to_process = get_collection_names(client, default_collections)
# 過濾掉已經是臨時集合的集合名稱
filtered_collections = []
for collection in collections_to_process:
if args.temp_collection_suffix in collection:
print(f"警告: 跳過可能的臨時集合 '{collection}'")
continue
filtered_collections.append(collection)
collections_to_process = filtered_collections
if not collections_to_process:
print("沒有找到可處理的集合。")
sys.exit(0)
print(f"將處理以下集合: {', '.join(collections_to_process)}")
if args.dry_run:
print("注意: 執行為乾運行模式,不會實際修改數據")
# 詢問用戶確認
if not args.confirm_dangerous and not args.dry_run:
confirm = input("這個操作將使用新的嵌入模型重新計算所有數據。繼續? (y/N): ")
if confirm.lower() != 'y':
print("操作已取消")
sys.exit(0)
# 處理每個集合
start_time = time.time()
success_count = 0
for collection_name in collections_to_process:
if process_collection(
client,
collection_name,
embedding_func,
args.temp_collection_suffix,
args.batch_size,
args.confirm_dangerous,
args.dry_run
):
success_count += 1
# 報告結果
elapsed_time = time.time() - start_time
print(f"\n{'=' * 60}")
print(f"處理完成: {success_count}/{len(collections_to_process)} 個集合成功")
print(f"總耗時: {elapsed_time:.2f}")
print(f"{'=' * 60}")
if __name__ == "__main__":
main()

View File

@ -147,6 +147,15 @@ class ChromaDBReader:
except Exception as e:
self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-MiniLM-L12-v2: {e}。將使用集合內部模型。")
self.query_embedding_function = None
# 添加新的模型支持
elif model_name == "paraphrase-multilingual-mpnet-base-v2":
try:
# 注意: sentence-transformers 庫需要安裝
self.query_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
self.logger.info(f"查詢將使用外部嵌入模型: {model_name}")
except Exception as e:
self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-mpnet-base-v2: {e}。將使用集合內部模型。")
self.query_embedding_function = None
else:
self.logger.warning(f"未知的查詢嵌入模型: {model_name}, 將使用集合內部模型。")
self.query_embedding_function = None
@ -450,13 +459,11 @@ class ChromaDBReaderUI:
self.embedding_models = {
"預設 (ChromaDB)": "default",
"all-MiniLM-L6-v2 (ST)": "all-MiniLM-L6-v2",
"paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2"
"paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2",
"paraphrase-multilingual-mpnet-base-v2 (ST)": "paraphrase-multilingual-mpnet-base-v2" # 添加新的模型選項
}
# 初始化 reader 中的嵌入模型 (確保 reader 實例已創建)
# self.reader.set_query_embedding_model(self.embedding_models[self.embedding_model_var.get()])
# ^^^ 這行需要在 setup_ui 之後,或者在 on_embedding_model_changed 中處理首次設置
self.setup_ui() # setup_ui 會創建 reader 實例
self.setup_ui()
# 默認主題
self.current_theme = "darkly" # ttkbootstrap的深色主題