139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
import asyncio
|
||
from typing import List, Iterable, AsyncGenerator, Optional
|
||
from transformers import AutoTokenizer, AutoModel
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch import Tensor
|
||
|
||
|
||
class JEmbeddingService:
|
||
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192):
|
||
self.model_path = model_path
|
||
self.max_length = max_length
|
||
self.tokenizer = None
|
||
self.model = None
|
||
self._load_model()
|
||
|
||
def _load_model(self) -> None:
|
||
"""加载模型和分词器"""
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True,
|
||
padding_side='left' # 左填充,适合指令模型
|
||
)
|
||
|
||
# 检查GPU可用性并自动选择设备
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"🚀 使用设备: {self.device}")
|
||
|
||
self.model = AutoModel.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True
|
||
)
|
||
|
||
# 将模型移动到指定设备
|
||
self.model = self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
# 打印模型信息
|
||
if torch.cuda.is_available():
|
||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||
print(f"🎯 GPU内存: {gpu_memory:.1f} GB")
|
||
print(f"📊 模型已加载到GPU: {next(self.model.parameters()).device}")
|
||
|
||
def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||
"""
|
||
提取最后一个有效token的隐藏状态
|
||
这是Qwen3-Embedding模型推荐的池化方式
|
||
"""
|
||
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
||
if left_padding:
|
||
return last_hidden_states[:, -1]
|
||
else:
|
||
sequence_lengths = attention_mask.sum(dim=1) - 1
|
||
batch_size = last_hidden_states.shape[0]
|
||
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
||
|
||
async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]:
|
||
"""
|
||
将文本列表转换为向量表示
|
||
|
||
Args:
|
||
texts: 文本列表
|
||
normalize: 是否进行L2归一化,默认True
|
||
|
||
Returns:
|
||
向量列表,每个向量对应一个输入文本
|
||
"""
|
||
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
|
||
raise ValueError("texts必须是字符串列表")
|
||
|
||
def encode_texts():
|
||
# 批量分词
|
||
batch_dict = self.tokenizer(
|
||
texts,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=self.max_length,
|
||
return_tensors="pt"
|
||
)
|
||
|
||
# 移动到模型设备
|
||
batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()}
|
||
|
||
# 获取嵌入
|
||
with torch.no_grad():
|
||
outputs = self.model(**batch_dict)
|
||
# 使用last_token_pool提取最后一个有效token的嵌入
|
||
embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
||
|
||
# L2归一化(推荐用于相似度计算)
|
||
if normalize:
|
||
embeddings = F.normalize(embeddings, p=2, dim=1)
|
||
|
||
return embeddings.tolist()
|
||
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, encode_texts)
|
||
|
||
|
||
async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]:
|
||
"""
|
||
批量处理文本,生成向量表示
|
||
|
||
Args:
|
||
items: 文本迭代器
|
||
normalize: 是否进行L2归一化
|
||
|
||
Yields:
|
||
处理结果字典,包含index、status、embedding或message
|
||
"""
|
||
texts: List[str] = []
|
||
indices: List[int] = []
|
||
|
||
# 收集有效文本
|
||
for idx, text in enumerate(items):
|
||
try:
|
||
if not isinstance(text, str):
|
||
raise ValueError("每个元素必须是字符串")
|
||
texts.append(text)
|
||
indices.append(idx)
|
||
except Exception as e:
|
||
yield {"index": idx, "status": "error", "message": str(e)}
|
||
await asyncio.sleep(0) # 让出控制权
|
||
|
||
if not texts:
|
||
return
|
||
|
||
try:
|
||
# 批量处理所有文本
|
||
vectors = await self.embed(texts, normalize=normalize)
|
||
for i, vec in zip(indices, vectors):
|
||
yield {"index": i, "status": "success", "embedding": vec}
|
||
except Exception as e:
|
||
# 如果批量处理失败,为每个文本返回错误
|
||
for i in indices:
|
||
yield {"index": i, "status": "error", "message": str(e)}
|
||
|
||
|