Migrate ChromaDB embedding model to paraphrase-multilingual-mpnet-base-v2
This commit is contained in:
parent
65df12a20e
commit
bccc6d413f
46
Setup.py
46
Setup.py
@ -316,6 +316,15 @@ def load_current_config():
|
||||
if backup_minute_match:
|
||||
config_data["MEMORY_BACKUP_MINUTE"] = int(backup_minute_match.group(1))
|
||||
|
||||
# Extract EMBEDDING_MODEL_NAME
|
||||
embedding_model_match = re.search(r'EMBEDDING_MODEL_NAME\s*=\s*["\'](.+?)["\']', config_content)
|
||||
if embedding_model_match:
|
||||
config_data["EMBEDDING_MODEL_NAME"] = embedding_model_match.group(1)
|
||||
else:
|
||||
# Default if not found in config.py, will be set in UI if not overridden by load
|
||||
config_data["EMBEDDING_MODEL_NAME"] = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
|
||||
|
||||
profile_model_match = re.search(r'MEMORY_PROFILE_MODEL\s*=\s*["\']?(.+?)["\']?\s*(?:#|$)', config_content)
|
||||
# Handle potential LLM_MODEL reference
|
||||
if profile_model_match:
|
||||
@ -537,10 +546,15 @@ def generate_config_file(config_data, env_data):
|
||||
f.write(f"MEMORY_BACKUP_MINUTE = {backup_minute}\n")
|
||||
# Write profile model, potentially referencing LLM_MODEL
|
||||
if profile_model == config_data.get('LLM_MODEL'):
|
||||
f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n")
|
||||
f.write(f"MEMORY_PROFILE_MODEL = LLM_MODEL # Default to main LLM model\n")
|
||||
else:
|
||||
f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n")
|
||||
f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n")
|
||||
f.write(f"MEMORY_PROFILE_MODEL = \"{profile_model}\"\n")
|
||||
f.write(f"MEMORY_SUMMARY_MODEL = \"{summary_model}\"\n\n")
|
||||
|
||||
# Write Embedding Model Name
|
||||
embedding_model_name = config_data.get('EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
f.write("# Embedding model for ChromaDB\n")
|
||||
f.write(f"EMBEDDING_MODEL_NAME = \"{embedding_model_name}\"\n")
|
||||
|
||||
|
||||
print("Generated config.py file successfully")
|
||||
@ -1717,6 +1731,24 @@ class WolfChatSetup(tk.Tk):
|
||||
related_info = ttk.Label(related_frame, text="(0 to disable related memories pre-loading)")
|
||||
related_info.pack(side=tk.LEFT, padx=(5, 0))
|
||||
|
||||
# Embedding Model Settings Frame
|
||||
embedding_model_settings_frame = ttk.LabelFrame(main_frame, text="Embedding Model Settings")
|
||||
embedding_model_settings_frame.pack(fill=tk.X, pady=10)
|
||||
|
||||
embedding_model_name_frame = ttk.Frame(embedding_model_settings_frame)
|
||||
embedding_model_name_frame.pack(fill=tk.X, pady=5, padx=10)
|
||||
|
||||
embedding_model_name_label = ttk.Label(embedding_model_name_frame, text="Embedding Model Name:", width=25) # Adjusted width
|
||||
embedding_model_name_label.pack(side=tk.LEFT, padx=(0, 5))
|
||||
|
||||
self.embedding_model_name_var = tk.StringVar(value="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
embedding_model_name_entry = ttk.Entry(embedding_model_name_frame, textvariable=self.embedding_model_name_var)
|
||||
embedding_model_name_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
||||
|
||||
embedding_model_info = ttk.Label(embedding_model_settings_frame, text="Default: sentence-transformers/paraphrase-multilingual-mpnet-base-v2", justify=tk.LEFT)
|
||||
embedding_model_info.pack(anchor=tk.W, padx=10, pady=(0,5))
|
||||
|
||||
|
||||
# Information box
|
||||
info_frame = ttk.LabelFrame(main_frame, text="Information")
|
||||
info_frame.pack(fill=tk.BOTH, expand=True, pady=10)
|
||||
@ -2067,6 +2099,10 @@ class WolfChatSetup(tk.Tk):
|
||||
self.profiles_collection_var.set(self.config_data.get("PROFILES_COLLECTION", "user_profiles")) # Default was user_profiles
|
||||
self.conversations_collection_var.set(self.config_data.get("CONVERSATIONS_COLLECTION", "conversations"))
|
||||
self.bot_memory_collection_var.set(self.config_data.get("BOT_MEMORY_COLLECTION", "wolfhart_memory"))
|
||||
# Embedding Model Name for Memory Settings Tab
|
||||
if hasattr(self, 'embedding_model_name_var'):
|
||||
self.embedding_model_name_var.set(self.config_data.get("EMBEDDING_MODEL_NAME", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"))
|
||||
|
||||
|
||||
# Memory Management Tab Settings
|
||||
if hasattr(self, 'backup_hour_var'): # Check if UI elements for memory management tab exist
|
||||
@ -2343,6 +2379,10 @@ class WolfChatSetup(tk.Tk):
|
||||
self.config_data["PROFILES_COLLECTION"] = self.profiles_collection_var.get()
|
||||
self.config_data["CONVERSATIONS_COLLECTION"] = self.conversations_collection_var.get()
|
||||
self.config_data["BOT_MEMORY_COLLECTION"] = self.bot_memory_collection_var.get()
|
||||
# Save Embedding Model Name from Memory Settings Tab
|
||||
if hasattr(self, 'embedding_model_name_var'):
|
||||
self.config_data["EMBEDDING_MODEL_NAME"] = self.embedding_model_name_var.get()
|
||||
|
||||
|
||||
# Get Memory Management settings from UI
|
||||
if hasattr(self, 'backup_hour_var'): # Check if UI elements exist
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# chroma_client.py
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils import embedding_functions # New import
|
||||
import os
|
||||
import json
|
||||
import config
|
||||
@ -10,6 +11,33 @@ import time
|
||||
_client = None
|
||||
_collections = {}
|
||||
|
||||
# Global embedding function variable
|
||||
_embedding_function = None
|
||||
|
||||
def get_embedding_function():
|
||||
"""Gets or creates the embedding function based on config"""
|
||||
global _embedding_function
|
||||
if _embedding_function is None:
|
||||
# Default to paraphrase-multilingual-mpnet-base-v2 if not specified or on error
|
||||
model_name = getattr(config, 'EMBEDDING_MODEL_NAME', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
try:
|
||||
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
||||
print(f"Successfully initialized embedding function with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize embedding function with model '{model_name}': {e}")
|
||||
# Fallback to default if specified model fails and it's not already the default
|
||||
if model_name != "sentence-transformers/paraphrase-multilingual-mpnet-base-v2":
|
||||
print("Falling back to default embedding model: sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
try:
|
||||
_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
print(f"Successfully initialized embedding function with default model.")
|
||||
except Exception as e_default:
|
||||
print(f"Failed to initialize default embedding function: {e_default}")
|
||||
_embedding_function = None # Ensure it's None if all attempts fail
|
||||
else:
|
||||
_embedding_function = None # Ensure it's None if default model also fails
|
||||
return _embedding_function
|
||||
|
||||
def initialize_chroma_client():
|
||||
"""Initializes and connects to ChromaDB"""
|
||||
global _client
|
||||
@ -34,13 +62,31 @@ def get_collection(collection_name):
|
||||
|
||||
if collection_name not in _collections:
|
||||
try:
|
||||
emb_func = get_embedding_function()
|
||||
if emb_func is None:
|
||||
print(f"Failed to get or create collection '{collection_name}' due to embedding function initialization failure.")
|
||||
return None
|
||||
|
||||
_collections[collection_name] = _client.get_or_create_collection(
|
||||
name=collection_name
|
||||
name=collection_name,
|
||||
embedding_function=emb_func
|
||||
)
|
||||
print(f"Successfully got or created collection '{collection_name}'")
|
||||
print(f"Successfully got or created collection '{collection_name}' using configured embedding function.")
|
||||
except Exception as e:
|
||||
print(f"Failed to get collection '{collection_name}': {e}")
|
||||
return None
|
||||
print(f"Failed to get collection '{collection_name}' with configured embedding function: {e}")
|
||||
# Attempt to create collection with default embedding function as a fallback
|
||||
print(f"Attempting to create collection '{collection_name}' with default embedding function...")
|
||||
try:
|
||||
# Ensure we try the absolute default if the configured one (even if it was the default) failed
|
||||
default_emb_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
_collections[collection_name] = _client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=default_emb_func
|
||||
)
|
||||
print(f"Successfully got or created collection '{collection_name}' with default embedding function after initial failure.")
|
||||
except Exception as e_default:
|
||||
print(f"Failed to get collection '{collection_name}' even with default embedding function: {e_default}")
|
||||
return None
|
||||
|
||||
return _collections[collection_name]
|
||||
|
||||
|
||||
@ -16,11 +16,12 @@ import schedule
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
# import chromadb # No longer directly needed by ChromaDBManager
|
||||
# from chromadb.utils import embedding_functions # No longer directly needed by ChromaDBManager
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
import config
|
||||
import chroma_client # Import the centralized chroma client
|
||||
|
||||
# =============================================================================
|
||||
# 日誌解析部分
|
||||
@ -345,28 +346,22 @@ class MemoryGenerator:
|
||||
|
||||
class ChromaDBManager:
|
||||
def __init__(self, collection_name: Optional[str] = None):
|
||||
self.client = chromadb.PersistentClient(path=config.CHROMA_DATA_DIR)
|
||||
self.collection_name = collection_name or config.BOT_MEMORY_COLLECTION
|
||||
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
||||
self._ensure_collection()
|
||||
|
||||
def _ensure_collection(self) -> None:
|
||||
"""確保集合存在"""
|
||||
try:
|
||||
self.collection = self.client.get_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
print(f"Connected to existing collection: {self.collection_name}")
|
||||
except Exception:
|
||||
self.collection = self.client.create_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
print(f"Created new collection: {self.collection_name}")
|
||||
|
||||
self._db_collection = None # Cache for the collection object
|
||||
|
||||
def _get_db_collection(self):
|
||||
"""Helper to get the collection object from chroma_client"""
|
||||
if self._db_collection is None:
|
||||
# Use the centralized get_collection function
|
||||
self._db_collection = chroma_client.get_collection(self.collection_name)
|
||||
if self._db_collection is None:
|
||||
# This indicates a failure in chroma_client to provide the collection
|
||||
raise RuntimeError(f"Failed to get or create collection '{self.collection_name}' via chroma_client. Check chroma_client logs.")
|
||||
return self._db_collection
|
||||
|
||||
def upsert_user_profile(self, profile_data: Dict[str, Any]) -> bool:
|
||||
"""寫入或更新用戶檔案"""
|
||||
collection = self._get_db_collection()
|
||||
if not profile_data or not isinstance(profile_data, dict):
|
||||
print("無效的檔案數據")
|
||||
return False
|
||||
@ -377,14 +372,13 @@ class ChromaDBManager:
|
||||
print("檔案缺少ID字段")
|
||||
return False
|
||||
|
||||
# 先檢查是否已存在
|
||||
results = self.collection.get(
|
||||
ids=[user_id], # Query by a list of IDs
|
||||
# where={"id": user_id}, # 'where' is for metadata filtering
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 準備元數據
|
||||
# Note: ChromaDB's upsert handles existence check implicitly.
|
||||
# The .get call here isn't strictly necessary for the upsert operation itself,
|
||||
# but might be kept if there was other logic depending on prior existence.
|
||||
# For a clean upsert, it can be removed. Let's assume it's not critical for now.
|
||||
# results = collection.get(ids=[user_id], limit=1) # Optional: if needed for pre-check logic
|
||||
|
||||
metadata = {
|
||||
"id": user_id,
|
||||
"type": "user_profile",
|
||||
@ -402,14 +396,12 @@ class ChromaDBManager:
|
||||
content_doc = json.dumps(profile_data.get("content", {}), ensure_ascii=False)
|
||||
|
||||
# 寫入或更新
|
||||
# ChromaDB's add/upsert handles both cases.
|
||||
# If an ID exists, it's an update; otherwise, it's an add.
|
||||
self.collection.upsert(
|
||||
collection.upsert(
|
||||
ids=[user_id],
|
||||
documents=[content_doc],
|
||||
metadatas=[metadata]
|
||||
)
|
||||
print(f"Upserted user profile: {user_id}")
|
||||
print(f"Upserted user profile: {user_id} into collection {self.collection_name}")
|
||||
|
||||
return True
|
||||
|
||||
@ -419,6 +411,7 @@ class ChromaDBManager:
|
||||
|
||||
def upsert_conversation_summary(self, summary_data: Dict[str, Any]) -> bool:
|
||||
"""寫入對話總結"""
|
||||
collection = self._get_db_collection()
|
||||
if not summary_data or not isinstance(summary_data, dict):
|
||||
print("無效的總結數據")
|
||||
return False
|
||||
@ -450,13 +443,13 @@ class ChromaDBManager:
|
||||
key_points_str = "\n".join([f"- {point}" for point in summary_data["key_points"]])
|
||||
content_doc += f"\n\n關鍵點:\n{key_points_str}"
|
||||
|
||||
# 寫入數據 (ChromaDB's add implies upsert if ID exists, but upsert is more explicit)
|
||||
self.collection.upsert(
|
||||
# 寫入數據
|
||||
collection.upsert(
|
||||
ids=[summary_id],
|
||||
documents=[content_doc],
|
||||
metadatas=[metadata]
|
||||
)
|
||||
print(f"Upserted conversation summary: {summary_id}")
|
||||
print(f"Upserted conversation summary: {summary_id} into collection {self.collection_name}")
|
||||
|
||||
return True
|
||||
|
||||
@ -466,10 +459,11 @@ class ChromaDBManager:
|
||||
|
||||
def get_existing_profile(self, username: str) -> Optional[Dict[str, Any]]:
|
||||
"""獲取現有的用戶檔案"""
|
||||
collection = self._get_db_collection()
|
||||
try:
|
||||
profile_id = f"{username}_profile"
|
||||
results = self.collection.get(
|
||||
ids=[profile_id], # Query by a list of IDs
|
||||
results = collection.get(
|
||||
ids=[profile_id],
|
||||
limit=1
|
||||
)
|
||||
|
||||
|
||||
529
reembed_chroma_data.py
Normal file
529
reembed_chroma_data.py
Normal file
@ -0,0 +1,529 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
重新嵌入工具 (Reembedding Tool)
|
||||
|
||||
這個腳本用於將現有ChromaDB集合中的數據使用新的嵌入模型重新計算向量並儲存。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from tqdm import tqdm # 進度條
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
except ImportError:
|
||||
print("錯誤: 請先安裝 chromadb: pip install chromadb")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
print("錯誤: 請先安裝 sentence-transformers: pip install sentence-transformers")
|
||||
sys.exit(1)
|
||||
|
||||
# 嘗試導入配置
|
||||
try:
|
||||
import config
|
||||
except ImportError:
|
||||
print("警告: 無法導入config.py,將使用預設設定")
|
||||
# 建立最小配置
|
||||
class MinimalConfig:
|
||||
CHROMA_DATA_DIR = "chroma_data"
|
||||
BOT_MEMORY_COLLECTION = "wolfhart_memory"
|
||||
CONVERSATIONS_COLLECTION = "wolfhart_memory"
|
||||
PROFILES_COLLECTION = "wolfhart_memory"
|
||||
config = MinimalConfig()
|
||||
|
||||
def parse_args():
|
||||
"""處理命令行參數"""
|
||||
parser = argparse.ArgumentParser(description='ChromaDB 數據重新嵌入工具')
|
||||
|
||||
parser.add_argument('--new-model', type=str,
|
||||
default="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
help='新的嵌入模型名稱 (預設: sentence-transformers/paraphrase-multilingual-mpnet-base-v2)')
|
||||
|
||||
parser.add_argument('--collections', type=str, nargs='+',
|
||||
help=f'要處理的集合名稱列表,空白分隔 (預設: 使用配置中的所有集合)')
|
||||
|
||||
parser.add_argument('--backup', action='store_true',
|
||||
help='在處理前備份資料庫 (推薦)')
|
||||
|
||||
parser.add_argument('--batch-size', type=int, default=100,
|
||||
help='批處理大小 (預設: 100)')
|
||||
|
||||
parser.add_argument('--temp-collection-suffix', type=str, default="_temp_new",
|
||||
help='臨時集合的後綴名稱 (預設: _temp_new)')
|
||||
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='模擬執行但不實際修改資料')
|
||||
|
||||
parser.add_argument('--confirm-dangerous', action='store_true',
|
||||
help='確認執行危險操作(例如刪除集合)')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def backup_chroma_directory(chroma_dir: str) -> str:
|
||||
"""備份ChromaDB數據目錄
|
||||
|
||||
Args:
|
||||
chroma_dir: ChromaDB數據目錄路徑
|
||||
|
||||
Returns:
|
||||
備份目錄的路徑
|
||||
"""
|
||||
if not os.path.exists(chroma_dir):
|
||||
print(f"錯誤: ChromaDB目錄 '{chroma_dir}' 不存在")
|
||||
sys.exit(1)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = f"{chroma_dir}_backup_{timestamp}"
|
||||
|
||||
print(f"備份資料庫從 '{chroma_dir}' 到 '{backup_dir}'...")
|
||||
shutil.copytree(chroma_dir, backup_dir)
|
||||
print(f"備份完成: {backup_dir}")
|
||||
|
||||
return backup_dir
|
||||
|
||||
def create_embedding_function(model_name: str):
|
||||
"""創建嵌入函數
|
||||
|
||||
Args:
|
||||
model_name: 嵌入模型名稱
|
||||
|
||||
Returns:
|
||||
嵌入函數對象
|
||||
"""
|
||||
if not model_name:
|
||||
print("使用ChromaDB預設嵌入模型")
|
||||
return embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
print(f"正在加載嵌入模型: {model_name}")
|
||||
try:
|
||||
# 直接使用SentenceTransformerEmbeddingFunction
|
||||
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
|
||||
embedding_function = SentenceTransformerEmbeddingFunction(model_name=model_name)
|
||||
# 預熱模型
|
||||
_ = embedding_function(["."])
|
||||
return embedding_function
|
||||
except Exception as e:
|
||||
print(f"錯誤: 無法加載模型 '{model_name}': {e}")
|
||||
print("退回到預設嵌入模型")
|
||||
return embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
def get_collection_names(client, default_collections: List[str]) -> List[str]:
|
||||
"""獲取所有可用的集合名稱
|
||||
|
||||
Args:
|
||||
client: ChromaDB客戶端
|
||||
default_collections: 預設集合列表
|
||||
|
||||
Returns:
|
||||
可用的集合名稱列表
|
||||
"""
|
||||
try:
|
||||
all_collections = client.list_collections()
|
||||
collection_names = [col.name for col in all_collections]
|
||||
|
||||
if collection_names:
|
||||
return collection_names
|
||||
else:
|
||||
print("警告: 沒有找到集合,將使用預設集合")
|
||||
return default_collections
|
||||
|
||||
except Exception as e:
|
||||
print(f"獲取集合列表失敗: {e}")
|
||||
print("將使用預設集合")
|
||||
return default_collections
|
||||
|
||||
def fetch_collection_data(client, collection_name: str, batch_size: int = 100) -> Dict[str, Any]:
|
||||
"""從集合中提取所有數據
|
||||
|
||||
Args:
|
||||
client: ChromaDB客戶端
|
||||
collection_name: 集合名稱
|
||||
batch_size: 批處理大小
|
||||
|
||||
Returns:
|
||||
集合數據字典,包含ids, documents, metadatas
|
||||
"""
|
||||
try:
|
||||
collection = client.get_collection(name=collection_name)
|
||||
|
||||
# 獲取該集合中的項目總數
|
||||
count_result = collection.count()
|
||||
if count_result == 0:
|
||||
print(f"集合 '{collection_name}' 是空的")
|
||||
return {"ids": [], "documents": [], "metadatas": []}
|
||||
|
||||
print(f"從集合 '{collection_name}' 中讀取 {count_result} 項數據...")
|
||||
|
||||
# 分批獲取數據
|
||||
all_ids = []
|
||||
all_documents = []
|
||||
all_metadatas = []
|
||||
|
||||
offset = 0
|
||||
with tqdm(total=count_result, desc=f"正在讀取 {collection_name}") as pbar:
|
||||
while True:
|
||||
# 注意: 使用include參數指定只獲取需要的數據
|
||||
batch_result = collection.get(
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
include=["documents", "metadatas"]
|
||||
)
|
||||
|
||||
batch_ids = batch_result.get("ids", [])
|
||||
if not batch_ids:
|
||||
break
|
||||
|
||||
all_ids.extend(batch_ids)
|
||||
all_documents.extend(batch_result.get("documents", []))
|
||||
all_metadatas.extend(batch_result.get("metadatas", []))
|
||||
|
||||
offset += len(batch_ids)
|
||||
pbar.update(len(batch_ids))
|
||||
|
||||
if len(batch_ids) < batch_size:
|
||||
break
|
||||
|
||||
return {
|
||||
"ids": all_ids,
|
||||
"documents": all_documents,
|
||||
"metadatas": all_metadatas
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"從集合 '{collection_name}' 獲取數據時出錯: {e}")
|
||||
return {"ids": [], "documents": [], "metadatas": []}
|
||||
|
||||
def create_and_populate_collection(
|
||||
client,
|
||||
collection_name: str,
|
||||
data: Dict[str, Any],
|
||||
embedding_func,
|
||||
batch_size: int = 100,
|
||||
dry_run: bool = False
|
||||
) -> bool:
|
||||
"""創建新集合並填充數據
|
||||
|
||||
Args:
|
||||
client: ChromaDB客戶端
|
||||
collection_name: 集合名稱
|
||||
data: 要添加的數據 (ids, documents, metadatas)
|
||||
embedding_func: 嵌入函數
|
||||
batch_size: 批處理大小
|
||||
dry_run: 是否只模擬執行
|
||||
|
||||
Returns:
|
||||
成功返回True,否則返回False
|
||||
"""
|
||||
if dry_run:
|
||||
print(f"[模擬] 將創建集合 '{collection_name}' 並添加 {len(data['ids'])} 項數據")
|
||||
return True
|
||||
|
||||
try:
|
||||
# 檢查集合是否已存在
|
||||
if collection_name in [col.name for col in client.list_collections()]:
|
||||
client.delete_collection(collection_name)
|
||||
|
||||
# 創建新集合
|
||||
collection = client.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
# 如果沒有數據,直接返回
|
||||
if not data["ids"]:
|
||||
print(f"集合 '{collection_name}' 創建完成,但沒有數據添加")
|
||||
return True
|
||||
|
||||
# 分批添加數據
|
||||
total_items = len(data["ids"])
|
||||
with tqdm(total=total_items, desc=f"正在填充 {collection_name}") as pbar:
|
||||
for i in range(0, total_items, batch_size):
|
||||
end_idx = min(i + batch_size, total_items)
|
||||
|
||||
batch_ids = data["ids"][i:end_idx]
|
||||
batch_docs = data["documents"][i:end_idx]
|
||||
batch_meta = data["metadatas"][i:end_idx]
|
||||
|
||||
# 處理可能的None值
|
||||
processed_docs = []
|
||||
for doc in batch_docs:
|
||||
if doc is None:
|
||||
processed_docs.append("") # 使用空字符串替代None
|
||||
else:
|
||||
processed_docs.append(doc)
|
||||
|
||||
collection.add(
|
||||
ids=batch_ids,
|
||||
documents=processed_docs,
|
||||
metadatas=batch_meta
|
||||
)
|
||||
|
||||
pbar.update(end_idx - i)
|
||||
|
||||
print(f"成功將 {total_items} 項數據添加到集合 '{collection_name}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"創建或填充集合 '{collection_name}' 時出錯: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def swap_collections(
|
||||
client,
|
||||
original_collection: str,
|
||||
temp_collection: str,
|
||||
confirm_dangerous: bool = False,
|
||||
dry_run: bool = False,
|
||||
embedding_func = None # 添加嵌入函數作為參數
|
||||
) -> bool:
|
||||
"""替換集合(刪除原始集合,將臨時集合重命名為原始集合名)
|
||||
|
||||
Args:
|
||||
client: ChromaDB客戶端
|
||||
original_collection: 原始集合名稱
|
||||
temp_collection: 臨時集合名稱
|
||||
confirm_dangerous: 是否確認危險操作
|
||||
dry_run: 是否只模擬執行
|
||||
embedding_func: 嵌入函數,用於創建新集合
|
||||
|
||||
Returns:
|
||||
成功返回True,否則返回False
|
||||
"""
|
||||
if dry_run:
|
||||
print(f"[模擬] 將替換集合: 刪除 '{original_collection}',重命名 '{temp_collection}' 到 '{original_collection}'")
|
||||
return True
|
||||
|
||||
try:
|
||||
# 檢查是否有確認標誌
|
||||
if not confirm_dangerous:
|
||||
response = input(f"警告: 即將刪除集合 '{original_collection}' 並用 '{temp_collection}' 替換它。確認操作? (y/N): ")
|
||||
if response.lower() != 'y':
|
||||
print("操作已取消")
|
||||
return False
|
||||
|
||||
# 檢查兩個集合是否都存在
|
||||
all_collections = [col.name for col in client.list_collections()]
|
||||
if original_collection not in all_collections:
|
||||
print(f"錯誤: 原始集合 '{original_collection}' 不存在")
|
||||
return False
|
||||
|
||||
if temp_collection not in all_collections:
|
||||
print(f"錯誤: 臨時集合 '{temp_collection}' 不存在")
|
||||
return False
|
||||
|
||||
# 獲取臨時集合的所有數據
|
||||
# 在刪除原始集合之前先獲取臨時集合的所有數據
|
||||
print(f"獲取臨時集合 '{temp_collection}' 的數據...")
|
||||
temp_collection_obj = client.get_collection(temp_collection)
|
||||
temp_data = temp_collection_obj.get(include=["documents", "metadatas"])
|
||||
|
||||
# 刪除原始集合
|
||||
print(f"刪除原始集合 '{original_collection}'...")
|
||||
client.delete_collection(original_collection)
|
||||
|
||||
# 創建一個同名的新集合(與原始集合同名)
|
||||
print(f"創建新集合 '{original_collection}'...")
|
||||
|
||||
# 使用傳入的嵌入函數或臨時集合的嵌入函數
|
||||
embedding_function = embedding_func or temp_collection_obj._embedding_function
|
||||
|
||||
# 創建新的集合
|
||||
original_collection_obj = client.create_collection(
|
||||
name=original_collection,
|
||||
embedding_function=embedding_function
|
||||
)
|
||||
|
||||
# 將數據添加到新集合
|
||||
if temp_data["ids"]:
|
||||
print(f"將 {len(temp_data['ids'])} 項數據從臨時集合複製到新集合...")
|
||||
|
||||
# 處理可能的None值
|
||||
processed_docs = []
|
||||
for doc in temp_data["documents"]:
|
||||
if doc is None:
|
||||
processed_docs.append("")
|
||||
else:
|
||||
processed_docs.append(doc)
|
||||
|
||||
# 使用分批方式添加數據以避免潛在的大數據問題
|
||||
batch_size = 100
|
||||
for i in range(0, len(temp_data["ids"]), batch_size):
|
||||
end = min(i + batch_size, len(temp_data["ids"]))
|
||||
original_collection_obj.add(
|
||||
ids=temp_data["ids"][i:end],
|
||||
documents=processed_docs[i:end],
|
||||
metadatas=temp_data["metadatas"][i:end] if temp_data["metadatas"] else None
|
||||
)
|
||||
|
||||
# 刪除臨時集合
|
||||
print(f"刪除臨時集合 '{temp_collection}'...")
|
||||
client.delete_collection(temp_collection)
|
||||
|
||||
print(f"成功用重新嵌入的數據替換集合 '{original_collection}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"替換集合時出錯: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def process_collection(
|
||||
client,
|
||||
collection_name: str,
|
||||
embedding_func,
|
||||
temp_suffix: str,
|
||||
batch_size: int,
|
||||
confirm_dangerous: bool,
|
||||
dry_run: bool
|
||||
) -> bool:
|
||||
"""處理一個集合的完整流程
|
||||
|
||||
Args:
|
||||
client: ChromaDB客戶端
|
||||
collection_name: 要處理的集合名稱
|
||||
embedding_func: 新的嵌入函數
|
||||
temp_suffix: 臨時集合的後綴
|
||||
batch_size: 批處理大小
|
||||
confirm_dangerous: 是否確認危險操作
|
||||
dry_run: 是否只模擬執行
|
||||
|
||||
Returns:
|
||||
處理成功返回True,否則返回False
|
||||
"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"處理集合: '{collection_name}'")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 暫時集合名稱
|
||||
temp_collection_name = f"{collection_name}{temp_suffix}"
|
||||
|
||||
# 1. 獲取原始集合的數據
|
||||
data = fetch_collection_data(client, collection_name, batch_size)
|
||||
|
||||
if not data["ids"]:
|
||||
print(f"集合 '{collection_name}' 為空或不存在,跳過")
|
||||
return True
|
||||
|
||||
# 2. 創建臨時集合並使用新的嵌入模型填充數據
|
||||
success = create_and_populate_collection(
|
||||
client,
|
||||
temp_collection_name,
|
||||
data,
|
||||
embedding_func,
|
||||
batch_size,
|
||||
dry_run
|
||||
)
|
||||
|
||||
if not success:
|
||||
print(f"創建臨時集合 '{temp_collection_name}' 失敗,跳過替換")
|
||||
return False
|
||||
|
||||
# 3. 替換原始集合
|
||||
success = swap_collections(
|
||||
client,
|
||||
collection_name,
|
||||
temp_collection_name,
|
||||
confirm_dangerous,
|
||||
dry_run,
|
||||
embedding_func # 添加嵌入函數作為參數
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""主函數"""
|
||||
args = parse_args()
|
||||
|
||||
# 獲取ChromaDB目錄
|
||||
chroma_dir = getattr(config, "CHROMA_DATA_DIR", "chroma_data")
|
||||
print(f"使用ChromaDB目錄: {chroma_dir}")
|
||||
|
||||
# 備份數據庫(如果請求)
|
||||
if args.backup:
|
||||
backup_chroma_directory(chroma_dir)
|
||||
|
||||
# 創建ChromaDB客戶端
|
||||
try:
|
||||
client = chromadb.PersistentClient(path=chroma_dir)
|
||||
except Exception as e:
|
||||
print(f"錯誤: 無法連接到ChromaDB: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 創建嵌入函數
|
||||
embedding_func = create_embedding_function(args.new_model)
|
||||
|
||||
# 確定要處理的集合
|
||||
if args.collections:
|
||||
collections_to_process = args.collections
|
||||
else:
|
||||
# 使用配置中的默認集合或獲取所有可用集合
|
||||
default_collections = [
|
||||
getattr(config, "BOT_MEMORY_COLLECTION", "wolfhart_memory"),
|
||||
getattr(config, "CONVERSATIONS_COLLECTION", "conversations"),
|
||||
getattr(config, "PROFILES_COLLECTION", "user_profiles")
|
||||
]
|
||||
collections_to_process = get_collection_names(client, default_collections)
|
||||
|
||||
# 過濾掉已經是臨時集合的集合名稱
|
||||
filtered_collections = []
|
||||
for collection in collections_to_process:
|
||||
if args.temp_collection_suffix in collection:
|
||||
print(f"警告: 跳過可能的臨時集合 '{collection}'")
|
||||
continue
|
||||
filtered_collections.append(collection)
|
||||
|
||||
collections_to_process = filtered_collections
|
||||
|
||||
if not collections_to_process:
|
||||
print("沒有找到可處理的集合。")
|
||||
sys.exit(0)
|
||||
|
||||
print(f"將處理以下集合: {', '.join(collections_to_process)}")
|
||||
if args.dry_run:
|
||||
print("注意: 執行為乾運行模式,不會實際修改數據")
|
||||
|
||||
# 詢問用戶確認
|
||||
if not args.confirm_dangerous and not args.dry_run:
|
||||
confirm = input("這個操作將使用新的嵌入模型重新計算所有數據。繼續? (y/N): ")
|
||||
if confirm.lower() != 'y':
|
||||
print("操作已取消")
|
||||
sys.exit(0)
|
||||
|
||||
# 處理每個集合
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
|
||||
for collection_name in collections_to_process:
|
||||
if process_collection(
|
||||
client,
|
||||
collection_name,
|
||||
embedding_func,
|
||||
args.temp_collection_suffix,
|
||||
args.batch_size,
|
||||
args.confirm_dangerous,
|
||||
args.dry_run
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
# 報告結果
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"處理完成: {success_count}/{len(collections_to_process)} 個集合成功")
|
||||
print(f"總耗時: {elapsed_time:.2f} 秒")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -147,6 +147,15 @@ class ChromaDBReader:
|
||||
except Exception as e:
|
||||
self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-MiniLM-L12-v2: {e}。將使用集合內部模型。")
|
||||
self.query_embedding_function = None
|
||||
# 添加新的模型支持
|
||||
elif model_name == "paraphrase-multilingual-mpnet-base-v2":
|
||||
try:
|
||||
# 注意: sentence-transformers 庫需要安裝
|
||||
self.query_embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
||||
self.logger.info(f"查詢將使用外部嵌入模型: {model_name}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"無法加載 SentenceTransformer paraphrase-multilingual-mpnet-base-v2: {e}。將使用集合內部模型。")
|
||||
self.query_embedding_function = None
|
||||
else:
|
||||
self.logger.warning(f"未知的查詢嵌入模型: {model_name}, 將使用集合內部模型。")
|
||||
self.query_embedding_function = None
|
||||
@ -450,13 +459,11 @@ class ChromaDBReaderUI:
|
||||
self.embedding_models = {
|
||||
"預設 (ChromaDB)": "default",
|
||||
"all-MiniLM-L6-v2 (ST)": "all-MiniLM-L6-v2",
|
||||
"paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2"
|
||||
"paraphrase-multilingual-MiniLM-L12-v2 (ST)": "paraphrase-multilingual-MiniLM-L12-v2",
|
||||
"paraphrase-multilingual-mpnet-base-v2 (ST)": "paraphrase-multilingual-mpnet-base-v2" # 添加新的模型選項
|
||||
}
|
||||
# 初始化 reader 中的嵌入模型 (確保 reader 實例已創建)
|
||||
# self.reader.set_query_embedding_model(self.embedding_models[self.embedding_model_var.get()])
|
||||
# ^^^ 這行需要在 setup_ui 之後,或者在 on_embedding_model_changed 中處理首次設置
|
||||
|
||||
self.setup_ui() # setup_ui 會創建 reader 實例
|
||||
|
||||
self.setup_ui()
|
||||
|
||||
# 默認主題
|
||||
self.current_theme = "darkly" # ttkbootstrap的深色主題
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user